当前位置:   article > 正文

《机器学习》周志华课后习题答案——第九章 (1-4完结)_k均值算法能否找到最优解

k均值算法能否找到最优解

机器学习》周志华课后习题答案——第九章 (1-4完结)

1.如图

请添加图片描述
请添加图片描述

2.如图

请添加图片描述

请添加图片描述

3.试析k均值算法能否找到最小化式(9.24)的最优解.

不能,因为k均值本身是NP问题,且9.24是非凸的(具体证明不太懂.),容易陷入局部最优是k均值的一个缺点吧,所以在使用k均值时常常多次随机初始化中心点,然后挑选结果最好的一个。

4.试编程实现k均值算法,设置三组不同的k值、三组不同初始中心点,在西瓜数据集4.0上进行实验比较,并讨论什么样的初始中心有利于取得好结果.

import numpy as np
import matplotlib.pyplot as plt
from scipy.spatial import ConvexHull


class KMeans(object):

    def __init__(self, k):
        self.k = k

    def fit(self, X, initial_centroid_index=None, max_iters=10, seed=16, plt_process=False):
        m, n = X.shape

        # 没有指定中心点时,随机初始化中心点
        if initial_centroid_index is None:
            np.random.seed(seed)
            initial_centroid_index = np.random.randint(0, m, self.k)

        centroid = X[initial_centroid_index, :]

        idx = None

        # 打开交互模式
        plt.ion()
        for i in range(max_iters):
            # 按照中心点给样本分类
            idx = self.find_closest_centroids(X, centroid)

            if plt_process:
                self.plot_converge(X, idx, initial_centroid_index)

            # 重新计算中心点
            centroid = self.compute_centroids(X, idx)

        # 关闭交互模式
        plt.ioff()
        plt.show()

        return centroid, idx

    def find_closest_centroids(self, X, centroid):

        # 这种方式利用 numpy 的广播机制,直接计算样本到各中心的距离,不用循环,速度比较快,但是在样本比较大时,更消耗内存
        distance = np.sum((X[:, np.newaxis, :] - centroid) ** 2, axis=2)
        idx = distance.argmin(axis=1)
        return idx

    def compute_centroids(self, X, idx):
        centroids = np.zeros((self.k, X.shape[1]))

        for i in range(self.k):
            centroids[i, :] = np.mean(X[idx == i], axis=0)

        return centroids

    def plot_converge(self, X, idx, initial_idx):
        plt.cla()  # 清除原有图像

        plt.title("k-meas converge process")
        plt.xlabel('density')
        plt.ylabel('sugar content')

        plt.scatter(X[:, 0], X[:, 1], c='lightcoral')
        # 标记初始化中心点
        plt.scatter(X[initial_idx, 0], X[initial_idx, 1], label='initial center', c='k')

        # 画出每个簇的凸包
        for i in range(self.k):
            X_i = X[idx == i]

            # 获取当前簇的凸包索引
            hull = ConvexHull(X_i).vertices.tolist()
            hull.append(hull[0])
            plt.plot(X_i[hull, 0], X_i[hull, 1], 'c--')

        plt.legend()
        plt.pause(0.5)


if __name__ == '__main__':

    data = np.loadtxt('..\data\watermelon4_0_Ch.txt', delimiter=', ')
    centroid, idx = KMeans(3).fit(data, plt_process=True, seed=24)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83

请添加图片描述

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/小小林熬夜学编程/article/detail/350103
推荐阅读
相关标签
  

闽ICP备14008679号