当前位置:   article > 正文

K-means算法实战项目(Python实现)(对比简化版)_mderank

mderank

https://blog.csdn.net/wzk4869/article/details/126040866?spm=1001.2014.3001.5501

我的这篇博客中给出了比较完整的k-means算法的实现过程,但是我们可以继续简化一下:

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from scipy.io import loadmat
"""
导入数据的函数
"""
def load_dataset():
    path='./data/ex7data2.mat'
    # 字典格式 : <class 'dict'>
    data=loadmat(path)
    # data.keys() : dict_keys(['__header__', '__version__', '__globals__', 'X'])
    dataset = pd.DataFrame(data.get('X'), columns=['X1', 'X2'])
    return data,dataset
data,dataset=load_dataset()
print('data=\n',data)
"""
绘制散点图
"""
def plot_scatter():
    data,dataset=load_dataset()
    plt.figure(figsize=(12,8))
    plt.scatter(dataset['X1'],dataset['X2'],cmap=['b'])
    plt.show()

def find_centroids(X,centroids):
    idx=[]
    for i in range(X.shape[0]):
        dist=np.linalg.norm((X[i]-centroids),axis=1)
        idx_i=np.argmin(dist)
        idx.append(idx_i)
    return np.array(idx)

centroids_0=np.array([[3,3],[6,2],[8,5]])
idx_0=find_centroids(data.get('X'),centroids_0)

def compute_centroids(X,idx,k):
    centroids=[]
    for i in range(k):
        centroids_i=np.mean(X[idx==i],axis=0)
        centroids.append(centroids_i)
    return np.array(centroids)

centroids=compute_centroids(data.get('X'),idx_0,3)

def k_means(X,centroids,max_iters):
    k=len(centroids)
    centroids_all=[]
    centroids_all.append(centroids_0)
    centroids_i=centroids
    for i in range(max_iters):
        idx_1=find_centroids(X,centroids_i)
        centroids_i=compute_centroids(X,idx_1,k)
        centroids_all.append(centroids_i)
    return np.array(idx_1),np.array(centroids_all)


def plot_data(X,centroids_all,idx):
    plt.figure(figsize=(12,8))
    plt.scatter(X[:,0],X[:,1],c=idx,cmap='rainbow')
    plt.plot(centroids_all[:,:,0],centroids_all[:,:,1],'kx--')
    plt.show()

idx,centroids_all=k_means(data.get('X'),centroids,max_iters=10)
plot_data(data.get('X'),centroids_all,idx)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65

比起我们之前的做法,在函数上减少了不少步骤。

我们看一下最终的结果:

当然,颜色也可以自定义,这里就不再赘述,直接使用内置的颜色。

有兴趣的小伙伴可以看我之前发的文章:

https://blog.csdn.net/wzk4869/article/details/126036397?spm=1001.2014.3001.5501

初始化聚类中心的步骤:

def init_centroids(X,k):
    index=np.random.choice(len(X),k)
    return X[index]

print(init_centroids(data.get('X'),k=3))
  • 1
  • 2
  • 3
  • 4
  • 5
[[7.24694794 2.96877424]
 [2.95818429 1.01887096]
 [2.09517296 1.14010491]]
  • 1
  • 2
  • 3

可视化结果为:

for i in range(5):
    idx,centroids_all=k_means(data.get('X'),init_centroids(data.get('X'),k=3),max_iters=10)
    plot_data(data.get('X'),centroids_all,idx)
  • 1
  • 2
  • 3





完整版代码如下:

"""
给定一个二维的数据集,使用k-means算法进行聚类
"""
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from scipy.io import loadmat
"""
导入数据的函数
"""
def load_dataset():
    path='./data/ex7data2.mat'
    # 字典格式 : <class 'dict'>
    data=loadmat(path)
    # data.keys() : dict_keys(['__header__', '__version__', '__globals__', 'X'])
    dataset = pd.DataFrame(data.get('X'), columns=['X1', 'X2'])
    return data,dataset
data,dataset=load_dataset()
print('data=\n',data)
"""
绘制散点图
"""
def plot_scatter():
    data,dataset=load_dataset()
    plt.figure(figsize=(12,8))
    plt.scatter(dataset['X1'],dataset['X2'],cmap=['b'])
    plt.show()

def find_centroids(X,centroids):
    idx=[]
    for i in range(X.shape[0]):
        dist=np.linalg.norm((X[i]-centroids),axis=1)
        idx_i=np.argmin(dist)
        idx.append(idx_i)
    return np.array(idx)

centroids_0=np.array([[3,3],[6,2],[8,5]])
idx_0=find_centroids(data.get('X'),centroids_0)

def compute_centroids(X,idx,k):
    centroids=[]
    for i in range(k):
        centroids_i=np.mean(X[idx==i],axis=0)
        centroids.append(centroids_i)
    return np.array(centroids)

centroids=compute_centroids(data.get('X'),idx_0,3)

def k_means(X,centroids,max_iters):
    k=len(centroids)
    centroids_all=[]
    centroids_all.append(centroids_0)
    centroids_i=centroids
    for i in range(max_iters):
        idx_1=find_centroids(X,centroids_i)
        centroids_i=compute_centroids(X,idx_1,k)
        centroids_all.append(centroids_i)
    return np.array(idx_1),np.array(centroids_all)


def plot_data(X,centroids_all,idx):
    plt.figure(figsize=(12,8))
    plt.scatter(X[:,0],X[:,1],c=idx,cmap='rainbow')
    plt.plot(centroids_all[:,:,0],centroids_all[:,:,1],'kx--')
    plt.show()

idx,centroids_all=k_means(data.get('X'),centroids,max_iters=10)
plot_data(data.get('X'),centroids_all,idx)

def init_centroids(X,k):
    index=np.random.choice(len(X),k)
    return X[index]

print(init_centroids(data.get('X'),k=3))
for i in range(5):
    idx,centroids_all=k_means(data.get('X'),init_centroids(data.get('X'),k=3),max_iters=10)
    plot_data(data.get('X'),centroids_all,idx)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/小小林熬夜学编程/article/detail/365258
推荐阅读
相关标签
  

闽ICP备14008679号