赞
踩
邻近算法,或者说K最近邻(KNN,K-NearestNeighbor)分类算法是数据挖掘分类技术中最简单的方法之一。所谓K最近邻,就是K个最近的邻居的意思,说的是每个样本都可以用它最接近的K个邻近值来代表。近邻算法就是将数据集合中每一个记录进行分类的方法 。
KNN算法的核心思想是,如果一个样本在特征空间中的K个最相邻的样本中的大多数属于某一个类别,则该样本也属于这个类别,并具有这个类别上样本的特性。该方法在确定分类决策上只依据最邻近的一个或者几个样本的类别来决定待分样本所属的类别。KNN方法在类别决策时,只与极少量的相邻样本有关。由于KNN方法主要靠周围有限的邻近的样本,而不是靠判别类域的方法来确定所属类别的,因此对于类域的交叉或重叠较多的待分样本集来说,KNN方法较其他方法更为适合。
总体来说,KNN分类算法包括以下4个步骤:
①准备数据,对数据进行预处理 [4] 。
②计算测试样本点(也就是待分类点)到其他每个样本点的距离 。
③对每个距离进行排序,然后选择出距离最小的K个点 。
④对K个点所属的类别进行比较,根据少数服从多数的原则,将测试样本点归入在K个点中占比最高的那一类。
优点: KNN方法思路简单,易于理解,易于实现,无需估计参数,无需训练
缺点: 该算法在分类时有个主要的不足是,当样本不平衡时,如一个类的样本容量很大,而其他类样本容量很小时,有可能导致当输入一个新样本时,该样本的K个邻居中大容量类的样本占多数 。该方法的另一个不足之处是计算量较大,因为对每一个待分类的文本都要计算它到全体已知样本的距离,才能求得它的K个最近邻点 。
引用: knn最邻近结点算法-百度百科
现有一个表 iris.csv ,部分数据如下所示:
id 花茎长度 花茎宽度 花蕊长度 花蕊宽度 种类
需求: iris.csv 总共150条数据,使用120条数据作为训练数据,使用花茎长度、花茎宽度、花蕊长度、花蕊宽度四个属性值,推导花的种类 Species ,使用剩余30条数据验证knn算法准确率
linux启动 pypark ,打开 jupyter 网页端口:http://hadoop-single:8888/
代码演示截图如下:
数据可视化:
导入matplotlib包:
import matplotlib as mpl
import matplotlib.pyplot as plt
设置相关配置:
# matplotlib不支持中文, 需要配置一下 , 设置一个中文字体
mpl.rcParams["font.family"] = "SimHei"
# 能够显示 中文, 正常显示 “-”
mpl.rcParams["axes.unicode_minus"] = False
训练数据集可视化展示:
#设置画布大小
plt.figure(figsize=(9,9))
#绘制点图,需要提供x,y轴坐标;x:花蕊的长度 y:花瓣的长度
plt.scatter(x=t0["SepalLengthCm"][:40], y=t0["PetalLengthCm"][:40], color="r", label="virginica")
plt.scatter(x=t1["SepalLengthCm"][:40], y=t1["PetalLengthCm"][:40], color="g", label="setosa")
plt.scatter(x=t2["SepalLengthCm"][:40], y=t2["PetalLengthCm"][:40], color="b", label="versicolor")
如下图所示:
测试数据集:
# 测试数据集 test_y 待测试数据集真实数据 result 用KNN分类算法计算出来的数据集
right = test_X[test_y ==result]
wrong = test_X[test_y !=result]
plt.figure(figsize=(9,9))
plt.scatter(x=right["SepalLengthCm"], y=right["PetalLengthCm"], color="c", label="right",marker="x")
如下图所示:
测试knn算法展示:
plt.figure(figsize=(9,9))
plt.legend(loc="best")
plt.xlabel('花萼')
plt.ylabel('花瓣')
plt.title('KNN分类算法显示')
plt.scatter(x=t0["SepalLengthCm"][:40], y=t0["PetalLengthCm"][:40], color="r", label="virginica")
plt.scatter(x=t1["SepalLengthCm"][:40], y=t1["PetalLengthCm"][:40], color="g", label="setosa")
plt.scatter(x=t2["SepalLengthCm"][:40], y=t2["PetalLengthCm"][:40], color="b", label="versicolor")
plt.scatter(x=right["SepalLengthCm"], y=right["PetalLengthCm"], color="c", label="right",marker="x")
plt.scatter(x=wrong["SepalLengthCm"], y=wrong["PetalLengthCm"], color="m", label="wrong",marker=">")
如下图所示:
表数据及算法测试源码链接:https://pan.baidu.com/s/1RI4k72ZUDPBvbiEjJzaDRg
提取码:rssc
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。