当前位置:   article > 正文

tensorflow 二分类、多分类指标评价_tensorflow accuracy score

tensorflow accuracy score

我用的是tensorflow 2.5,搜索网上的教程大部分说用的是tf.keras.metrics中的api,但是经过实验发现都用不了,如今tensorflow 2.5可能不支持这些api了。

于是我采用sklearn库的函数实现二分类问题和多分类问题的评价指标计算。

f1_score, precision_score, recall_score, accuracy_score

二分类问题

  1. # binary classify
  2. from sklearn.metrics import confusion_matrix, f1_score, precision_score, recall_score, accuracy_score
  3. import matplotlib as mpl
  4. import matplotlib.pyplot as plt
  5. # 绘制正例ROC曲线
  6. def plot_roc(name, labels, predictions, **kwargs):
  7. fp, tp, _ = sklearn.metrics.roc_curve(labels, predictions, pos_label=0)
  8. plt.plot(100*fp, 100*tp, label=name, linewidth=2, **kwargs)
  9. plt.xlabel('False positives [%]')
  10. plt.ylabel('True positives [%]')
  11. plt.xlim([-0.5,100.5])
  12. plt.ylim([-0.5,100.5])
  13. plt.grid(True)
  14. ax = plt.gca()
  15. ax.set_aspect('equal')
  16. plt.legend(loc='lower right')
  17. plt.savefig("./img/multi_roc.png")
  18. loss,acc= model.evaluate(x_test,y_test)
  19. test_predictions = model.predict(x_test)
  20. true_labels=y_test.astype('uint8')
  21. test_scores = 1-(test_predictions - test_predictions.min())/(test_predictions.max() - test_predictions.min())
  22. mpl.rcParams['figure.figsize'] = (12, 10)
  23. colors = plt.rcParams['axes.prop_cycle'].by_key()['color']
  24. plot_roc("My Model", true_labels, test_scores, color=colors[0])
  25. recall = recall_score(true_labels,test_predictions.round())
  26. f1 = f1_score(true_labels,test_predictions.round())
  27. precision = precision_score(true_labels,test_predictions.round())
  28. print('accuracy: ',acc)
  29. print('loss: ',loss)
  30. print('recall: ',recall)
  31. print('precision: ',precision)
  32. print('f1: ',f1)

多分类问题

  1. import numpy as np
  2. import matplotlib as mpl
  3. import matplotlib.pyplot as plt
  4. import sklearn
  5. from sklearn.metrics import confusion_matrix, f1_score, precision_score, recall_score, accuracy_score
  6. # 绘制正例ROC曲线
  7. def plot_roc(name, labels, predictions, **kwargs):
  8. fp, tp, _ = sklearn.metrics.roc_curve(labels, predictions, pos_label=0)
  9. plt.plot(100*fp, 100*tp, label=name, linewidth=2, **kwargs)
  10. plt.xlabel('False positives [%]')
  11. plt.ylabel('True positives [%]')
  12. plt.xlim([-0.5,100.5])
  13. plt.ylim([-0.5,100.5])
  14. plt.grid(True)
  15. ax = plt.gca()
  16. ax.set_aspect('equal')
  17. plt.legend(loc='lower right')
  18. plt.savefig("./img/multi_roc.png")
  19. y_pred = model.predict(x_test)
  20. test_predictions = np.argmax(y_pred, axis=1)
  21. y_true=y_test.astype('uint8')
  22. test_scores = 1-(test_predictions - test_predictions.min())/(test_predictions.max() - test_predictions.min())
  23. mpl.rcParams['figure.figsize'] = (12, 10)
  24. colors = plt.rcParams['axes.prop_cycle'].by_key()['color']
  25. plot_roc("My Model", y_test, test_scores, color=colors[0])
  26. acc = accuracy_score(y_true,test_predictions)
  27. recall = recall_score(y_true,test_predictions,average='micro')
  28. precision = precision_score(y_true,test_predictions,average='micro')
  29. f1 = f1_score(y_true,test_predictions,average='micro')
  30. print('accuracy: ',acc)
  31. print('recall: ',recall)
  32. print('precision: ',precision)
  33. print('f1: ',f1)

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/盐析白兔/article/detail/142875
推荐阅读
相关标签
  

闽ICP备14008679号