赞
踩
如果一个样本在特征空间中的k个最相似(即特征空间中最邻近)的样本中的大多数属于某一个类别,则该样本也属于这个类别,其中K通常是不大于20的整数。KNN算法中,所选择的邻居都是已经正确分类的对象。该方法在定类决策上只依据最邻近的一个或者几个样本的类别来决定待分样本所属的类别。
在KNN中,通过计算对象间距离来作为各个对象之间的非相似性指标,避免了对象之间的匹配问题,在这里距离一般使用欧氏距离或曼哈顿距离:
就是在训练集中数据和标签已知的情况下,输入测试数据,将测试数据的特征与训练集中对应的特征进行相互比较,找到训练集中与之最为相似的前K个数据,则该测试数据对应的类别就是K个数据中出现次数最多的那个分类,其算法的描述为:
1)计算测试数据与各个训练数据之间的距离;
2)按照距离的递增关系进行排序;
3)选取距离最小的K个点;
4)确定前K个点所在类别的出现频率;
5)返回前K个点中出现频率最高的类别作为测试数据的预测分类。
- from numpy import *
- import operator
- #给出训练数据以及对应的类别
- def createDataSet():
- group = array([[1.0, 2.0], [1.2, 0.1], [0.1, 1.4], [0.3, 3.5]])
- labels = ['A', 'A', 'B', 'B']
- return group, labels
- ###通过KNN进行分类
- def classify(input, dataSet, labels, k):
- dataSize=dataSet.shape[0]
- ####计算欧式距离
- diff = tile(input, (dataSize, 1)) - dataSet
- sqdiff = diff ** 2
- squareDist = sum(sqdiff, axis=1) ###行向量分别相加,从而得到新的一个行向量
- dist = squareDist ** 0.5
- ##对距离进行排序
- sortedDistIndex = argsort(dist) ##argsort()根据元素的值从大到小对元素进行排序,返回下标
- classCount = {}
- for i in range(k):
- voteLabel = labels[sortedDistIndex[i]]
- ###对选取的K个样本所属的类别个数进行统计
- classCount[voteLabel] = classCount.get(voteLabel, 0) + 1
- ###选取出现的类别次数最多的类别
- maxCount = 0
- for key, value in classCount.items():
- if value > maxCount:
- maxCount = value
- classes = key
- return classes
- if __name__=='__main__':
- dataset,labels = createDataSet()
- input = array([1.1,0.3])
- k = 3
- output =classify(input,dataset,labels,k)
- print("测试数据为:",input,"分类结果为:",output)
- E:\Python\python.exe E:/work/PycharmProjects/LuoDemo/KNN.py
- 测试数据为: [1.1 0.3] 分类结果为: A
-
- Process finished with exit code 0
- %% KNN
- clear all
- clc
- %% data
- trainData = [1.0,2.0;1.2,0.1;0.1,1.4;0.3,3.5];
- trainClass = [1,1,2,2];
- testData = [0.5,2.3];
- k = 3;
-
- %% distance
- row = size(trainData,1);
- col = size(trainData,2);
- test = repmat(testData,row,1);
- dis = zeros(1,row);
- for i = 1:row
- diff = 0;
- for j = 1:col
- diff = diff + (test(i,j) - trainData(i,j)).^2;
- end
- dis(1,i) = diff.^0.5;
- end
-
- %% sort
- jointDis = [dis;trainClass];
- sortDis= sortrows(jointDis');
- sortDisClass = sortDis';
-
- %% find
- class = sort(2:1:k);
- member = unique(class);
- num = size(member);
-
- max = 0;
- for i = 1:num
- count = find(class == member(i));
- if count > max
- max = count;
- label = member(i);
- end
- end
-
- disp('最终的分类结果为:');
- fprintf('%d\n',label)
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。