当前位置:   article > 正文

AlexNet计算每个类别的精确率、召回率和 F1-Score代码_每个类别的召回率

每个类别的召回率
import os
import sys
import json

import torch
import torch.nn as nn
from torchvision import transforms, datasets
import torch.optim as optim
from tqdm import tqdm

from model import AlexNet  # 假设你的模型定义在model.py中


def validate(model, dataloader, device):
    # 将模型设置为评估模式
    model.eval()

    # 定义总体准确率的累积变量
    total_correct = 0
    total_samples = 0

    # 定义类别准确率字典
    class_correct = {i: 0 for i in range(len(dataloader.dataset.classes))}
    class_total = {i: 0 for i in range(len(dataloader.dataset.classes))}
    class_precision = {i: 0 for i in range(len(dataloader.dataset.classes))}
    class_recall = {i: 0 for i in range(len(dataloader.dataset.classes))}
    class_f1 = {i: 0 for i in range(len(dataloader.dataset.classes))}

    with torch.no_grad():
        for inputs, labels in dataloader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs, 1)
            c = (predicted == labels).squeeze()
            total_correct += c.sum().item()
            total_samples += labels.size(0)
            for i in range(len(labels)):
                label = labels[i].item()  # 将标签转换为整数类型
                class_correct[label] += c[i].item()
                class_total[label] += 1

        # 计算每个类别的精确率、召回率和 F1-Score
        for i in range(len(dataloader.dataset.classes)):
            if class_total[i] > 0:
                class_precision[i] = class_correct[i] / class_total[i]
                class_recall[i] = class_correct[i] / len(dataloader.dataset.targets)
                class_f1[i] = 2 * (class_precision[i] * class_recall[i]) / (class_precision[i] + class_recall[i])

    # 计算总体精确率、召回率和 F1-Score
    overall_precision = sum(class_precision.values()) / len(class_precision)
    overall_recall = sum(class_recall.values()) / len(class_recall)
    overall_f1 = sum(class_f1.values()) / len(class_f1)
    overall_accuracy = total_correct / total_samples

    print("Overall accuracy: {:.2f}%".format(overall_accuracy * 100))
    print("Overall precision: {:.2f}".format(overall_precision))
    print("Overall recall: {:.2f}".format(overall_recall))
    print("Overall F1-Score: {:.2f}".format(overall_f1))

    # 输出每个类别的精确率、召回率和 F1-Score
    for i in range(len(dataloader.dataset.classes)):
        print('Class: %5s Precision: %.2f%% Recall: %.2f%% F1-Score: %.2f' %
              (dataloader.dataset.classes[i], class_precision[i] * 100,
               class_recall[i] * 100, class_f1[i] * 100))





def main():
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print("Using {} device.".format(device))

    data_transform = {
        "val": transforms.Compose([transforms.Resize((224, 224)),
                                   transforms.ToTensor(),
                                   transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])}

    image_path = "E:\\wafer_data\\wafer_27"  # 修改为你的数据集路径
    assert os.path.exists(image_path), "{} path does not exist.".format(image_path)

    validate_dataset = datasets.ImageFolder(root=os.path.join(image_path, "val"),
                                            transform=data_transform["val"])
    validate_loader = torch.utils.data.DataLoader(validate_dataset,
                                                  batch_size=4, shuffle=False)

    # 加载模型
    net = AlexNet(num_classes=8, init_weights=False)  # 注意:此处要设置为False,因为我们将加载预训练权重
    net.load_state_dict(torch.load('AlexNet.pth'))  # 修改为你的模型路径
    net.to(device)

    # 在验证集上验证每个类别的分类准确率
    print("Validation accuracy for each class:")
    validate(net, validate_loader, device)


if __name__ == '__main__':
    main()
本文内容由网友自发贡献,转载请注明出处:【wpsshop博客】
推荐阅读
  

闽ICP备14008679号