赞
踩
Cover和Hart在1968年提出了最初的邻近算法。所谓K最近邻,就是k个最近的邻居的意思,说的是每个样本都可以用它最接近的k个邻居来代表。KNN是一种分类(classification)算法,它输入基于实例的学习(instance-based learning),属于懒惰学习(lazy learning)即KNN没有显式的学习过程,也就是说没有训练阶段,数据集事先已有了分类和特征值,待收到新样本后直接进行处理。KNN是通过测量不同特征值之间的距离进行分类。
KNN算法的思路是:如果一个样本在特征空间中的k个最邻近的样本中的大多数属于某一个类别,则该样本也划分为这个类别。KNN算法中,所选择的邻居都是已经正确分类的对象。该方法在定类决策上只依据最邻近的一个或者几个样本的类别来决定待分样本所属的类别。
1)计算测试数据与各个训练数据之间的距离;
2)按照距离的递增关系进行排序;
3)选取距离最小的K个点
4)确定前K个点所在类别的出现频率
5)返回前K个点中出现频率最高的类别作为测试数据的预测分类
from sklearn import datasets from sklearn.model_selection import train_test_split from sklearn.neighbors import KNeighborsClassifier from sklearn.model_selection import cross_val_score#引入交叉验证 import matplotlib.pyplot as plt ###引入数据### iris=datasets.load_iris() X=iris.data y=iris.target ###设置n_neighbors的值为1到30,通过绘图来看训练分数### k_range=range(1,31) k_score=[] for k in k_range: knn=KNeighborsClassifier(n_neighbors=k) scores=cross_val_score(knn, X, y, cv=10, scoring='accuracy')#for classfication k_score.append(scores.mean()) plt.figure() plt.plot(k_range,k_score) plt.xlabel('Value of k for KNN') plt.ylabel('CrossValidation accuracy') plt.show()
由图可以看出K=13的时候取的最高的得分。
4. KNN算法中的距离
# 基础结构.py # import numpy as np from sklearn import neighbors, datasets, preprocessing from sklearn.model_selection import train_test_split from sklearn.metrics import accuracy_score from sklearn.model_selection import cross_val_score np.random.RandomState(0) # 加载数据 iris = datasets.load_iris() # 划分训练集与测试集 x, y = iris.data, iris.target x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.3) # 数据预处理 scaler = preprocessing.StandardScaler().fit(x_train) x_train = scaler.transform(x_train) x_test = scaler.transform(x_test) # 创建模型 knn = neighbors.KNeighborsClassifier(n_neighbors=5) # 模型拟合 knn.fit(x_train, y_train) # 交叉验证 scores = cross_val_score(knn, x_train, y_train, cv=5, scoring='accuracy') print(scores) # 每组的评分结果 print(scores.mean()) # 预测 y_pred = knn.predict(x_test) # 评估 print(accuracy_score(y_test, y_pred))
import csv import random import math import operator # 数据集模块定义 def loadDataset(filename, split, trainingSet=[], testSet=[]): with open(filename, 'r') as csvfile: lines = csv.reader(csvfile) dataset = list(lines) for x in range(len(dataset)-1): for y in range(5): dataset[x][y] = float(dataset[x][y]) if random.random() < split: trainingSet.append(dataset[x]) else: testSet.append(dataset[x]) # 计算欧式距离Euclidean Distance def euclideanDistance(instance1, instance2, lengths): distance = 0 for x in range(lengths): distance += math.pow((float(instance1[x]) - float(instance2[2])), 2) return math.sqrt(distance) # 最邻近算法,找出测试集中与训练集最邻近的K个值 def getNeighbors(trainingSet, testInstance, k): distances = [] length = len(testInstance) -1 for x in range(len(trainingSet)): dist = euclideanDistance(testInstance, trainingSet[x], length) distances.append((trainingSet[x], dist)) # 将距离从小到大排序,然后取前n个邻居 distances.sort(key=operator.itemgetter(1)) neighbors = [] for x in range(k): neighbors.append(distances[x][0]) return neighbors # 对找出的K个值进行打分,比较哪个类别出现的次数最多 def getResponse(neighbors): classVotes = {} # 统计类别出现的次数 for x in range(len(neighbors)): response = neighbors[x][-1] if response in classVotes: classVotes[response] += 1 else: classVotes[response] = 1 # classVotes.items()将字典以元组的形式显示,key=operator.itemgetter(1) 按第一个(索引是1)元素排序。 sortedVotes = sorted(classVotes.items(), key=operator.itemgetter(1), reverse=True) return sortedVotes[0][0] # 计算精确度 def getAccuary(testSet, predictions): correct = 0 for x in range(len(testSet)): if testSet[x][-1] == predictions[x]: correct += 1 return (correct / float(len(testSet))) * 100.0 def main(): trainingSet = [] testSet =[] # 需要将 1/2的数据划分为训练集 split = 0.50 loadDataset(r"./iris.csv", split, trainingSet, testSet) print("train set:" + repr(len(trainingSet))) print("Test set: " + repr(len(testSet))) predictions = [] k = 11 for x in range(len(testSet)): neighbors = getNeighbors(trainingSet, testSet[x], k) result = getResponse(neighbors) predictions.append(result) print('> predicted = ' + repr(result) + ', actual=' + repr(testSet[x][-1])) accuracy = getAccuary(testSet, predictions) print('Accuracy:' + repr(accuracy) + '%') main()
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。