赞
踩
在进行讲解之前,先放出数据集:
链接:https://pan.baidu.com/s/1y2hUqXm5-LsLpDqrMK3bIQ
提取码:6666
K-means
算法为一种无监督学习算法,可以实现算法自动区分数据集中不同簇。
无监督学习区别于监督学习算法,无监督学习中没有标签y
,算法需要根据输入的数据集直接将其进行区分为各簇,每簇数据有其聚类中心。
K-means 算法是一种自动将相似数据示例聚类在一起的方法。具体来说,给你一个训练集 x ( 1 ) ; ⋯ ; x ( m ) x^{(1)};\cdots; x^{(m)} x(1);⋯;x(m) (其中 x ( i ) ∈ R n x^{(i) }\in R^n x(i)∈Rn),并且想要将数据分组到几个有凝聚力的“簇”中。
K-means 背后是一个迭代过程,首先猜测初始聚类中心,然后通过重复将示例分配给它们最接近的聚类中心,然后根据分配重新计算聚类中心来重新猜测。
K-means算法的伪代码如下:
算法步骤总结如下:
1.随机选取k个聚类中心点
2.遍历所有数据,将数据划分到最近的那个聚类中心点
3.计算所有类的平均值,作为新的聚类中心点
4.重复步骤2和步骤3,直到聚类中心点不再发生变化,或者达到设定的迭代次数
计算每个样本点到聚类中心的距离为:
计算新的聚类中心点:
导入运算包:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from scipy.io import loadmat
import matplotlib as mpl
导入需要的数据集
def load_dataset():
path='./data/ex7data2.mat'
# 字典格式 : <class 'dict'>
data=loadmat(path)
# data.keys() : dict_keys(['__header__', '__version__', '__globals__', 'X'])
dataset = pd.DataFrame(data.get('X'), columns=['X1', 'X2'])
return data,dataset
如同上面所讲,我们此时导入的数据是字典格式,关于如何使用Python导入mat格式的数据并整理成DataFrame格式的问题,我在之前的文章中也有讲解,欢迎大家学习:
https://blog.csdn.net/wzk4869/article/details/126018725?spm=1001.2014.3001.5501
我们可以简单看一下输出结果:
首先是data
{'__header__': b'MATLAB 5.0 MAT-file, Platform: GLNXA64, Created on: Wed Nov 16 00:48:22 2011', '__version__': '1.0', '__globals__': [], 'X': array([[ 1.84207953, 4.6075716 ],
[ 5.65858312, 4.79996405],
[ 6.35257892, 3.2908545 ],
[ 2.90401653, 4.61220411],
[ 3.23197916, 4.93989405],
[ 1.24792268, 4.93267846],
......中间省略部分数据集
[ 4.8255341 , 2.77961664],
[ 6.11768055, 2.85475655],
[ 0.94048944, 5.71556802]])}
接下来是dataset
X1 X2
0 1.842080 4.607572
1 5.658583 4.799964
2 6.352579 3.290854
3 2.904017 4.612204
4 3.231979 4.939894
.. ... ...
295 7.302787 3.380160
296 6.991984 2.987067
297 4.825534 2.779617
298 6.117681 2.854757
299 0.940489 5.715568
数据可视化
拿到数据集后,我们可以可视化一下,看一下数据集的样子:
def plot_scatter():
data,dataset=load_dataset()
plt.figure(figsize=(12,8))
plt.scatter(dataset['X1'],dataset['X2'],cmap=['b'])
plt.show()
寻找数据点的最近的聚类中心
我们这一步的目标是定义函数,用于寻找数据点的最近的聚类中心。
输入参数为数据集X、初始聚类中心,返回一个一维数组,其长度与X的数据点个数相同,每个索引对应的值为该点对应的聚类中心。
def get_near_cluster_centroids(X,centroids): """ :param X: 我们的数据集 :param centroids: 聚类中心点的初始位置 :return: """ m = X.shape[0] #数据的行数 k = centroids.shape[0] #聚类中心的行数,即个数 idx = np.zeros(m) # 一维向量idx,大小为数据集中的点的个数,用于保存每一个X的数据点最小距离点的是哪个聚类中心 for i in range(m): min_distance = 1000000 for j in range(k): distance = np.sum((X[i, :] - centroids[j, :]) ** 2) # 计算数据点到聚类中心距离代价的公式,X中每个点都要和每个聚类中心计算 if distance < min_distance: min_distance = distance idx[i] = j # idx中索引为i,表示第i个X数据集中的数据点距离最近的聚类中心的索引 return idx # 返回的是X数据集中每个数据点距离最近的聚类中心
编写计算聚类中心函数
传入数据集X、聚类中心索引idx、聚类中心个数k。
算法根据idx将X分为k个组,每个组有多个X内的数据集,计算其平均值,得到求平均值之后的聚类中心。返回值为求一次平均值之后的聚类中心。
本函数需要传入idx数组,故不能用于直接初始化聚类中心,只起到优化聚类中心的作用。
def compute_centroids(X, idx, k):
"""
:param X: 数据集
:param idx: 每个样本所属的类别
:param k: 类别总数
:return:
"""
m, n = X.shape
centroids = np.zeros((k, n)) # 初始化为k行n列的二维数组,值均为0,k为聚类中心个数,n为数据列数
for i in range(k):
indices = np.where(idx == i) # 输出的是索引位置
centroids[i, :] = (np.sum(X[indices, :], axis=1) / len(indices[0])).ravel()
return centroids
K-means迭代算法
传入参数数据集X、初始化后的聚类中心、迭代次数。
每次迭代使用之前的两个函数,先计算出idx索引数组,再计算出聚类中心。
返回迭代多次后的idx和cendroids。
def k_means(X, initial_centroids, max_iters): """ :param X: :param initial_centroids: 初始聚类中心 :param max_iters: 迭代次数 :return: """ m, n = X.shape k = initial_centroids.shape[0] idx = np.zeros(m) centroids_all = [] centroids_all.append(initial_centroids) centroids = initial_centroids for i in range(max_iters): idx = get_near_cluster_centroids(X, centroids) centroids = compute_centroids(X, idx, k) centroids_all.append(centroids) return idx, np.array(centroids_all),centroids
首先,我们初始化centroids
为:
centroids_0=np.array([[3, 3], [6, 2], [8, 5]])
接着我们画出聚类图像:
def plot_classify_data(X,idx):
cluster1 = X[np.where(idx == 0)[0], :]
cluster2 = X[np.where(idx == 1)[0], :]
cluster3 = X[np.where(idx == 2)[0], :]
fig, ax = plt.subplots(figsize=(12, 8))
ax.scatter(cluster1[:, 0], cluster1[:, 1], s=30, color='r', label='Cluster 1')
ax.scatter(cluster2[:, 0], cluster2[:, 1], s=30, color='g', label='Cluster 2')
ax.scatter(cluster3[:, 0], cluster3[:, 1], s=30, color='b', label='Cluster 3')
ax.legend()
plt.show()
idx, centroids_all,centroids=k_means(X,centroids_0,10)
plot_classify_data(X,idx)
结果为:
我们想看一下聚类中心的变化情况:
def plot_data(centroids_all,idx):
plt.figure(figsize=(12,8))
cm=mpl.colors.ListedColormap(['r','g','b'])
plt.scatter(dataset['X1'],dataset['X2'], c=idx,cmap=cm)
plt.plot(centroids_all[:,:,0],centroids_all[:,:,1], 'kx--')
plt.show()
结果为:
综上所述,结果清晰明了。
在以上操作中,使用的初始聚类中心均为自己找的聚类中心,直接以数组形式进行定义。
实际上,聚类中心应在数据集X中随机选择k个样本带你作为聚类中心。
def init_centroids(X, k):
m, n = X.shape
init_centroids = np.zeros((k, n))
idx = np.random.randint(0, m, k)
for i in range(k):
init_centroids[i, :] = X[idx[i], :]
return init_centroids
我们展示一下结果:
print(init_centroids(X, 3))
[[3.12635184 1.2806893 ]
[3.12405123 0.67821757]
[2.84734459 0.26759253]]
可视化图像为:
for i in range(5):
idx, centroids_all, centroids = k_means(X, init_centroids(X, 3), 10)
plot_data(centroids_all,idx)
需要注意的是,每次运行结果都不相同,具有随机性!
""" 给定一个二维的数据集,使用k-means算法进行聚类 """ import numpy as np import pandas as pd import matplotlib.pyplot as plt from scipy.io import loadmat import matplotlib as mpl """ 导入数据的函数 """ def load_dataset(): path='./data/ex7data2.mat' # 字典格式 : <class 'dict'> data=loadmat(path) # data.keys() : dict_keys(['__header__', '__version__', '__globals__', 'X']) dataset = pd.DataFrame(data.get('X'), columns=['X1', 'X2']) return data,dataset """ 绘制散点图 """ def plot_scatter(): data,dataset=load_dataset() plt.figure(figsize=(12,8)) plt.scatter(dataset['X1'],dataset['X2'],cmap=['b']) plt.show() """ 获得每个样本所属的类别 """ def get_near_cluster_centroids(X,centroids): """ :param X: 我们的数据集 :param centroids: 聚类中心点的初始位置 :return: """ m = X.shape[0] #数据的行数 k = centroids.shape[0] #聚类中心的行数,即个数 idx = np.zeros(m) # 一维向量idx,大小为数据集中的点的个数,用于保存每一个X的数据点最小距离点的是哪个聚类中心 for i in range(m): min_distance = 1000000 for j in range(k): distance = np.sum((X[i, :] - centroids[j, :]) ** 2) # 计算数据点到聚类中心距离代价的公式,X中每个点都要和每个聚类中心计算 if distance < min_distance: min_distance = distance idx[i] = j # idx中索引为i,表示第i个X数据集中的数据点距离最近的聚类中心的索引 return idx # 返回的是X数据集中每个数据点距离最近的聚类中心 def compute_centroids(X, idx, k): """ :param X: 数据集 :param idx: 每个样本所属的类别 :param k: 类别总数 :return: """ m, n = X.shape centroids = np.zeros((k, n)) # 初始化为k行n列的二维数组,值均为0,k为聚类中心个数,n为数据列数 for i in range(k): indices = np.where(idx == i) # 输出的是索引位置 centroids[i, :] = (np.sum(X[indices, :], axis=1) / len(indices[0])).ravel() return centroids def k_means(X, initial_centroids, max_iters): """ :param X: :param initial_centroids: 初始聚类中心 :param max_iters: 迭代次数 :return: """ m, n = X.shape k = initial_centroids.shape[0] idx = np.zeros(m) centroids_all = [] centroids_all.append(initial_centroids) centroids = initial_centroids for i in range(max_iters): idx = get_near_cluster_centroids(X, centroids) centroids = compute_centroids(X, idx, k) centroids_all.append(centroids) return idx, np.array(centroids_all),centroids def plot_data(centroids_all,idx): plt.figure(figsize=(12,8)) cm=mpl.colors.ListedColormap(['r','g','b']) plt.scatter(dataset['X1'],dataset['X2'], c=idx,cmap=cm) plt.plot(centroids_all[:,:,0],centroids_all[:,:,1], 'kx--') plt.show() def plot_classify_data(X,idx): cluster1 = X[np.where(idx == 0)[0], :] cluster2 = X[np.where(idx == 1)[0], :] cluster3 = X[np.where(idx == 2)[0], :] fig, ax = plt.subplots(figsize=(12, 8)) ax.scatter(cluster1[:, 0], cluster1[:, 1], s=30, color='r', label='Cluster 1') ax.scatter(cluster2[:, 0], cluster2[:, 1], s=30, color='g', label='Cluster 2') ax.scatter(cluster3[:, 0], cluster3[:, 1], s=30, color='b', label='Cluster 3') ax.legend() plt.show() def init_centroids(X, k): m, n = X.shape init_centroids = np.zeros((k, n)) idx = np.random.randint(0, m, k) for i in range(k): init_centroids[i, :] = X[idx[i], :] return init_centroids if __name__=='__main__': data,dataset = load_dataset() plot_scatter() X=data.get('X') centroids_0=np.array([[3, 3], [6, 2], [8, 5]]) idx, centroids_all,centroids=k_means(X,centroids_0,10) plot_data(centroids_all, idx) plot_classify_data(X,idx) print(init_centroids(X, 3)) for i in range(5): idx, centroids_all, centroids = k_means(X, init_centroids(X, 3), 10) plot_data(centroids_all,idx)
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。