当前位置:   article > 正文

Pytorch学习基础——LeNet从训练到测试_for step, (b_x, b_y) in enumerate(train_loader):

for step, (b_x, b_y) in enumerate(train_loader):

在上一篇Pytorch学习基础——CNN基本结构搭建中介绍了如何使用Pytorch.nn类搭建网络模型,结合MNIST数据集进行训练测试。

实现步骤:

  • 导入必要的包并设置超参数:
  1. import torch
  2. import torchvision
  3. import torch.nn as nn
  4. from torch.autograd import Variable
  5. import torchvision.datasets as dsets
  6. import torchvision.transforms as transforms
  7. import matplotlib.pyplot as plt
  8. #define hyperparameter
  9. EPOCH = 1
  10. BATCH_SIZE = 64
  11. TIME_STEP = 28 #time_step / image_height
  12. INPUT_SIZE = 28 #input_step / image_width
  13. LR = 0.01
  14. DOWNLOAD = False
  • 获取并加载数据集:下载MNIST到当前目录下,转换数据为Tensor张量;使用DataLoader转换为torch批次加载的形式;
  1. train_data = dsets.MNIST(root='./', train=True, transform=torchvision.transforms.ToTensor(), download=True)
  2. test_data = dsets.MNIST(root='./', train=False, transform=torchvision.transforms.ToTensor())
  3. test_x = Variable(torch.unsqueeze(test_data.test_data, dim=1),volatile = True).type(torch.FloatTensor)[:2000]/255
  4. test_y = test_data.test_labels[:2000]
  5. #use dataloader to batch input dateset
  6. train_loader = torch.utils.data.DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)
  • 定义并实例化网络模型:
  1. #define the RNN class
  2. class LeNet(nn.Module):
  3. #overload __init__() method
  4. def __init__(self):
  5. super(LeNet, self).__init__()
  6. self.layer1 = nn.Sequential(
  7. nn.Conv2d(1, 25, kernel_size=3),
  8. nn.BatchNorm2d(25),
  9. nn.ReLU(True),
  10. nn.MaxPool2d(kernel_size=2, stride=2),
  11. )
  12. self.layer2 = nn.Sequential(
  13. nn.Conv2d(25, 50, kernel_size=3),
  14. nn.BatchNorm2d(50),
  15. nn.ReLU(True),
  16. nn.MaxPool2d(kernel_size=2, stride=2),
  17. )
  18. self.classifier = nn.Sequential(
  19. nn.Linear(50*5*5, 1024),
  20. nn.ReLU(True),
  21. nn.Linear(1024, 128),
  22. nn.ReLU(True),
  23. nn.Linear(128, 10),
  24. )
  25. #overload forward() method
  26. def forward(self, x):
  27. out = self.layer1(x)
  28. out = self.layer2(out)
  29. out = out.view(out.size(0), -1)
  30. out = self.classifier(out)
  31. return out
  32. cnn = LeNet()
  33. print(cnn)
  • 定义优化器和损失函数:
  1. #define optimizer with Adam optim
  2. optimizer = torch.optim.Adam(cnn.parameters(), lr=LR)
  3. #define cross entropy loss function
  4. loss_func = nn.CrossEntropyLoss()
  • 训练模型:
  1. epoch = 0
  2. #training and testing
  3. for epoch in range(EPOCH):
  4. for step, (b_x, b_y) in enumerate(train_loader):
  5. b_x = Variable(b_x)
  6. b_y = Variable(b_y)
  7. output = cnn(b_x)
  8. loss = loss_func(output, b_y)
  9. optimizer.zero_grad()
  10. loss.backward()
  11. optimizer.step()
  12. if step % 50 == 0:
  13. test_output = cnn(test_x)
  14. pred_y = torch.max(test_output, 1)[1].data.squeeze()
  15. acc = float((pred_y == test_y).sum()) / float(test_y.size(0))
  16. print('Epoch: ', epoch, '| train loss: %.3f' % loss.data.numpy(), '| test accuracy: %.3f' % acc)
  17. print('Training ending')
  • 验证模型:
  1. # print 100 predictions from test data
  2. numTest = 100
  3. test_output = cnn(test_x[:numTest])
  4. pred_y = torch.max(test_output, 1)[1].data.numpy().squeeze()
  5. print(pred_y, 'prediction number')
  6. print(test_y[:numTest], 'real number')
  7. ErrorCount = 0.0
  8. for i in pred_y:
  9. if pred_y[i] != test_y[i]:
  10. ErrorCount += 1
  11. print('ErrorRate : %.3f'%(ErrorCount / numTest))

实验结果:

可以看到,对于简单的MNIST手写数字数据集,LeNet在较低训练时间内即能完成准确识别,从而证实了神经网络的高效识别能力,为大型数据集的识别分类提供了参考和借鉴。 

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/小蓝xlanll/article/detail/101859
推荐阅读
相关标签
  

闽ICP备14008679号