赞
踩
import torch import torch.nn as nn import torch.optim as optim import torchvision.transforms as transforms from torchvision.datasets import MNIST from torch.utils.data import DataLoader from torchvision import models import matplotlib.pyplot as plt # 定义超参数 batch_size = 240 learning_rate = 0.001 num_epochs = 10 # 数据预处理,包括调整图像大小并将单通道图像复制到三个通道 transform = transforms.Compose([ transforms.Resize(224), # 调整图像大小以适应EfficientNetB0 transforms.Grayscale(num_output_channels=3), # 将单通道图像复制到三个通道 transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # 使用ImageNet的均值和标准差 ]) # 加载数据集 train_dataset = MNIST(root='./data', train=True, transform=transform, download=True) test_dataset = MNIST(root='./data', train=False, transform=transform, download=True) # 创建数据加载器 train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, num_workers=32) test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False, num_workers=32) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # 加载预训练的EfficientNetB0模型并调整最后的分类层 model = models.efficientnet_b0(pretrained=True) model.classifier[1] = nn.Linear(model.classifier[1].in_features, 10) # MNIST共10个类别 model.to(device) # 定义损失函数和优化器 criterion = nn.CrossEntropyLoss() optimizer = optim.Adam(model.parameters(), lr=learning_rate) # 用于绘图的数据 train_losses = [] test_accuracies = [] # 训练模型 for epoch in range(num_epochs): model.train() running_loss = 0.0 for batch_idx, (data, target) in enumerate(train_loader): optimizer.zero_grad() data, target = data.to(device), target.to(device) output = model(data) loss = criterion(output, target) loss.backward() optimizer.step() running_loss += loss.item() print(f"\rEpoch {epoch + 1}/{num_epochs}, Batch {batch_idx + 1}/{len(train_loader)}, Loss: {loss.item():.4f}") # 计算平均损失 avg_loss = running_loss / len(train_loader) train_losses.append(avg_loss) # 测试准确率 model.eval() correct = 0 total = 0 with torch.no_grad(): for data, target in test_loader: data, target = data.to(device), target.to(device) # Move test data to the correct device output = model(data) _, predicted = torch.max(output.data, 1) total += target.size(0) correct += (predicted == target).sum().item() accuracy = 100 * correct / total test_accuracies.append(accuracy) print(f'Epoch {epoch + 1}/{num_epochs}, Loss: {avg_loss:.4f}, Test Accuracy: {accuracy:.2f}%') # save torch.save(model.state_dict(), 'mnist_efficientnetb0.pth') # 绘制损失函数和准确率图 plt.figure(figsize=(12, 5)) plt.subplot(1, 2, 1) plt.plot(train_losses, label='Training Loss') plt.title('Training Loss') plt.xlabel('Epoch') plt.ylabel('Loss') plt.legend() plt.subplot(1, 2, 2) plt.plot(test_accuracies, label='Test Accuracy') plt.title('Test Accuracy') plt.xlabel('Epoch') plt.ylabel('Accuracy (%)') plt.legend() plt.show()
训练10轮,测试准确率很猛:
Epoch 10/10, Loss: 0.0087, Test Accuracy: 99.60%
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。