当前位置:   article > 正文

KNN算法简单实战_nrknn

nrknn

KNN算法简介

KNN算法又称K近邻(knn,k-NearestNeighbor)分类算法,K最近邻,就是k个最近的邻居的意思,说的是每个样本都可以用它最接近的k个邻居来代表。KNN通过测量不同特征值之间的距离进行分类。

KNN算法思路

取样本在特征空间中的k个最近的邻居,判断它们的分类情况,KNN算法将样本预测为k个最近邻居中大多数的分类情况。

KNN算法实战

参考博客
https://www.cnblogs.com/xuyiqing/p/8762066.html

数据集

使用mnist特征提取后的数据集,对于每一个mnist图像,通过CNN网络提取出1024个特征,每个图像的1024个特征保存在一行,每行最后一个数字是图像的标签,即数字[0,9]。我总共准备了1000个数据集。
数据集可从此下载,提取码: ttnr
https://pan.baidu.com/s/1Xei_yIK1BOY0cuFC4Lp0MA

具体代码

import csv
import random
import math
import operator

# 训练集
trainingSet = []
# 测试集
testSet = []


def loadDataset(filename):
    with open(filename, 'r') as file:  # 打开文件,文件句柄为file
        lines = csv.reader(file)  # 将每行数据当做列表返回
        dataset = list(lines)
#        trainSetnum = ((len(dataset) - 1) * 2) / 3;
#        print(len(dataset))
        for x in range(len(dataset) - 1):  # 遍历读取的每一行列表
            for y in range(1024):
                dataset[x][y] = float(dataset[x][y])  #因为一开始读取出来是str类型的,所以要将其转化为float类型
            if random.random() < 0.7:  # 随机赋值给训练集和测试集
                trainingSet.append(dataset[x])  # x表示的是行数
            else:
                testSet.append(dataset[x])


def euclideanDistance(instance1, instance2, length):  # 扩展到多维的欧式距离,根号(x1^2+...xi^2)
    distance = 0
    for x in range(length):
        distance += pow((instance1[x] - instance2[x]), 2)
    return math.sqrt(distance)


def getNeighbors(k, testInstance):  # 计算训练集中每一项与该实例的欧氏距离,取最近的k个
    distances = []
    length = len(testInstance) - 1;  # 这个是条目的维数,之所以要减1是因为最后一位是标签
    for x in range(len(trainingSet)):  # 遍历所有的训练集
        dist = euclideanDistance(testInstance, trainingSet[x], length)  # 计算每一个条目与当前条目的欧氏距离
        distances.append((trainingSet[x], dist))  # 把那个对应的条目和欧式距离压入数组
    distances.sort(key=operator.itemgetter(
        1))  # operator模块提供的itemgetter函数用于获取对象的哪些维的数据,opearator.itemgetter是定义了一个函数,这里是以条目的第二个域来进行排序
    neighbors = []
    for x in range(k):      #取欧拉距离最近的k个训练集
        neighbors.append(distances[x][0])
    return neighbors


def getResponse(neighbors):          # 统计类别次数,然后返回次数最多的类别作为最终的结果
    classVotes = {}
    for x in range(len(neighbors)):
        response = neighbors[x][-1]  # 取每个邻居的最后一维,即标签
#        print(response,end=" ")
        if response in classVotes:   # 如果标间存在classVotes列表中,则对应条目加1,否则新建条目
            classVotes[response] += 1
        else:
            classVotes[response] = 1
    sortedVotes = sorted(classVotes.items(), key=operator.itemgetter(1), reverse=True)  # 按照classVotes的第二个域来排序,classVotes的结构是一个个键值对构成的列表,比如“5”:1,表示邻近k个元素中标签为5的元素有1个
#    print(classVotes)
    return sortedVotes[0][0]


def getAccuracy(predictions):  # 验证精确度,将测试级中预测的类别和测试集中正确的类别对比,得出精确百分比
    correct = 0
    for x in range(len(testSet)):
        if testSet[x][-1] == predictions[x]:
            correct += 1
    return (correct / float(len(testSet))) * 100.0


def main():
    loadDataset(r'data.txt')
    print('Train set:' + repr(len(trainingSet)))
    print('Test Set:' + repr(len(testSet)))
    predictions = []
    k = 3
    for x in range(len(testSet)):
        neighbors = getNeighbors(k, testSet[x])#首先获得距测试集第x个元素距离前k的邻居
        result = getResponse(neighbors)        #返回预测的结果
        predictions.append(result)
        print('>predicted=' + repr(result) + ',actual=' + repr(testSet[x][-1]))#输出预测的结果和真实的结果
    accuracy = getAccuracy(predictions)
    print('Accuracy:' + repr(accuracy) + '%')

if __name__ == '__main__':
    main()

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86

运行结果

在这里插入图片描述在这里插入图片描述
取了两次测试的结果,可以看到当数据集够大的时候效果还是不错的,之前只有100个数据集的时候,准确率只有百分60左右,代码原理基本都有注释,如还有不懂可以看下我参考的那篇博客,也可以私信我一起讨论。

总结

KNN算法应该是最简单的一类分类算法了,个人觉得当做分类算法的入门学习还是很不错的,如有问题,请指正。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/凡人多烦事01/article/detail/442593
推荐阅读
相关标签
  

闽ICP备14008679号