当前位置:   article > 正文

KNN算法_knn算法距离计算

knn算法距离计算

一. 概述

       如果一个样本在特征空间中的k个最相似(即特征空间中最邻近)的样本中的大多数属于某一个类别,则该样本也属于这个类别,其中K通常是不大于20的整数。KNN算法中,所选择的邻居都是已经正确分类的对象。该方法在定类决策上只依据最邻近的一个或者几个样本的类别来决定待分样本所属的类别。

二.算法描述

1.距离计算公式

      在KNN中,通过计算对象间距离来作为各个对象之间的非相似性指标,避免了对象之间的匹配问题,在这里距离一般使用欧氏距离或曼哈顿距离:

 2.描述:

       就是在训练集中数据和标签已知的情况下,输入测试数据,将测试数据的特征与训练集中对应的特征进行相互比较,找到训练集中与之最为相似的前K个数据,则该测试数据对应的类别就是K个数据中出现次数最多的那个分类,其算法的描述为:

1)计算测试数据与各个训练数据之间的距离;

2)按照距离的递增关系进行排序;

3)选取距离最小的K个点;

4)确定前K个点所在类别的出现频率;

5)返回前K个点中出现频率最高的类别作为测试数据的预测分类。

三.python实现

  1. from numpy import *
  2. import operator
  3. #给出训练数据以及对应的类别
  4. def createDataSet():
  5. group = array([[1.0, 2.0], [1.2, 0.1], [0.1, 1.4], [0.3, 3.5]])
  6. labels = ['A', 'A', 'B', 'B']
  7. return group, labels
  8. ###通过KNN进行分类
  9. def classify(input, dataSet, labels, k):
  10. dataSize=dataSet.shape[0]
  11. ####计算欧式距离
  12. diff = tile(input, (dataSize, 1)) - dataSet
  13. sqdiff = diff ** 2
  14. squareDist = sum(sqdiff, axis=1) ###行向量分别相加,从而得到新的一个行向量
  15. dist = squareDist ** 0.5
  16. ##对距离进行排序
  17. sortedDistIndex = argsort(dist) ##argsort()根据元素的值从大到小对元素进行排序,返回下标
  18. classCount = {}
  19. for i in range(k):
  20. voteLabel = labels[sortedDistIndex[i]]
  21. ###对选取的K个样本所属的类别个数进行统计
  22. classCount[voteLabel] = classCount.get(voteLabel, 0) + 1
  23. ###选取出现的类别次数最多的类别
  24. maxCount = 0
  25. for key, value in classCount.items():
  26. if value > maxCount:
  27. maxCount = value
  28. classes = key
  29. return classes
  30. if __name__=='__main__':
  31. dataset,labels = createDataSet()
  32. input = array([1.1,0.3])
  33. k = 3
  34. output =classify(input,dataset,labels,k)
  35. print("测试数据为:",input,"分类结果为:",output)

运行结果如下:

  1. E:\Python\python.exe E:/work/PycharmProjects/LuoDemo/KNN.py
  2. 测试数据为: [1.1 0.3] 分类结果为: A
  3. Process finished with exit code 0

四.matlab实现

  1. %% KNN
  2. clear all
  3. clc
  4. %% data
  5. trainData = [1.0,2.0;1.2,0.1;0.1,1.4;0.3,3.5];
  6. trainClass = [1,1,2,2];
  7. testData = [0.5,2.3];
  8. k = 3;
  9. %% distance
  10. row = size(trainData,1);
  11. col = size(trainData,2);
  12. test = repmat(testData,row,1);
  13. dis = zeros(1,row);
  14. for i = 1:row
  15. diff = 0;
  16. for j = 1:col
  17. diff = diff + (test(i,j) - trainData(i,j)).^2;
  18. end
  19. dis(1,i) = diff.^0.5;
  20. end
  21. %% sort
  22. jointDis = [dis;trainClass];
  23. sortDis= sortrows(jointDis');
  24. sortDisClass = sortDis';
  25. %% find
  26. class = sort(2:1:k);
  27. member = unique(class);
  28. num = size(member);
  29. max = 0;
  30. for i = 1:num
  31. count = find(class == member(i));
  32. if count > max
  33. max = count;
  34. label = member(i);
  35. end
  36. end
  37. disp('最终的分类结果为:');
  38. fprintf('%d\n',label)

结果如下:

 

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/我家自动化/article/detail/703156
推荐阅读
相关标签
  

闽ICP备14008679号