当前位置:   article > 正文

混淆矩阵的生成(python实现,含机器学习方法)_python混淆矩阵

python混淆矩阵

混淆矩阵(Confusion Matrix)是用于评估分类模型性能的一种表格形式。它显示了在分类问题中模型的预测结果与实际标签之间的各种组合情况。

混淆矩阵通常用于二分类问题,但也可以扩展到多分类问题。对于二分类问题,它由四个重要的指标组成:

真正例(True Positive, TP):模型预测为正例,并且实际上是正例的数量。
真反例(True Negative, TN):模型预测为反例,并且实际上是反例的数量。
假正例(False Positive, FP):模型预测为正例,但实际上是反例的数量。也称为"误报"。
假反例(False Negative, FN):模型预测为反例,但实际上是正例的数量。也称为"漏报"。

混淆矩阵的一般形式如下:
在这里插入图片描述

使用混淆矩阵可以计算多个衡量分类器性能的指标,如准确率(Accuracy)、精确率(Precision)、召回率(Recall,也称为敏感度或真正例率)和 F1 值等。这些指标可以通过混淆矩阵中的各个元素计算得出:

准确率(Accuracy):分类器预测正确的样本占总样本数的比例,计算公式为 (TP + TN) / (TP + TN + FP + FN) 。
精确率(Precision):正例预测正确的比例,计算公式为 TP / (TP + FP) 。
召回率(Recall):正例被正确预测为正例的比例,计算公式为 TP / (TP + FN) 。
F1 值:综合考虑了精确率和召回率的指标,计算公式为 2 (Precision Recall) / (Precision + Recall) 。

混淆矩阵提供了更详细和全面地评估分类模型性能的能力,帮助我们了解预测中的误报和漏报情况。通过分析混淆矩阵,我们可以获得对分类器在每个类别上的表现有关的宝贵见解,并对分类结果进行优化。

废话不多数,上代码:

def draw_confusion_matrix(label_true, label_pred, label_name, normlize, title="Confusion Matrix", pdf_save_path=None, dpi=100):
    """

    @param label_true: 真实标签,比如[0,1,2,7,4,5,...]
    @param label_pred: 预测标签,比如[0,5,4,2,1,4,...]
    @param label_name: 标签名字,比如['cat','dog','flower',...]
    @param normlize: 是否设元素为百分比形式
    @param title: 图标题
    @param pdf_save_path: 是否保存,是则为保存路径pdf_save_path=xxx.png | xxx.pdf | ...等其他plt.savefig支持的保存格式
    @param dpi: 保存到文件的分辨率,论文一般要求至少300dpi
    @return:

    example:
            draw_confusion_matrix(label_true=y_gt,
                          label_pred=y_pred,
                          label_name=["Angry", "Disgust", "Fear", "Happy", "Sad", "Surprise", "Neutral"],
                          normlize=True,
                          title="Confusion Matrix on Fer2013",
                          pdf_save_path="Confusion_Matrix_on_Fer2013.png",
                          dpi=300)

    """
    cm1=confusion_matrix(label_true, label_pred)
    cm = confusion_matrix(label_true, label_pred)
    if normlize:
        row_sums = np.sum(cm, axis=1)
        cm = cm / row_sums[:, np.newaxis]
    cm=cm.T
    cm1=cm1.T
    plt.imshow(cm, cmap='Blues')
    plt.title(title)
    plt.xlabel("Predict label")
    plt.ylabel("Truth label")
    plt.yticks(range(label_name.__len__()), label_name)
    plt.xticks(range(label_name.__len__()), label_name, rotation=45)

    plt.tight_layout()

    plt.colorbar()

    for i in range(label_name.__len__()):
        for j in range(label_name.__len__()):
            color = (1, 1, 1) if i == j else (0, 0, 0)	# 对角线字体白色,其他黑色
            value = float(format('%.1f' % (cm[i, j]*100)))
            value1=str(value)+'%\n'+str(cm1[i, j])
            plt.text(i, j, value1, verticalalignment='center', horizontalalignment='center', color=color)

    # plt.show()
    if not pdf_save_path is None:
        plt.savefig(pdf_save_path, bbox_inches='tight',dpi=dpi)



labels_name = ['bananaquit', 'Black Skimmer', 'Black Throated Bushtiti', 'Cockatoo']

y_gt=[]
y_pred=[]

model_weight_path = "./best_CBAM_model.pth"
models = Xception(num_classes = 4)
models.load_state_dict(torch.load(model_weight_path))




models.eval()
for index, (imgs, labels) in enumerate(test_dl):
    labels_pd = models(imgs)
    predict_np = np.argmax(labels_pd.cpu().detach().numpy(), axis=-1).tolist()
    labels_np = labels.numpy().tolist()

    y_pred.extend(predict_np)
    y_gt.extend(labels_np)
print("预测标签为:", y_pred)
print("真实标签为", y_gt)



draw_confusion_matrix(label_true=y_gt,
                      label_pred=y_pred,
                      label_name=labels_name,
                      normlize=True,
                      title="Confusion Matrix",
                      pdf_save_path="Confusion_Matrix.jpg",
                      dpi=300)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85

结果如下:
在这里插入图片描述

更新

这里大佬给我提供了一种更加简单的混淆矩阵生成方法,是基于机器学习sklearn里面的库confusion_matrix和seaborn 库。原文参考深度学习100例-卷积神经网络(CNN)识别眼睛状态 | 第17天

Seaborn是一个基于Matplotlib的Python数据可视化库,它提供了一种高层次的接口,用于绘制具有吸引力和信息丰富的统计图形。Seaborn的设计目标是使得数据可视化更加简单,同时也更具吸引力,以便更好地理解和传达数据的含义。它具有许多内置的图表类型和样式,可用于探索数据分布、比较多个变量之间的关系、绘制分类数据以及在统计模型中的可视化等。它还提供了许多自定义选项,使您能够根据自己的需求进行图形的修改和美化。

而sklearn.metrics是Scikit-learn库中的一个模块,用于评估机器学习模型的性能和预测结果。这个模块提供了各种用于计算模型性能指标(如准确度、精确度、召回率、F1值等)的函数,以及用于绘制混淆矩阵、ROC曲线、学习曲线等的工具函数。

混淆矩阵是衡量分类模型性能的一种方法,它以矩阵形式表示了模型预测结果与真实标签之间的差异。混淆矩阵的行表示真实标签,列表示预测标签,每个单元格中的值表示对应标签的样本数量。通过分析混淆矩阵,我们可以得出模型的准确性、错误类型和偏差等信息。

通过Seaborn库的heatmap函数,我们可以将混淆矩阵可视化为一个热力图,更直观地展示模型预测结果的分布情况。热力图的每个单元格的颜色深浅表示对应标签的样本数量或其他统计指标。

from sklearn.metrics import confusion_matrix
import seaborn as sns
import pandas as pd

# 定义一个绘制混淆矩阵图的函数
def plot_cm(labels, predictions):
    
    # 生成混淆矩阵
    conf_numpy = confusion_matrix(labels, predictions)
    # 将矩阵转化为 DataFrame
    conf_df = pd.DataFrame(conf_numpy, index=class_names ,columns=class_names)  
    
    plt.figure(figsize=(8,7))
    
    sns.heatmap(conf_df, annot=True, fmt="d", cmap="BuPu")
    
    plt.title('混淆矩阵',fontsize=15)
    plt.ylabel('真实值',fontsize=14)
    plt.xlabel('预测值',fontsize=14)

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
val_pre   = []
val_label = []

for images, labels in val_ds:#这里可以取部分验证数据(.take(1))生成混淆矩阵
    for image, label in zip(images, labels):
        # 需要给图片增加一个维度
        img_array = tf.expand_dims(image, 0) 
        # 使用模型预测图片中的人物
        prediction = model.predict(img_array)

        val_pre.append(class_names[np.argmax(prediction)])
        val_label.append(class_names[label])

plot_cm(val_label, val_pre)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14

在这里插入图片描述

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/笔触狂放9/article/detail/151741?site
推荐阅读
相关标签
  

闽ICP备14008679号