当前位置:   article > 正文

机器学习算法-KNN代码实现_knn算法代码

knn算法代码

一、KNN算法初步理解

统计学习方法书上的解释:给定一个训练数据集,对于新的输入实例,在训练数据集中找到与该实例最邻近的k个实例,这k个实例的多数属于某个类,就把该输入实例分为这个类。

二、代码实现

1.数据集处理

数据集是使用的是鸢尾花数据集,在代码中直接从sklearn中导入即可。
要对数据集进行处理,必须得先知道数据集的特点,鸢尾花共有150个样本,类别数为3.
0~50个样本label=0,50-100样本label=1,100-150个样本label=2。并且鸢尾花特征数=4。
1.使用pandas展示数据集

from sklearn.datasets import load_iris
import pandas as pd
#导入数据集
iris = load_iris()
#iris.data得到数据,columns为列即4个特征
df = pd.DataFrame(iris.data, columns=iris.feature_names)
#添加标签label列
df['label'] = iris.target
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8

注:iris.data得到样本特征,iris.target得到标签。
panda中dataframe用法
鸢尾花数据集分布:

在这里插入图片描述
2.数据处理
取出数据中前100个样本,并且将特征1(“speal length”)和特征2(“speal width”)作为特征
特征数目等于2。
用matplotlib中将样本可视化
plt.scatter()用法


plt.scatter(df[:50]['sepal length'], df[:50]['sepal width'], label='0')
plt.scatter(df[50:100]['sepal length'], df[50:100]['sepal width'], label='1')
plt.xlabel('sepal length')
plt.ylabel('sepal width')
plt.show()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6

在这里插入图片描述
准备数据集


#取前100个样本,然后特征1和特征2和label
data=np.array(df.iloc[:100,[0,1,-1]])
#数据集
X,y=data[:,:-1],data[:,-1]
#train、test的划分train/test=4:
X_train,X_test,Y_train,Y_test=train_test_split(X,y,test_size=0.2)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7

2.创建model


#model
class KNN():
    #初始化,设置neighbors=3,使用L2范数求距离
    def __init__(self,x_train,y_train,n_neighbors=3,p=2):
        self.x_train=x_train
        self.y_train=y_train
        self.n=n_neighbors
        self.p=p
    def predict(self,x):
        #创建一个列表
        knn_list=[]
        for i in range(self.n):
            #求距离
            dist=np.linalg.norm(x-self.x_train[i],ord=self.p)
            #向列表knn_list添加前三个训练样本的距离和label,为元组(dist,label)
            #比如knn_list=[(dist0,0),(dist1,0),(dist2,1)]
            knn_list.append((dist,self.y_train[i]))
        #从第四个样本开始
        for i in range(self.n,len(self.x_train)):
            #找到knn_list中最大值索引
            max_index=knn_list.index(max(knn_list,key=lambda x:x[0]))
            dist = np.linalg.norm(x - self.x_train[i], ord=self.p)
            #与最大值比较,要把最大值踢出去,就是不断缩小预测样本与训练样本之间的距离
            if knn_list[max_index][0]>dist:
                knn_list[max_index]=(dist,self.y_train[i])
        #统计-看看有没有误判的类别,计算损失
        #knn列表存储了三个label值
        knn=[k[-1] for k in knn_list]
        #用key-value的形式记录label=0或1的1个数有多少个
        count_pairs=Counter(knn)
        #[0:1,1:2],得到最大计数值是label,少数服从多数原则
        max_count=sorted(count_pairs.items(),key=lambda x:x[-1])[-1][0]
        return max_count
    def score(self,x_test,y_test):
        right_count=0
        n=10
        for x,y in zip(x_test,y_test):
            #调用了KNN.predict
            label=self.predict(y)
            if label==y:
                right_count+=1
        return right_count/len(x_test)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
'
运行

最后实例化类,再送入训练数据训练模型。测试集得分。


#创建model
model=KNN(X_train,Y_train)
# model.predict(X_train)
print(model.score(X_test,Y_test))
test_point=[6.0,3.0]
print('Test_point:{}'.format(model.predict(test_point)))
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7

输出:

0.4
Test_point:1.0
  • 1
  • 2
'
运行

3.可视化


#可视化
plt.scatter(df[:50]['sepal length'], df[:50]['sepal width'], label='0')
plt.scatter(df[50:100]['sepal length'], df[50:100]['sepal width'], label='1')
plt.plot(test_point[0], test_point[1], 'bo', label='test_point')
plt.xlabel('sepal length')
plt.ylabel('sepal width')
plt.show()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8

在这里插入图片描述

总结

1.KNN这种训练的过程和一般意义的训练过程不一样,KNN训练时需要给一个新的数据,让这个数据与训练集数据计算距离,并最终以少数服从多数的原则判断给定数据属于哪一类。因此KNN不具备显式的学习过程。
2.没有之前在感知机里的权重参数w和b,只有超参数k。
3.比较赞同书上的这句话:KNN算法中当训练集、距离度量、k值及分类决策规则确定后,对于任何一个新的输入实例,它所属于的类已经确定好了。就相当于把特征空间划分为一些子空间,确定子空间的每个点所属于的类,点落在这个子空间内即这个点属于这个类。
知识点:
1.k值的过大或过小会带来什么影响?
k值过大,那就意味着算法不需要花太多力气给一个数据我给它归属于一个类,也能取得很好的效果。
k值过小,意味者对数据太过敏感,模型变得复杂,容易过拟合。考虑周围有噪声等情况。

本文内容由网友自发贡献,转载请注明出处:【wpsshop博客】
推荐阅读
相关标签
  

闽ICP备14008679号