当前位置:   article > 正文

sklearn.metrics模块模型评价函数_sklearn评价函数

sklearn评价函数

sklearn中有3种不同的API用于评估模型的预测质量。

估算器得分方法:估算器有一个评分方法,为它们被设计用来解决的问题提供一个默认的评价标准。
计分参数:使用交叉验证的模型评价工具(如model_selection.cross_val_score和model_selection.GridSearchCV)依赖于一个内部评分策略。
度量函数:sklearn.metrics模块为特定目的实现了评估预测误差的函数。

模型评价函数示例:

  1. from sklearn import metrics
  2. # 查看模块的函数
  3. dir(metrics)
  4. ### 1.Accuracy score
  5. import numpy as np
  6. from sklearn.metrics import accuracy_score
  7. y_pred = [0, 2, 1, 3]
  8. y_true = [0, 1, 2, 3]
  9. print(accuracy_score(y_true, y_pred))
  10. print(accuracy_score(y_true, y_pred, normalize=False)) # 正确的预测个数
  11. ### 2.Top-k accuracy score
  12. # top_k_accuracy_score函数是accuracy_score的泛化。
  13. # 区别在于,只要真实标签与k个最高预测分数之一相关联,预测就被认为是正确的。
  14. # 准确度_分数是k=1的特殊情况。
  15. import numpy as np
  16. from sklearn.metrics import top_k_accuracy_score
  17. y_true = np.array([0, 1, 2, 2])
  18. y_score = np.array([[0.5, 0.2, 0.2],
  19. [0.3, 0.4, 0.2],
  20. [0.2, 0.4, 0.3],
  21. [0.7, 0.2, 0.1]])
  22. top_k_accuracy_score(y_true, y_score, k=2)
  23. # Not normalizing gives the number of "correctly" classified samples
  24. top_k_accuracy_score(y_true, y_score, k=2, normalize=False)
  25. ### 3.confusion_matrix
  26. from sklearn import datasets
  27. from sklearn.svm import LinearSVC
  28. from sklearn.model_selection import cross_validate
  29. from sklearn.metrics import confusion_matrix,ConfusionMatrixDisplay
  30. import matplotlib.pyplot as plt
  31. y_true = [2, 0, 2, 2, 0, 1]
  32. y_pred = [0, 0, 2, 2, 0, 2]
  33. cm=confusion_matrix(y_true, y_pred)
  34. print(confusion_matrix(y_true, y_pred))
  35. print(confusion_matrix(y_true, y_pred,normalize='all'))
  36. disp = ConfusionMatrixDisplay(confusion_matrix=cm)
  37. disp.plot()
  38. plt.show()
  39. # 二分类
  40. y_true = [0, 0, 0, 1, 1, 1, 1, 1]
  41. y_pred = [0, 1, 0, 1, 0, 1, 0, 1]
  42. tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel()
  43. print(tn, fp, fn, tp)
  44. # A sample toy binary classification dataset
  45. X, y = datasets.make_classification(n_classes=2, random_state=0)
  46. svm = LinearSVC(random_state=0)
  47. def confusion_matrix_scorer(clf, X, y):
  48. y_pred = clf.predict(X)
  49. cm = confusion_matrix(y, y_pred)
  50. return {'tn': cm[0, 0], 'fp': cm[0, 1],
  51. 'fn': cm[1, 0], 'tp': cm[1, 1]}
  52. cv_results = cross_validate(svm, X, y, cv=5,
  53. scoring=confusion_matrix_scorer)
  54. # Getting the test set true positive scores
  55. print(cv_results['test_tp'])
  56. # Getting the test set false negative scores
  57. print(cv_results['test_fn'])
  58. print(cv_results['test_tn'])
  59. print(cv_results['test_fp'])
  60. ### 4.classification_report
  61. from sklearn.metrics import classification_report
  62. y_true = [0, 1, 2, 2, 0]
  63. y_pred = [0, 0, 2, 1, 0]
  64. target_names = ['class 0', 'class 1', 'class 2']
  65. print(classification_report(y_true, y_pred, target_names=target_names))
  66. ### 5. hamming_loss
  67. from sklearn.metrics import hamming_loss
  68. y_pred = [1, 2, 3, 4]
  69. y_true = [2, 2, 3, 4]
  70. hamming_loss(y_true, y_pred)
  71. ### 6. Precision, recall and F-measures
  72. from sklearn import metrics
  73. y_pred = [0, 1, 0, 0]
  74. y_true = [0, 1, 0, 1]
  75. metrics.precision_score(y_true, y_pred)
  76. metrics.recall_score(y_true, y_pred)
  77. metrics.f1_score(y_true, y_pred)
  78. metrics.fbeta_score(y_true, y_pred, beta=0.5)
  79. metrics.fbeta_score(y_true, y_pred, beta=1)
  80. metrics.fbeta_score(y_true, y_pred, beta=2)
  81. metrics.precision_recall_fscore_support(y_true, y_pred, beta=0.5)
  82. import numpy as np
  83. from sklearn.metrics import precision_recall_curve
  84. from sklearn.metrics import average_precision_score
  85. y_true = np.array([0, 0, 1, 1])
  86. y_scores = np.array([0.1, 0.4, 0.35, 0.8])
  87. precision, recall, threshold = precision_recall_curve(y_true, y_scores)
  88. print(precision)
  89. print(recall)
  90. print(threshold)
  91. print( average_precision_score(y_true, y_scores))
  92. ## 多分类
  93. from sklearn import metrics
  94. y_true = [0, 1, 2, 0, 1, 2]
  95. y_pred = [0, 2, 1, 0, 0, 1]
  96. print(metrics.precision_score(y_true, y_pred, average='macro'))
  97. print(metrics.recall_score(y_true, y_pred, average='micro'))
  98. print(metrics.f1_score(y_true, y_pred, average='weighted'))
  99. print(metrics.fbeta_score(y_true, y_pred, average='macro', beta=0.5))
  100. print(metrics.precision_recall_fscore_support(y_true, y_pred, beta=0.5, average=None))
  101. ### 7. 回归预测的r2_score
  102. # r2_score函数计算决定系数,通常表示为R²。
  103. # 它表示模型中自变量所解释的方差(Y)的比例。它提供了拟合度的指示,
  104. # 因此通过解释方差的比例来衡量未见过的样本可能被模型预测的程度。
  105. from sklearn.metrics import r2_score
  106. y_true = [3, -0.5, 2, 7]
  107. y_pred = [2.5, 0.0, 2, 8]
  108. r2_score(y_true, y_pred)
  109. ### 8. 回归预测的mean_absolute_error
  110. from sklearn.metrics import mean_absolute_error
  111. y_true = [3, -0.5, 2, 7]
  112. y_pred = [2.5, 0.0, 2, 8]
  113. mean_absolute_error(y_true, y_pred)
  114. ### 9. 回归预测的mean_squared_error
  115. from sklearn.metrics import mean_squared_error
  116. y_true = [3, -0.5, 2, 7]
  117. y_pred = [2.5, 0.0, 2, 8]
  118. mean_squared_error(y_true, y_pred)
  119. ### 10. 回归预测的mean_squared_log_error
  120. from sklearn.metrics import mean_squared_log_error
  121. y_true = [3, 5, 2.5, 7]
  122. y_pred = [2.5, 5, 4, 8]
  123. mean_squared_log_error(y_true, y_pred)
  124. ### 11.无监督聚类的Silhouette Coefficient
  125. from sklearn import metrics
  126. from sklearn import datasets
  127. import numpy as np
  128. X, y = datasets.load_iris(return_X_y=True)
  129. from sklearn.cluster import KMeans
  130. kmeans_model = KMeans(n_clusters=3, random_state=1).fit(X)
  131. labels = kmeans_model.labels_
  132. metrics.silhouette_score(X, labels, metric='euclidean')

参考:

https://scikit-learn.org/stable/modules/model_evaluation.html#multilabel-ranking-metrics

https://scikit-learn.org/stable/modules/clustering.html#clustering-evaluation

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/weixin_40725706/article/detail/142828
推荐阅读
相关标签
  

闽ICP备14008679号