当前位置:   article > 正文

机器学习实验一:KNN算法,手写数字数据集(使用汉明距离)(2)_基于汉明距离的knng

基于汉明距离的knng

KNN-手写数字数据集:

      使用sklearn中的KNN算法工具包( KNeighborsClassifier)替换实现分类器的构建,注意使用的是汉明距离

 运行结果:(大概要运行4分钟左右)

代码:

  1. import pandas as pd
  2. import os
  3. def hamming(str1, str2):
  4. if len(str1) != len(str2):
  5. raise ValueError("两个字符串长度不相等")
  6. return sum(c1 != c2 for c1, c2 in zip(str1, str2))
  7. def get_train():
  8. path = 'digits/trainingDigits'
  9. trainingFileList0 = os.listdir(path)
  10. trainingFileList = [file[2:] if file.startswith('._') else file for file in trainingFileList0]
  11. train = pd.DataFrame()
  12. img = []
  13. labels = []
  14. for i in range(len(trainingFileList)):
  15. filename = trainingFileList[i]
  16. with open(f'digits/trainingDigits/{filename}', 'r') as f:
  17. txt = f.read().replace('\n', '')
  18. img.append(txt)
  19. filelabel = filename.split('_')[0]
  20. labels.append(filelabel)
  21. train['img'] = img
  22. train['labels'] = labels
  23. return train
  24. def get_test():
  25. path = 'digits/testDigits'
  26. testFileList0 = os.listdir(path)
  27. testFileList = [file[2:] if file.startswith('._') else file for file in testFileList0]
  28. test = pd.DataFrame()
  29. img = []
  30. labels = []
  31. for filename in testFileList:
  32. with open(f'digits/testDigits/{filename}', 'r') as f:
  33. txt = f.read().replace('\n', '')
  34. img.append(txt)
  35. filelabel = filename.split('_')[0]
  36. labels.append(filelabel)
  37. test['img'] = img
  38. test['labels'] = labels
  39. return test
  40. def handwritingClass(train, test, k):
  41. n = train.shape[0]
  42. m = test.shape[0]
  43. result = []
  44. for i in range(m):
  45. dist = []
  46. for j in range(n):
  47. d = str(hamming(train.iloc[j, 0], test.iloc[i, 0]))
  48. dist.append(d)
  49. dist_l = pd.DataFrame({'dist': dist, 'labels': train.iloc[:, 1]})
  50. dr = dist_l.sort_values(by='dist')[:k]
  51. re = dr.loc[:, 'labels'].value_counts()
  52. result.append(re.index[0])
  53. result = pd.Series(result)
  54. test['predict'] = result
  55. acc = (test.iloc[:, -1] == test.iloc[:, -2]).mean()
  56. print(f'模型预测准确率为{acc:.5f}')
  57. return test
  58. # 获取训练集和测试集
  59. train = get_train()
  60. test = get_test()
  61. # 调用函数
  62. handwritingClass(train, test, 3)

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/盐析白兔/article/detail/703090
推荐阅读
相关标签
  

闽ICP备14008679号