当前位置:   article > 正文

机器学习 --- kNN算法_头歌实践平台:机器学习之knn算法

头歌实践平台:机器学习之knn算法

第1关:实现kNN算法

  1. #encoding=utf8
  2. import numpy as np
  3. class kNNClassifier(object):
  4. def __init__(self, k):
  5. '''
  6. 初始化函数
  7. :param k:kNN算法中的k
  8. '''
  9. self.k = k
  10. # 用来存放训练数据,类型为ndarray
  11. self.train_feature = None
  12. # 用来存放训练标签,类型为ndarray
  13. self.train_label = None
  14. def fit(self, feature, label):
  15. '''
  16. kNN算法的训练过程
  17. :param feature: 训练集数据,类型为ndarray
  18. :param label: 训练集标签,类型为ndarray
  19. :return: 无返回
  20. '''
  21. #********* Begin *********#
  22. self.train_feature = np.array(feature)
  23. self.train_label = np.array(label)
  24. #********* End *********#
  25. def predict(self, feature):
  26. '''
  27. kNN算法的预测过程
  28. :param feature: 测试集数据,类型为ndarray
  29. :return: 预测结果,类型为ndarray或list
  30. '''
  31. #********* Begin *********#
  32. def _predict(test_data):
  33. distances = [np.sqrt(np.sum((test_data - vec) ** 2)) for vec in self.train_feature]
  34. nearest = np.argsort(distances)
  35. topK = [self.train_label[i] for i in nearest[:self.k]]
  36. votes = {}
  37. result = None
  38. max_count = 0
  39. for label in topK:
  40. if label in votes.keys():
  41. votes[label] += 1
  42. if votes[label] > max_count:
  43. max_count = votes[label]
  44. result = label
  45. else:
  46. votes[label] = 1
  47. if votes[label] > max_count:
  48. max_count = votes[label]
  49. result = label
  50. return result
  51. predict_result = [_predict(test_data) for test_data in feature]
  52. return predict_result
  53. #********* End *********#

第2关:红酒分类

  1. from sklearn.neighbors import KNeighborsClassifier
  2. from sklearn.preprocessing import StandardScaler
  3. def classification(train_feature, train_label, test_feature):
  4. '''
  5. 对test_feature进行红酒分类
  6. :param train_feature: 训练集数据,类型为ndarray
  7. :param train_label: 训练集标签,类型为ndarray
  8. :param test_feature: 测试集数据,类型为ndarray
  9. :return: 测试集数据的分类结果
  10. '''
  11. #********* Begin *********#
  12. scaler = StandardScaler()
  13. train_feature = scaler.fit_transform(train_feature)
  14. test_feature = scaler.transform(test_feature)
  15. clf = KNeighborsClassifier()
  16. clf.fit(train_feature, train_label)
  17. return clf.predict(test_feature)
  18. #********* End **********#

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/神奇cpp/article/detail/827448
推荐阅读
相关标签
  

闽ICP备14008679号