赞
踩
- # -*- coding: utf-8 -*-
- import numpy as np
- import matplotlib.pyplot as plt
- import matplotlib.animation as animation
-
-
- def kmeans(data, center_ids, max_err=0.0001, max_round=30):
- init_centers = []
- n = len(center_ids)
- for id in center_ids:
- init_centers.append(data[id, :])
- error, rounds = 1.0, 0
- while error > max_err and rounds < max_round:
- rounds += 1
- clusters = []
- for _ in range(n):
- clusters.append([])
- for j in range(len(data)):
- dist = []
- for i in range(n):
- vector = data[j, :] - init_centers[i]
- d_ji = np.dot(vector, vector) ** 0.5
- dist.append(d_ji)
- near_id = sorted(enumerate(dist), key=lambda x: x[1])[0][0]
- clusters[near_id].append(j)
-
- new_center = [0] * n
- error = 0
- for i in range(n):
- new_center[i] = np.sum(data[clusters[i], :], axis=0)
- new_center[i] /= len(clusters[i])
- vec = new_center[i] - init_centers[i]
- err = np.dot(vec, vec) ** 0.5
- if err:
- init_centers[i] = new_center[i]
- error += err
- yield clusters, new_center, rounds # 用yield可以得到每一轮训练后的聚类情况,最终返回的是一个生成器
-
-
- data = np.array([
- [0.697, 0.460], [0.774, 0.376], [0.634, 0.264], [0.608, 0.318], [0.556, 0.215],
- [0.403, 0.237], [0.481, 0.149], [0.437, 0.211], [0.666, 0.091], [0.243, 0.267],
- [0.245, 0.057], [0.343, 0.099], [0.639, 0.161], [0.657, 0.198], [0.360, 0.370],
- [0.593, 0.042], [0.719, 0.103], [0.359, 0.188], [0.339, 0.241], [0.282, 0.257],
- [0.748, 0.232], [0.714, 0.346], [0.483, 0.312], [0.478, 0.437], [0.525, 0.369],
- [0.751, 0.489], [0.532, 0.472], [0.473, 0.376], [0.725, 0.445], [0.446, 0.459]])
- init_centers = [12, 22] # 对应的是选择的初始中心样本的id,这也同时代表了选择的聚类数目
- fig, ax = plt.subplots(1, 1, figsize=(5, 5))
- ax.set_xlim(0, 1)
- ax.set_ylim(0, 0.6)
- ax.set_ylabel('sugar')
- ax.set_xlabel('density')
- imgs = []
- for cluster, center, rounds in kmeans(data, init_centers): # 对各轮聚类的结果进行保存,存入imgs
- pics, dye = [], ['red', 'orange', 'green', 'blue', 'pink']
- ax.set_title('clusters in %s rounds' % rounds)
- for i, li in enumerate(cluster):
- pics.append(ax.scatter(data[li, 0], data[li, 1], c=dye[i]))
- pics.append(ax.scatter(center[i][0], center[i][1], s=45, c='gray', marker='s', ))
- imgs.append(pics)
- imgs.insert(0, [ax.scatter(data[:, 0], data[:, 1], c='k')])
- A = animation.ArtistAnimation(fig, imgs, interval=1000, blit=True, repeat_delay=500)
- plt.show()
- A.save('3point.gif', fps=2, writer='imagemagick') # 设置保存路径,gif图每秒帧数
K-means算法的2类聚类:
K-means算法的3类聚类:
K-means算法的4类聚类:
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。