范围:余弦相似度的值范围从 -1 到 1。
- from sklearn.metrics.pairwise import cosine_similarity
- import numpy as np
- # 定义两个向量
- vector_a = np.array([[1, 2, 3]])
- vector_b = np.array([[4, 5, 6]])
- # 计算余弦相似度
- similarity = cosine_similarity(vector_a, vector_b)
- print("余弦相似度:", similarity[0][0])
我们下面从视觉化的角度,感觉Cosin Simiarity。
- import numpy as np
- import matplotlib
- import matplotlib.pyplot as plt
- A = [1, 4]
- B = [3, 3]
- C = [2, 3]
- #colors数组为每个向量定义颜色,以便视觉区分。
- colors = ['r', 'g', 'b']
- #V是从向量A、B和C创建的NumPy数组。
- V = np.array([A, B, C])
- #向量的起点设置为一个2x3的全零数组。这允许所有向量都从点(0, 0)开始。
- origin = np.array([[0, 0, 0],[0, 0, 0]]) # origin point
- plt.figure(figsize=(9, 9))
- #使用plt.quiver函数绘制向量。参数确保向量从x和y方向的原点绘制。scale_units='xy'和scale=1确保
- #向量在x和y方向上按1:1的比例绘制。
- plt.quiver(*origin, V[:,0], V[:,1], color=colors, angles='xy', scale_units='xy', scale=1)
- for letter, x, y, color in zip('ABC', V[:, 0], V[:, 1], colors):
- plt.text(x / 2., y / 2., letter, fontdict={'size': 20, 'color': color, 'weight': 'bold'}, verticalalignment='top')
- plt.xlim(0, 5)
- plt.ylim(0, 5)
- plt.grid()
- plt.show()
我们倾向于向量指向的方向。角度是两个方向之间差异的衡量标准 我们可以使用这个角度的余弦值:
- from sklearn.metrics.pairwise import cosine_similarity
- sims = cosine_similarity([A, B, C])
- # from https://matplotlib.org/3.1.1/gallery/images_contours_and_fields/image_annotated_heatmap.html
- from mpl_toolkits.axes_grid1 import make_axes_locatable
- def heatmap(data, row_labels, col_labels, ax=None,
- cbar_kw={}, cbarlabel="", **kwargs):
- """
- Create a heatmap from a numpy array and two lists of labels.
- Parameters
- ----------
- data
- A 2D numpy array of shape (N, M).
- row_labels
- A list or array of length N with the labels for the rows.
- col_labels
- A list or array of length M with the labels for the columns.
- ax
- A `matplotlib.axes.Axes` instance to which the heatmap is plotted. If
- not provided, use current axes or create a new one. Optional.
- cbar_kw
- A dictionary with arguments to `matplotlib.Figure.colorbar`. Optional.
- cbarlabel
- The label for the colorbar. Optional.
- **kwargs
- All other arguments are forwarded to `imshow`.
- """
- if not ax:
- ax = plt.gca()
- # Plot the heatmap
- im = ax.imshow(data, **kwargs)
- # Create colorbar
- # We want to show all ticks...
- ax.set_xticks(np.arange(data.shape[1]))
- ax.set_yticks(np.arange(data.shape[0]))
- # ... and label them with the respective list entries.
- ax.set_xticklabels(col_labels, fontdict={'size': 20, 'weight': 'bold'})
- ax.set_yticklabels(row_labels, fontdict={'size': 20, 'weight': 'bold'})
- # Let the horizontal axes labeling appear on top.
- ax.tick_params(top=True, bottom=False,
- labeltop=True, labelbottom=False)
- ax.set_xticks(np.arange(data.shape[1]+1)-.5, minor=True)
- ax.set_yticks(np.arange(data.shape[0]+1)-.5, minor=True)
- ax.grid(which="minor", linestyle='-', linewidth=1)
- ax.tick_params(which="minor", bottom=False, left=False)
- return im
- def annotate_heatmap(im, data=None, valfmt="{x:.2f}",
- textcolors=["black", "white"],
- threshold=None, **textkw):
- """
- A function to annotate a heatmap.
- Parameters
- ----------
- im
- The AxesImage to be labeled.
- data
- Data used to annotate. If None, the image's data is used. Optional.
- valfmt
- The format of the annotations inside the heatmap. This should either
- use the string format method, e.g. "$ {x:.2f}", or be a
- `matplotlib.ticker.Formatter`. Optional.
- textcolors
- A list or array of two color specifications. The first is used for
- values below a threshold, the second for those above. Optional.
- threshold
- Value in data units according to which the colors from textcolors are
- applied. If None (the default) uses the middle of the colormap as
- separation. Optional.
- **kwargs
- All other arguments are forwarded to each call to `text` used to create
- the text labels.
- """
- if not isinstance(data, (list, np.ndarray)):
- data = im.get_array()
- # Normalize the threshold to the images color range.
- if threshold is not None:
- threshold = im.norm(threshold)
- else:
- threshold = im.norm(data.max())/2.
- # Set default alignment to center, but allow it to be
- # overwritten by textkw.
- kw = dict(horizontalalignment="center",
- verticalalignment="center")
- kw.update(textkw)
- # Get the formatter in case a string is supplied
- if isinstance(valfmt, str):
- valfmt = matplotlib.ticker.StrMethodFormatter(valfmt)
- # Loop over the data and create a `Text` for each "pixel".
- # Change the text's color depending on the data.
- texts = []
- for i in range(data.shape[0]):
- for j in range(data.shape[1]):
- kw.update(color=textcolors[int(im.norm(data[i, j]) > threshold)])
- text = im.axes.text(j, i, valfmt(data[i, j], None), **kw)
- texts.append(text)
- return texts
此函数从一个2D numpy数组和两个行和列标签列表创建热图。
:形状为(N, M)的2D numpy数组,其中N是行数,M是列数。row_labels
ax.imshow(data, **kwargs)
- fig, ax = plt.subplots(figsize=(9, 9))
- im = heatmap(sims, list('ABC'), list('ABC'), ax=ax,
- cmap="Greens", cbarlabel="Cosine Similarity", vmin=0.0, vmax=1.0)
- texts = annotate_heatmap(im, valfmt="{x:.2f}")
- divider = make_axes_locatable(ax)
- cax = divider.append_axes("right", size="5%", pad=0.2)
- fig.colorbar(im, cax=cax)
- fig.tight_layout()
- plt.show()
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。