当前位置:   article > 正文

PyTorch搭建卷积神经网络(CNN)实现手写数字识别_基于pytorch搭建cnn实现手写数字识别

基于pytorch搭建cnn实现手写数字识别

1.卷积神经网络介绍

卷积神经网络(Convolutional Neural Networks, CNN)是一类包含卷积计算且具有深度结构的前馈神经网络(Feedforward Neural Networks),是深度学习(deep learning)的代表算法之一 。卷积神经网络具有表征学习(representation learning)能力,能够按其阶层结构对输入信息进行平移不变分类(shift-invariant classification),因此也被称为“平移不变人工神经网络(Shift-Invariant Artificial Neural Networks, SIANN)” 

2.卷积神经网络架构 

卷积神经网络主要包括卷积层,采样层(一般做最大池化)和全连接层(FC层)。

3.Pytorch实现卷积神经网络 

  • 卷积层:nn.Conv2d() 

其参数如下:

  • 池化层:nn.MaxPool2d()

其参数如下: 

4.实现MINST手写数字识别

一共定义了五层,其中两层卷积层,两层池化层,最后一层为FC层进行分类输出。其网络结构如下:

 具体的图片大小计算如下图:

 5.代码实现

  1. import torch
  2. from torchvision import transforms # 是一个常用的图片变换类
  3. from torchvision import datasets
  4. from torch.utils.data import DataLoader
  5. import torch.nn.functional as F
  6. batch_size = 64
  7. transform = transforms.Compose(
  8. [
  9. transforms.ToTensor(), # 把数据转换成张量
  10. transforms.Normalize((0.1307,), (0.3081,)) # 0.1307是均值,0.3081是标准差
  11. ]
  12. )
  13. train_dataset = datasets.MNIST(root='../dataset/mnist',
  14. train=True,
  15. download=True,
  16. transform=transform)
  17. train_loader = DataLoader(train_dataset,
  18. shuffle=True,
  19. batch_size=batch_size)
  20. test_dataset = datasets.MNIST(root='../dataset/mnist',
  21. train=False,
  22. download=True,
  23. transform=transform)
  24. test_loader = DataLoader(test_dataset,
  25. shuffle=True,
  26. batch_size=batch_size)
  27. class CNN(torch.nn.Module):
  28. def __init__(self):
  29. super(CNN, self).__init__()
  30. self.layer1 = torch.nn.Sequential(
  31. torch.nn.Conv2d(1, 25, kernel_size=3),
  32. torch.nn.BatchNorm2d(25),
  33. torch.nn.ReLU(inplace=True)
  34. )
  35. self.layer2 = torch.nn.Sequential(
  36. torch.nn.MaxPool2d(kernel_size=2, stride=2)
  37. )
  38. self.layer3 = torch.nn.Sequential(
  39. torch.nn.Conv2d(25, 50, kernel_size=3),
  40. torch.nn.BatchNorm2d(50),
  41. torch.nn.ReLU(inplace=True)
  42. )
  43. self.layer4 = torch.nn.Sequential(
  44. torch.nn.MaxPool2d(kernel_size=2, stride=2)
  45. )
  46. self.fc = torch.nn.Sequential(
  47. torch.nn.Linear(50 * 5 * 5, 1024),
  48. torch.nn.ReLU(inplace=True),
  49. torch.nn.Linear(1024, 128),
  50. torch.nn.ReLU(inplace=True),
  51. torch.nn.Linear(128, 10)
  52. )
  53. def forward(self, x):
  54. x = self.layer1(x)
  55. x = self.layer2(x)
  56. x = self.layer3(x)
  57. x = self.layer4(x)
  58. x = x.view(x.size(0), -1) # 在进入全连接层之前需要把数据拉直Flatten
  59. x = self.fc(x)
  60. return x
  61. model = CNN()
  62. # 下面两行代码主要是如果有GPU那么就使用GPU跑代码,否则就使用cpu。cuda:0表示第1块显卡
  63. device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # 将数据放在GPU上跑所需要的代码
  64. model.to(device) # 将数据放在GPU上跑所需要的代码
  65. criterion = torch.nn.CrossEntropyLoss() # 使用交叉熵损失
  66. optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.5) # momentum表示冲量,冲出局部最小
  67. def train(epochs):
  68. running_loss = 0.0
  69. for batch_idx, data in enumerate(train_loader, 0):
  70. inputs, target = data
  71. inputs, target = inputs.to(device), target.to(device) # 将数据放在GPU上跑所需要的代码
  72. optimizer.zero_grad()
  73. # 前馈+反馈+更新
  74. outputs = model(inputs)
  75. loss = criterion(outputs, target)
  76. loss.backward()
  77. optimizer.step()
  78. running_loss += loss.item()
  79. if batch_idx % 300 == 299: # 不让他每一次小的迭代就输出,而是300次小迭代再输出一次
  80. print('[%d,%5d] loss:%.3f' % (epoch + 1, batch_idx + 1, running_loss / 300))
  81. running_loss = 0.0
  82. torch.save(model, 'model_{}.pth'.format(epochs))
  83. def test():
  84. correct = 0
  85. total = 0
  86. with torch.no_grad(): # 下面的代码就不会再计算梯度
  87. for data in test_loader:
  88. inputs, target = data
  89. inputs, target = inputs.to(device), target.to(device) # 将数据放在GPU上跑所需要的代码
  90. outputs = model(inputs)
  91. _, predicted = torch.max(outputs.data, dim=1) # _为每一行的最大值,predicted表示每一行最大值的下标
  92. total += target.size(0)
  93. correct += (predicted == target).sum().item()
  94. print('Accuracy on test set:%d %%' % (100 * correct / total))
  95. if __name__ == '__main__':
  96. for epoch in range(10):
  97. train(epoch)
  98. test()

 6.结果

         

声明:本文内容由网友自发贡献,转载请注明出处:【wpsshop】
推荐阅读
相关标签
  

闽ICP备14008679号