当前位置:   article > 正文

sklearn包中K近邻分类器 KNeighborsClassifier的使用_from sklearn.neighbors import kneighborsclassifier

from sklearn.neighbors import kneighborsclassifier

1. KNN算法

K近邻(k-Nearest Neighbor,KNN)分类算法的核心思想是如果一个样本在特征空间中的k个最相似(即特征空间中最邻近)的样本中的大多数属于某一个类别,则该样本也属于这个类别KNN算法可用于多分类,KNN算法不仅可以用于分类,还可以用于回归。通过找出一个样本的k个最近邻居,将这些邻居的属性的平均值赋给该样本,作为预测值。

KNeighborsClassifier在scikit-learn 在sklearn.neighbors包之中。KNeighborsClassifier使用很简单,三步:

1)创建KNeighborsClassifier对象,

2)调用fit函数,

3)调用predict函数进行预测。

以下代码说明了用法。

例子一:

[python]  view plain  copy
  1. from sklearn.neighbors import KNeighborsClassifier  
  2.   
  3. X = [[0], [1], [2], [3],[4], [5],[6],[7],[8]]  
  4. y = [000111222]  
  5.   
  6. neigh = KNeighborsClassifier(n_neighbors=3)  
  7. neigh.fit(X, y)  
  8.   
  9. print(neigh.predict([[1.1]]))   #结果[0]
  10. print(neigh.predict([[1.6]]))   #结果[0]
  11. print(neigh.predict([[5.2]]))   #结果[1]
  12. print(neigh.predict([[5.8]]))   #结果[2]
  13. print(neigh.predict([[6.2]]))   #结果[3]

例子二:

  1. from sklearn import datasets
  2. from sklearn import *
  3. # from sklearn.neighbors import KNeighborsClassifier
  4. # from sklearn.cross_validation import train_test_split
  5. iris=datasets.load_iris()
  6. iris_X=iris.data
  7. iris_Y=iris.target
  8. X_train,X_test,Y_train,Y_test = train_test_split(iris_X,iris_Y,test_size=0.3)
  9. knn=KNeighborsClassifier()
  10. knn.fit(X_train,Y_train)
  11. print(knn.predict(X_test))
  12. print(Y_test)

2. 实例

1)小麦种子数据集 (seeds)

七个特征,面积、周长、紧密度、谷粒的长度、谷粒的宽度、偏度系数和谷粒槽长度。数据格式如下:

[plain]  view plain  copy
  1. 15.26   14.84   0.871   5.763   3.312   2.221   5.22    Kama  
  2. 14.88   14.57   0.8811  5.554   3.333   1.018   4.956   Kama  
  3. 14.29   14.09   0.905   5.291   3.337   2.699   4.825   Kama  
  4. 13.84   13.94   0.8955  5.324   3.379   2.259   4.805   Kama  
  5. 16.14   14.99   0.9034  5.658   3.562   1.355   5.175   Kama  
  6. 14.38   14.21   0.8951  5.386   3.312   2.462   4.956   Kama  
  7. 14.69   14.49   0.8799  5.563   3.259   3.586   5.219   Kama  
  8. 14.11   14.1    0.8911  5.42    3.302   2.7     5.0     Kama  
  9. 16.63   15.46   0.8747  6.053   3.465   2.04    5.877   Kama  

2)代码

[python]  view plain  copy
  1. # -*- coding:utf-8 -*-  
  2. import numpy as np  
  3. from matplotlib import pyplot as plt  
  4. from matplotlib.colors import ListedColormap  
  5. from sklearn.neighbors import KNeighborsClassifier  
  6. from sklearn.cross_validation import KFold, cross_val_score  
  7.   
  8. feature_names = [  
  9.     'area',  
  10.     'perimeter',  
  11.     'compactness',  
  12.     'length of kernel',  
  13.     'width of kernel',  
  14.     'asymmetry coefficien',  
  15.     'length of kernel groove',  
  16. ]  
  17.   
  18. COLOUR_FIGURE = False  
  19.   
  20.   
  21. def plot_decision(features, labels, num_neighbors=3):  
  22.     y_min, y_max = features[:, 2].min() * .9, features[:, 2].max() * 1.1  
  23.     x_min, x_max = features[:, 0].min() * .9, features[:, 0].max() * 1.1  
  24.     X, Y = np.meshgrid(np.linspace(x_min, x_max, 1000), np.linspace(y_min, y_max, 1000))  
  25.   
  26.     model = KNeighborsClassifier(num_neighbors)  
  27.     model.fit(features[:, (0,2)], labels)  
  28.     C = model.predict(np.vstack([X.ravel(), Y.ravel()]).T).reshape(X.shape)  
  29.     if COLOUR_FIGURE:  
  30.         cmap = ListedColormap([(1., .7, .7), (.71., .7), (.7, .71.)])  
  31.     else:  
  32.         cmap = ListedColormap([(1.1.1.), (.2, .2, .2), (.6, .6, .6)])  
  33.     fig,ax = plt.subplots()  
  34.     ax.set_xlim(x_min, x_max)  
  35.     ax.set_ylim(y_min, y_max)  
  36.     ax.set_xlabel(feature_names[0])  
  37.     ax.set_ylabel(feature_names[2])  
  38.     ax.pcolormesh(X, Y, C, cmap=cmap)  
  39.     if COLOUR_FIGURE:  
  40.         cmap = ListedColormap([(1., .0, .0), (.1, .6, .1), (.0, .01.)])  
  41.         ax.scatter(features[:, 0], features[:, 2], c=labels, cmap=cmap)  
  42.     else:  
  43.         for lab, ma in zip(range(3), "Do^"):  
  44.             ax.plot(features[labels == lab, 0],  
  45.                     features[labels == lab, 2],  
  46.                     ma,  
  47.                     c=(1.1.1.),  
  48.                     ms=6)  
  49.     return fig, ax  
  50.   
  51.   
  52. def load_csv_data(filename):  
  53.     data = []  
  54.     labels = []  
  55.     datafile = open(filename)  
  56.     for line in datafile:  
  57.         fields = line.strip().split('\t')  
  58.         data.append([float(field) for field in fields[:-1]])  
  59.         labels.append(fields[-1])  
  60.     data = np.array(data)  
  61.     labels = np.array(labels)  
  62.     return data, labels  
  63.   
  64.   
  65. def accuracy(test_labels, pred_lables):  
  66.     correct = np.sum(test_labels == pred_lables)  
  67.     n = len(test_labels)  
  68.     return float(correct) / n  
  69.   
  70.   
  71. if __name__ == '__main__':  
  72.     opt = input("raw_inputp[1 or 2]: ")  
  73.     features, labels = load_csv_data('data/seeds.tsv')  
  74.     if opt == '1':  
  75.         knn = KNeighborsClassifier(n_neighbors=5)  
  76.         kf = KFold(len(features), n_folds=3, shuffle=True)  
  77.         result_set = [(knn.fit(features[train], labels[train]).predict(features[test]), test) for train, test in kf]  
  78.         score = [accuracy(labels[result[1]], result[0]) for result in result_set]  
  79.         print(score)  
  80.     elif opt == '2':  
  81.         names = sorted(set(labels))  
  82.         labels = np.array([names.index(ell) for ell in labels])  
  83.         fig, ax = plot_decision(features, labels)  
  84.         plt.show()  
  85.     else:  
  86.         print('input 1 or 2 !')  

代码简要说明 

load_csv_data 从数据文件,读取数据。

accuracy 计算预测的准确度。

plot_decision 画决策边界图,挑两个特征。这个函数要注意pcolormesh。

主程序:输入1进行预测,输入2画图。第一个选项中,

a)首先生成分类器

b)调用KFold来生产学习数据和测试数据,

3)训练和预测,

4)计算精度。

这里充分利用了“列表解析”和“向量”使代码简洁。



声明:本文内容由网友自发贡献,转载请注明出处:【wpsshop博客】
推荐阅读
相关标签
  

闽ICP备14008679号