赞
踩
MNIST数据集是一个广泛使用的手写数字数据集,由美国国家标准与技术研究所(NIST)发起并整理。这个数据集包含了来自250个不同的人手写数字的图片,其中一半是高中生,另一半来自人口普查局的工作人员。主要目的是通过算法实现对手写数字的识别。
MNIST数据集一共包含了70000张图像,其中60000张用于训练,10000张用于测试。每张图像都是28×28像素的灰度图像,代表一个手写数字,范围从0到9。每张图像都附带一个标签,表示该图像上写的是哪个数字。
可以看到,效果相当好,比之前的线性模型好很多。
import torch import numpy as np import torch.nn as nn import torch.optim as optim from torchvision import datasets, transforms from torch.utils.data import DataLoader from sklearn.metrics import accuracy_score, precision_score, recall_score # 1. 数据加载与预处理 transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,)) ]) trainset = datasets.MNIST('./MNIST_data/', download=True, train=True, transform=transform) trainloader = DataLoader(trainset, batch_size=64, shuffle=True) testset = datasets.MNIST('./MNIST_data/', download=True, train=False, transform=transform) testloader = DataLoader(testset, batch_size=64, shuffle=False) # 2. 定义卷积神经网络模型 class Net(nn.Module): def __init__(self): super(Net, self).__init__() self.conv1 = nn.Conv2d(1, 32, 3, 1) self.conv2 = nn.Conv2d(32, 64, 3, 1) self.dropout1 = nn.Dropout2d(0.25) self.dropout2 = nn.Dropout2d(0.5) self.fc1 = nn.Linear(9216, 128) self.fc2 = nn.Linear(128, 10) def forward(self, x): x = self.conv1(x) x = nn.functional.relu(x) x = self.conv2(x) x = nn.functional.relu(x) x = nn.functional.max_pool2d(x, 2) x = self.dropout1(x) x = torch.flatten(x, 1) x = self.fc1(x) x = nn.functional.relu(x) x = self.dropout2(x) x = self.fc2(x) output = nn.functional.log_softmax(x, dim=1) return output # 初始化网络 net = Net() # 定义损失函数和优化器 criterion = nn.NLLLoss() optimizer = optim.Adam(net.parameters(), lr=0.001) # 3. 训练网络 device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") net = net.to(device) for epoch in range(5): # loop over the dataset multiple times running_loss = 0.0 for i, data in enumerate(trainloader, 0): # get the inputs; data is a list of [inputs, labels] inputs, labels = data[0].to(device), data[1].to(device) # zero the parameter gradients optimizer.zero_grad() # forward + backward + optimize outputs = net(inputs) loss = criterion(outputs, labels) loss.backward() optimizer.step() # print statistics running_loss += loss.item() if i % 2000 == 1999: # print every 2000 mini-batches print('[%d, %5d] loss: %.3f' % (epoch + 1, i + 1, running_loss / 2000)) running_loss = 0.0 print('Finished Training') # 4. 测试网络 correct = 0 total = 0 with torch.no_grad(): for data in testloader: images, labels = data[0].to(device), data[1].to(device) outputs = net(images) # 获取模型的输出和预测类别 _, predicted = torch.max(outputs.data, 1) total += labels.size(0) correct += (predicted == labels).sum().item() # 计算准确率 accuracy = 100 * correct / total print('Accuracy of the network on the 10000 test images: %d %%' % (accuracy)) # 获取测试集的真实标签和预测标签 test_labels = [] test_preds = [] with torch.no_grad(): for data in testloader: images, labels = data[0].to(device), data[1].to(device) outputs = net(images) _, predicted = torch.max(outputs, 1) test_labels.extend(labels.cpu().numpy()) test_preds.extend(predicted.cpu().numpy()) # 将标签和预测转换为numpy数组 test_labels = np.array(test_labels) test_preds = np.array(test_preds) # 计算精准率和召回率 precision = precision_score(test_labels, test_preds, average='weighted') recall = recall_score(test_labels, test_preds, average='weighted') print('Precision: %.3f' % precision) print('Recall: %.3f' % recall)
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。