赞
踩
import os import sys import json import torch import torch.nn as nn from torchvision import transforms, datasets import torch.optim as optim from tqdm import tqdm from model import AlexNet # 假设你的模型定义在model.py中 def validate(model, dataloader, device): # 将模型设置为评估模式 model.eval() # 定义总体准确率的累积变量 total_correct = 0 total_samples = 0 # 定义类别准确率字典 class_correct = {i: 0 for i in range(len(dataloader.dataset.classes))} class_total = {i: 0 for i in range(len(dataloader.dataset.classes))} class_precision = {i: 0 for i in range(len(dataloader.dataset.classes))} class_recall = {i: 0 for i in range(len(dataloader.dataset.classes))} class_f1 = {i: 0 for i in range(len(dataloader.dataset.classes))} with torch.no_grad(): for inputs, labels in dataloader: inputs, labels = inputs.to(device), labels.to(device) outputs = model(inputs) _, predicted = torch.max(outputs, 1) c = (predicted == labels).squeeze() total_correct += c.sum().item() total_samples += labels.size(0) for i in range(len(labels)): label = labels[i].item() # 将标签转换为整数类型 class_correct[label] += c[i].item() class_total[label] += 1 # 计算每个类别的精确率、召回率和 F1-Score for i in range(len(dataloader.dataset.classes)): if class_total[i] > 0: class_precision[i] = class_correct[i] / class_total[i] class_recall[i] = class_correct[i] / len(dataloader.dataset.targets) class_f1[i] = 2 * (class_precision[i] * class_recall[i]) / (class_precision[i] + class_recall[i]) # 计算总体精确率、召回率和 F1-Score overall_precision = sum(class_precision.values()) / len(class_precision) overall_recall = sum(class_recall.values()) / len(class_recall) overall_f1 = sum(class_f1.values()) / len(class_f1) overall_accuracy = total_correct / total_samples print("Overall accuracy: {:.2f}%".format(overall_accuracy * 100)) print("Overall precision: {:.2f}".format(overall_precision)) print("Overall recall: {:.2f}".format(overall_recall)) print("Overall F1-Score: {:.2f}".format(overall_f1)) # 输出每个类别的精确率、召回率和 F1-Score for i in range(len(dataloader.dataset.classes)): print('Class: %5s Precision: %.2f%% Recall: %.2f%% F1-Score: %.2f' % (dataloader.dataset.classes[i], class_precision[i] * 100, class_recall[i] * 100, class_f1[i] * 100)) def main(): device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") print("Using {} device.".format(device)) data_transform = { "val": transforms.Compose([transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])} image_path = "E:\\wafer_data\\wafer_27" # 修改为你的数据集路径 assert os.path.exists(image_path), "{} path does not exist.".format(image_path) validate_dataset = datasets.ImageFolder(root=os.path.join(image_path, "val"), transform=data_transform["val"]) validate_loader = torch.utils.data.DataLoader(validate_dataset, batch_size=4, shuffle=False) # 加载模型 net = AlexNet(num_classes=8, init_weights=False) # 注意:此处要设置为False,因为我们将加载预训练权重 net.load_state_dict(torch.load('AlexNet.pth')) # 修改为你的模型路径 net.to(device) # 在验证集上验证每个类别的分类准确率 print("Validation accuracy for each class:") validate(net, validate_loader, device) if __name__ == '__main__': main()
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。