赞
踩
import matplotlib.pyplot as plt
import numpy as np
classes = ['ang', 'hap', 'neu', 'sad']#标签列表
confusion_matrix = np.array(([91, 1, 4, 2], [6, 92, 2, 2], [2, 3, 92, 3], [8, 13, 4, 90]))#二维混淆矩阵
plt.imshow(confusion_matrix, interpolation='nearest', cmap=plt.cm.Oranges) # 按照像素显示出矩阵
plt.title('confusion_matrix')
plt.colorbar()
tick_marks = np.arange(len(classes))
plt.xticks(tick_marks, classes, rotation=0) # 倾斜
plt.yticks(tick_marks, classes)
thresh = confusion_matrix.max() / 2.
# iters = [[i,j] for i in range(len(classes)) for j in range((classes))]
# ij配对,遍历矩阵迭代器
iters = np.reshape([[[i, j] for j in range(4)] for i in range(4)], (confusion_matrix.size, 2))
for i, j in iters:
plt.text(j, i, format(confusion_matrix[i, j]), va='center', ha='center') # 显示对应的数字
plt.ylabel('Real label')
plt.xlabel('Prediction')
plt.tight_layout()
# plt.show()
plt.savefig('confusion_matrix2.png', format='png')
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。