当前位置:   article > 正文

scikit-learn KNN实现糖尿病预测

scikit-learn KNN实现糖尿病预测

随书代码,阅读笔记。

KNN是一种有监督的机器学习算法,可以解决分类问题,也可以解决回归问题。

算法优点:准确性高,对异常值和噪声有较高的容忍度;

算法缺点:计算量大,内存消耗也比较大。

针对算法计算量大,有一些改进的数据结构,避免重复计算K-D Tree, Ball Tree。

算法变种:根据邻居的距离,分配不同权重。另外一个变种是指定半径。

  • KNN进行分类
  1. %matplotlib inline
  2. import matplotlib.pyplot as plt
  3. import numpy as np
  4. import pandas as pd
  5. from sklearn.datasets.samples_generator import make_blobs
  6. # 生成数据
  7. centers = [[-2, 2], [2, 2], [0, 4]]
  8. X, y = make_blobs(n_samples=60, centers=centers, random_state=0, cluster_std=0.60)
  9. # 画出数据
  10. plt.figure(figsize=(16, 10), dpi=144)
  11. c = np.array(centers)
  12. plt.scatter(X[:, 0], X[:, 1], c=y, s=100, cmap='cool'); # 画出样本
  13. plt.scatter(c[:, 0], c[:, 1], s=100, marker='^', c='orange'); # 画出中心点
  14. from sklearn.neighbors import KNeighborsClassifier
  15. # 模型训练
  16. k = 5
  17. clf = KNeighborsClassifier(n_neighbors=k)
  18. clf.fit(X, y);
  19. # 进行预测
  20. X_sample = [0, 2]
  21. y_sample = clf.predict(X_sample);
  22. neighbors = clf.kneighbors(X_sample, return_distance=False);
  23. # 画出示意图
  24. plt.figure(figsize=(16, 10), dpi=144)
  25. plt.scatter(X[:, 0], X[:, 1], c=y, s=100, cmap='cool'); # 样本
  26. plt.scatter(c[:, 0], c[:, 1], s=100, marker='^', c='k'); # 中心点
  27. plt.scatter(X_sample[0], X_sample[1], marker="x",
  28. c=y_sample, s=100, cmap='cool') # 待预测的点
  29. for i in neighbors[0]:
  30. plt.plot([X[i][0], X_sample[0]], [X[i][1], X_sample[1]],
  31. 'k--', linewidth=0.6); # 预测点与距离最近的 5 个样本的连线

  • KNN进行回归拟合
  1. %matplotlib inline
  2. import matplotlib.pyplot as plt
  3. import numpy as np
  4. # 生成训练样本
  5. n_dots = 40
  6. X = 5 * np.random.rand(n_dots, 1)
  7. y = np.cos(X).ravel()
  8. # 添加一些噪声
  9. y += 0.2 * np.random.rand(n_dots) - 0.1
  10. # 训练模型
  11. from sklearn.neighbors import KNeighborsRegressor
  12. k = 5
  13. knn = KNeighborsRegressor(k)
  14. knn.fit(X, y);
  15. # 生成足够密集的点并进行预测
  16. T = np.linspace(0, 5, 500)[:, np.newaxis]
  17. y_pred = knn.predict(T)
  18. knn.score(X, y)
  19. #output:0.98579189493611052
  20. # 画出拟合曲线
  21. plt.figure(figsize=(16, 10), dpi=144)
  22. plt.scatter(X, y, c='g', label='data', s=100) # 画出训练样本
  23. plt.plot(T, y_pred, c='k', label='prediction', lw=4) # 画出拟合曲线
  24. plt.axis('tight')
  25. plt.title("KNeighborsRegressor (k = %i)" % k)
  26. plt.show()

  • KNN 实现糖尿病预测
    1. %matplotlib inline
    2. import matplotlib.pyplot as plt
    3. import numpy as np
    4. import pandas as pd
    5. # 加载数据
    6. data = pd.read_csv('datasets/pima-indians-diabetes/diabetes.csv')
    7. print('dataset shape {}'.format(data.shape))
    8. data.head()
    9. data.groupby("Outcome").size()
    10. #Outcome
    11. #0 500 无糖尿病
    12. #1 268 有糖尿病
    13. #dtype: int64
    14. X = data.iloc[:, 0:8]
    15. Y = data.iloc[:, 8]
    16. print('shape of X {}; shape of Y {}'.format(X.shape, Y.shape))
    17. from sklearn.model_selection import train_test_split
    18. X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size=0.2);
    19. from sklearn.neighbors import KNeighborsClassifier, RadiusNeighborsClassifier
    20. models = []
    21. models.append(("KNN", KNeighborsClassifier(n_neighbors=2)))
    22. models.append(("KNN with weights", KNeighborsClassifier(
    23. n_neighbors=2, weights="distance")))
    24. models.append(("Radius Neighbors", RadiusNeighborsClassifier(
    25. n_neighbors=2, radius=500.0)))
    26. results = []
    27. for name, model in models:
    28. model.fit(X_train, Y_train)
    29. results.append((name, model.score(X_test, Y_test)))
    30. for i in range(len(results)):
    31. print("name: {}; score: {}".format(results[i][0],results[i][1]))
    32. #name: KNN; score: 0.681818181818
    33. #name: KNN with weights; score: 0.636363636364
    34. #name: Radius Neighbors; score: 0.62987012987
    35. from sklearn.model_selection import KFold
    36. from sklearn.model_selection import cross_val_score
    37. #kfold 训练10次,计算10次的平均准确率
    38. results = []
    39. for name, model in models:
    40. kfold = KFold(n_splits=10)
    41. cv_result = cross_val_score(model, X, Y, cv=kfold)
    42. results.append((name, cv_result))
    43. for i in range(len(results)):
    44. print("name: {}; cross val score: {}".format(
    45. results[i][0],results[i][1].mean()))
    46. #name: KNN; cross val score: 0.714764183185
    47. #name: KNN with weights; cross val score: 0.677050580998
    48. #name: Radius Neighbors; cross val score: 0.6497265892
    49. #模型训练
    50. knn = KNeighborsClassifier(n_neighbors=2)
    51. knn.fit(X_train, Y_train)
    52. train_score = knn.score(X_train, Y_train)
    53. test_score = knn.score(X_test, Y_test)
    54. print("train score: {}; test score: {}".format(train_score, test_score))
    55. #画出学习曲线
    56. from sklearn.model_selection import ShuffleSplit
    57. from common.utils import plot_learning_curve
    58. knn = KNeighborsClassifier(n_neighbors=2)
    59. cv = ShuffleSplit(n_splits=10, test_size=0.2, random_state=0)
    60. plt.figure(figsize=(10, 6), dpi=200)
    61. plot_learning_curve(plt, knn, "Learn Curve for KNN Diabetes",
    62. X, Y, ylim=(0.0, 1.01), cv=cv);
    63. #数据可视化
    64. # 从8个特征中选择2个最重要的特征进行可视化
    65. from sklearn.feature_selection import SelectKBest
    66. selector = SelectKBest(k=2)
    67. X_new = selector.fit_transform(X, Y)
    68. X_new[0:5]
    69. results = []
    70. for name, model in models:
    71. kfold = KFold(n_splits=10)
    72. cv_result = cross_val_score(model, X_new, Y, cv=kfold)
    73. results.append((name, cv_result))
    74. for i in range(len(results)):
    75. print("name: {}; cross val score: {}".format(
    76. results[i][0],results[i][1].mean()))
    77. # 画出数据
    78. plt.figure(figsize=(10, 6), dpi=200)
    79. plt.ylabel("BMI")
    80. plt.xlabel("Glucose")
    81. plt.scatter(X_new[Y==0][:, 0], X_new[Y==0][:, 1], c='r', s=20, marker='o'); # 画出样本
    82. plt.scatter(X_new[Y==1][:, 0], X_new[Y==1][:, 1], c='g', s=20, marker='^'); # 画出样本
    83. #2个特征和8个特征得到的结果差不多。分类效果达到了瓶颈

KNN对糖尿病进行测试,无法得到比较高的预测准确性

扩展阅读

本文内容由网友自发贡献,转载请注明出处:https://www.wpsshop.cn/w/一键难忘520/article/detail/878083
推荐阅读
  

闽ICP备14008679号