当前位置:   article > 正文

pytorch实现AlexNet,在mnist数据集上实验,用精确率、召回率等指标评估,并绘制PR、ROC曲线_pytorch alexnet mnist

pytorch alexnet mnist

一、导入需要的模块

  1. import torch
  2. import prettytable
  3. import numpy as np
  4. import matplotlib.pyplot as plt
  5. from torch import nn
  6. from torch.utils.data import Dataset
  7. from torch.utils import data
  8. from torchvision import datasets
  9. from torchvision.transforms import ToTensor
  10. from torchvision import transforms
  11. from torch.utils.data import DataLoader
  12. from torch.nn import functional as F
  13. from torchsummary import summary
  14. from torch.utils.tensorboard import SummaryWriter

二、准备数据集

  1. #获取数据集
  2. trans = [transforms.ToTensor()]
  3. trans.insert(0, transforms.Resize(224))
  4. trans = transforms.Compose(trans)
  5. batch_size = 256
  6. training_data = datasets.MNIST(
  7. root="./data",
  8. train=True,
  9. download=True,
  10. transform=trans
  11. )
  12. test_data = datasets.MNIST(
  13. root="./data",
  14. train=False,
  15. download=True,
  16. transform=trans
  17. )
  18. train_iter = data.DataLoader(training_data, batch_size, shuffle=True,
  19. num_workers=2)
  20. test_iter = data.DataLoader(test_data, batch_size, shuffle=False,
  21. num_workers=2)
  22. train_features, train_labels = next(iter(train_iter))
  23. print(f"Feature batch shape: {train_features.size()}")
  24. print(f"Labels batch shape: {train_labels.size()}")
  25. print(f"batch size:{len(iter(train_iter))}")

输出:

  1. Feature batch shape: torch.Size([256, 1, 224, 224])
  2. Labels batch shape: torch.Size([256])
  3. batch size:235

三、数据集可视化

  1. #随机展示训练集中的九张图片
  2. figure = plt.figure(figsize=(8, 8))
  3. sample_idx = torch.randint(len(training_data), size=(9,))
  4. row, column = 0, 0
  5. for i, pict_index in enumerate(sample_idx):
  6. img, label = training_data[i]
  7. figure.add_subplot(3, 3, i+1)
  8. plt.title(str(label))
  9. plt.imshow(img.squeeze(), cmap="gray")
  10. plt.show()

输出:

 四、定义AlexNet

这里严格按照Alex Krizhevsky的论文《ImageNet Classification with Deep Convolutional Neural Networks》定义AlexNet。

当然如果想要省事,也可以直接从torchvision中导入~~~

  1. class AlexNet(nn.Module):
  2. def __init__(self):
  3. super(AlexNet, self).__init__()
  4. self.conv1 = nn.Conv2d(1, 96, kernel_size=(11, 11), stride=4, padding=2)
  5. self.maxpool1 = nn.MaxPool2d(kernel_size=(3, 3), stride=2)
  6. self.conv2 = nn.Conv2d(96, 256, kernel_size=(5, 5), stride=1, padding=2)
  7. self.maxpool2 = nn.MaxPool2d(kernel_size=(3, 3), stride=2)
  8. self.conv3 = nn.Conv2d(256, 384, kernel_size=(3, 3), padding=1)
  9. self.conv4 = nn.Conv2d(384, 384, kernel_size=(3, 3), padding=1)
  10. self.conv5 = nn.Conv2d(384, 256, kernel_size=(3, 3), padding=1)
  11. self.maxpool3 = nn.MaxPool2d(kernel_size=(3, 3), stride=2)
  12. self.flatten = nn.Flatten()
  13. self.linear1 = nn.Linear(9216, 4096)
  14. self.dropout1 = nn.Dropout(0.5)
  15. self.linear2 = nn.Linear(4096, 4096)
  16. self.dropout2 = nn.Dropout(0.5)
  17. self.linear3 = nn.Linear(4096, 10)
  18. def forward(self, x):
  19. out_conv1 = F.relu(self.conv1(x))
  20. out_pool1 = self.maxpool1(out_conv1)
  21. out_conv2 = F.relu(self.conv2(out_pool1))
  22. out_pool2 = self.maxpool2(out_conv2)
  23. out_conv3 = F.relu(self.conv3(out_pool2))
  24. out_conv4 = F.relu(self.conv4(out_conv3))
  25. out_conv5 = F.relu(self.conv5(out_conv4))
  26. out_pool3 = self.maxpool3(out_conv5)
  27. flatten_x = self.flatten(out_pool3)
  28. out_linear1 = F.relu(self.linear1(flatten_x))
  29. out_dropout1 = self.dropout1(out_linear1)
  30. out_linear2 = F.relu(self.linear2(out_dropout1))
  31. out_dropout2 = F.relu(out_linear2)
  32. out_linear3 = F.relu(self.linear3(out_dropout2))
  33. return out_linear3

五、定义训练循环、测试循环、初始化超参数

  1. #定义超参数,采用SGD作为优化器
  2. learning_rate = 0.001
  3. batch_size = 256
  4. optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
  5. device = 'cuda' if torch.cuda.is_available() else 'cpu'
  6. loss_fn = nn.CrossEntropyLoss()
  7. model.to(device)
  8. loss_list = []
  9. acc_list = []
  10. epoch_num = []
  11. def init_weights(m):
  12. if type(m) == nn.Linear or type(m) == nn.Conv2d:
  13. nn.init.xavier_uniform_(m.weight)
  14. #定义训练循环和测试循环
  15. def train_loop(dataloader, model, loss_fn, optimizer, epoch):
  16. size = len(dataloader.dataset)
  17. for t in range(epoch):
  18. print(f"Epoch {t+1}\n-------------------------------")
  19. running_loss = 0
  20. for batch, (X, y) in enumerate(dataloader):
  21. X, y = X.to(device), y.to(device)
  22. pred = model(X)
  23. loss = loss_fn(pred, y)
  24. running_loss += loss
  25. # Backpropagation
  26. optimizer.zero_grad()
  27. loss.backward()
  28. optimizer.step()
  29. if batch % 50 == 49:
  30. writer.add_scalar('training loss',
  31. running_loss / 50,
  32. epoch * len(dataloader)+batch+1)
  33. loss, current = loss.item(), (batch+1) * len(X)
  34. loss_list.append(loss), epoch_num.append(t+current/size)
  35. print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]")
  36. running_loss = 0
  37. test_loop(test_iter, model, loss_fn)
  38. def test_loop(dataloader, model, loss_fn):
  39. size = len(dataloader.dataset)
  40. num_batches = len(dataloader)
  41. test_loss, correct = 0, 0
  42. with torch.no_grad():
  43. for X, y in dataloader:
  44. X, y = X.to(device), y.to(device)
  45. pred = model(X)
  46. test_loss += loss_fn(pred, y).item()
  47. correct += (pred.argmax(1) == y).type(torch.float).sum().item()
  48. test_loss /= num_batches
  49. correct /= size
  50. acc_list.append(correct)
  51. print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f}")

六、建模完成,开始训练!

  1. model.apply(init_weights)
  2. writer = SummaryWriter()
  3. train_loop(train_iter, model, loss_fn, optimizer, 30)

输出:

  1. Epoch 1
  2. -------------------------------
  3. loss: 2.303341 [12800/60000]
  4. loss: 2.303362 [25600/60000]
  5. loss: 2.300716 [38400/60000]
  6. loss: 2.300808 [51200/60000]
  7. Test Error:
  8. Accuracy: 11.5%, Avg loss: 2.300705
  9. .........
  10. .........
  11. .........
  12. Epoch 30
  13. -------------------------------
  14. loss: 0.075750 [12800/60000]
  15. loss: 0.073634 [25600/60000]
  16. loss: 0.110787 [38400/60000]
  17. loss: 0.061658 [51200/60000]
  18. Test Error:
  19. Accuracy: 97.4%, Avg loss: 0.081114

七、模型评估及可视化

1、loss、accuracy曲线

  1. #保存模型
  2. torch.save(model.state_dict(), 'MnistOnAlexNet_epoch30.pkl')
  3. #绘制损失和准确度曲线
  4. plt.title('Loss and Accuracy')
  5. plt.xlabel('epoch')
  6. plt.plot(epoch_num, loss_list, 'yellow')
  7. plt.plot(range(30), acc_list, 'cyan')
  8. plt.legend(['Loss', 'Accuracy'])
  9. plt.show()

结果:

2、输出精确率、召回率

  1. #在测试集上评估模型
  2. model.eval()
  3. model.to('cpu')
  4. pred_list = torch.tensor([])
  5. with torch.no_grad():
  6. for X, y in test_iter:
  7. pred = model(X)
  8. pred_list = torch.cat([pred_list, pred])
  9. test_iter1 = data.DataLoader(test_data, batch_size=10000, shuffle=False,
  10. num_workers=2)
  11. features, labels = next(iter(test_iter1))
  12. print(labels.shape)
  1. #输出每个类别的精确率和召回率
  2. train_result = np.zeros((10, 10), dtype=int)
  3. for i in range(len(test_data)):
  4. train_result[labels[i]][np.argmax(pred_list[i])] += 1
  5. result_table = prettytable.PrettyTable()
  6. result_table.field_names = ['Type', 'Accuracy(精确率)', 'Recall(召回率)', 'F1_Score']
  7. class_names = ['Zero', 'One', 'Two', 'Three', 'Four', 'Five', 'Six', 'Seven', 'Eight', 'Nine']
  8. for i in range(10):
  9. accuracy = train_result[i][i] / train_result.sum(axis=0)[i]
  10. recall = train_result[i][i] / train_result.sum(axis=1)[i]
  11. result_table.add_row([class_names[i], np.round(accuracy, 3), np.round(recall, 3),
  12. np.round(accuracy * recall * 2 / (accuracy + recall), 3)])
  13. print(result_table)

结果:

  1. +-------+------------------+----------------+----------+
  2. | Type | Accuracy(精确率) | Recall(召回率) | F1_Score |
  3. +-------+------------------+----------------+----------+
  4. | Zero | 0.972 | 0.993 | 0.982 |
  5. | One | 0.991 | 0.985 | 0.988 |
  6. | Two | 0.983 | 0.976 | 0.98 |
  7. | Three | 0.966 | 0.984 | 0.975 |
  8. | Four | 0.994 | 0.966 | 0.98 |
  9. | Five | 0.994 | 0.97 | 0.982 |
  10. | Six | 0.988 | 0.981 | 0.985 |
  11. | Seven | 0.982 | 0.965 | 0.974 |
  12. | Eight | 0.954 | 0.983 | 0.968 |
  13. | Nine | 0.953 | 0.972 | 0.963 |
  14. +-------+------------------+----------------+----------+

3、对十个类别分别绘制PR曲线和ROC曲线

  1. #采用scikit-learn模块对10个类分别绘制PR曲线和ROC曲线
  2. from sklearn.metrics import precision_recall_curve, roc_curve
  3. for i in range(10):
  4. temp_true = []
  5. temp_probilities = []
  6. temp = 0
  7. for j in range(len(labels)):
  8. if i == labels[j]:
  9. temp = 1
  10. else:
  11. temp = 0
  12. temp_true.append(temp)
  13. temp_probilities.append(pred_probilities[j][i])
  14. precision, recall, threshholds = precision_recall_curve(temp_true, temp_probilities)
  15. fpr, tpr, thresholds = roc_curve(temp_true, temp_probilities)
  16. plt.figure(figsize=(12, 6))
  17. plt.subplot(1, 2, 1)
  18. plt.xlabel('Precision')
  19. plt.ylabel('Recall')
  20. plt.title(f'Precision & Recall Curve (class:{i}) ')
  21. plt.plot(precision, recall, 'yellow')
  22. plt.subplot(1, 2, 2)
  23. plt.xlabel('Fpr')
  24. plt.ylabel('Tpr')
  25. plt.title(f'Roc Curve (class:{i})')
  26. plt.plot(fpr, tpr, 'cyan')
  27. plt.show()

结果:

第1类(数字1)的PR、ROC曲线

 可以看到非常完美!

其他九个类别(2-9)也是一样的,每个类别都对应一张PR曲线图和ROC曲线图,这里因为篇幅原因就不放了。

代码完整版可以看github,数据集和预训练权重可以查看release分支:

https://github.com/tortorish/Pytorch_AlexNet_Mnist

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/我家自动化/article/detail/546212
推荐阅读
相关标签
  

闽ICP备14008679号