当前位置:   article > 正文

pytorch实现个人手写数字识别_pytorch 测试自己的数字

pytorch 测试自己的数字

网上的大多数例子都是基于Mnist数据集进行测试的,今天实现一个自己手写数字的识别。

 

首先训练模型,使用Mnist数据集,网络的backbone采用LeNet。

1. 导入需要的模块并添加GPU设备

  1. import torch
  2. import torchvision as tv
  3. import torchvision.transforms as transforms
  4. import torch.nn as nn
  5. import torch.optim as optim
  6. import cv2
  7. # 定义是否使用GPU
  8. device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

2. 定义网络结构

  1. class LeNet(nn.Module):
  2. def __init__(self):
  3. super(LeNet, self).__init__()
  4. self.conv1 = nn.Sequential( # input_size=(1*28*28)
  5. nn.Conv2d(1, 6, 5, 1, 2), # padding=2保证输入输出尺寸相同
  6. nn.ReLU(), # input_size=(6*28*28)
  7. nn.MaxPool2d(kernel_size=2, stride=2), # output_size=(6*14*14)
  8. )
  9. self.conv2 = nn.Sequential(
  10. nn.Conv2d(6, 16, 5),
  11. nn.ReLU(), # input_size=(16*10*10)
  12. nn.MaxPool2d(2, 2) # output_size=(16*5*5)
  13. )
  14. self.fc1 = nn.Sequential(
  15. nn.Linear(16 * 5 * 5, 120),
  16. nn.ReLU()
  17. )
  18. self.fc2 = nn.Sequential(
  19. nn.Linear(120, 84),
  20. nn.ReLU()
  21. )
  22. self.fc3 = nn.Linear(84, 10)
  23. # 定义前向传播过程,输入为x
  24. def forward(self, x):
  25. x = self.conv1(x)
  26. x = self.conv2(x)
  27. # nn.Linear()的输入输出都是维度为一的值,所以要把多维度的tensor展平成一维(一行)
  28. x = x.view(x.size()[0], -1)
  29. x = self.fc1(x)
  30. x = self.fc2(x)
  31. x = self.fc3(x)
  32. return x

3. 设置超参数和定义训练和测试数据提取器

  1. # 超参数设置
  2. EPOCH = 10 # 遍历数据集次数
  3. BATCH_SIZE = 256 # 批处理尺寸(batch_size)
  4. LR = 0.001 # 学习率
  5. # 定义数据预处理方式
  6. transform = transforms.ToTensor()
  7. # 定义训练数据集
  8. trainset = tv.datasets.MNIST(
  9. root='./data/',
  10. train=True,
  11. download=False,
  12. transform=transform)
  13. # 定义训练批处理数据
  14. trainloader = torch.utils.data.DataLoader(
  15. trainset,
  16. batch_size=BATCH_SIZE,
  17. shuffle=True,
  18. )
  19. # 定义测试数据集
  20. testset = tv.datasets.MNIST(
  21. root='./data/',
  22. train=False,
  23. download=False,
  24. transform=transform)

4. 定义训练函数

  1. def train():
  2. # 定义损失函数loss function 和优化方式(采用SGD)
  3. net = LeNet().to(device)
  4. criterion = nn.CrossEntropyLoss() # 交叉熵损失函数,通常用于多分类问题上
  5. optimizer = optim.SGD(net.parameters(), lr=LR, momentum=0.9)
  6. for epoch in range(EPOCH):
  7. sum_loss = 0.0
  8. # 数据读取
  9. for i, data in enumerate(trainloader):
  10. inputs, labels = data
  11. inputs, labels = inputs.to(device), labels.to(device)
  12. # 梯度清零
  13. optimizer.zero_grad()
  14. # forward + backward
  15. outputs = net(inputs)
  16. loss = criterion(outputs, labels)
  17. loss.backward()
  18. optimizer.step()
  19. # 每训练100个batch打印一次平均loss
  20. sum_loss += loss.item()
  21. if i % 100 == 99:
  22. print('[%d, %d] loss: %.03f'
  23. % (epoch + 1, i + 1, sum_loss / 100))
  24. sum_loss = 0.0
  25. # 每跑完一次epoch测试一下准确率
  26. with torch.no_grad():
  27. correct = 0
  28. total = 0
  29. for data in testloader:
  30. images, labels = data
  31. images, labels = images.to(device), labels.to(device)
  32. outputs = net(images)
  33. # 取得分最高的那个类
  34. _, predicted = torch.max(outputs.data, 1)
  35. total += labels.size(0)
  36. correct += (predicted == labels).sum()
  37. print('第%d个epoch的识别准确率为:%d%%' % (epoch + 1, (100 * correct / total)))
  38. # 保存模型参数
  39. torch.save(net.state_dict(), './params.pth')

5. 先进行训练,训练结果会保存在params.pth中。

  1. if __name__ == "__main__":
  2. train()

6. 训练完成后注释掉训练函数,读取训练好的模型参数并进行测试。

  1. # 读取训练好的网络参数
  2. net = LeNet().to(device)
  3. a = torch.load('./params.pth')
  4. net.load_state_dict(torch.load('./params.pth'))
  5. if __name__ == "__main__":
  6. # train()
  7. img = cv2.imread('./2.png', cv2.IMREAD_GRAYSCALE) #读取图片
  8. img = cv2.resize(img,(28, 28)) # 调整图片为28*28
  9. img = torch.from_numpy(img).float()
  10. img = img.view(1, 1, 28, 28)
  11. img = img.to(device)
  12. outputs = net(img)
  13. _, predicted = torch.max(outputs.data, 1)
  14. print(predicted.to('cpu').numpy().squeeze())

测试图片使用windows软件画图绘制,如下:

输出结果如下:

完整代码如下:

  1. import torch
  2. import torchvision as tv
  3. import torchvision.transforms as transforms
  4. import torch.nn as nn
  5. import torch.optim as optim
  6. import cv2
  7. # 定义是否使用GPU
  8. device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  9. # 定义网络结构
  10. class LeNet(nn.Module):
  11. def __init__(self):
  12. super(LeNet, self).__init__()
  13. self.conv1 = nn.Sequential( # input_size=(1*28*28)
  14. nn.Conv2d(1, 6, 5, 1, 2), # padding=2保证输入输出尺寸相同
  15. nn.ReLU(), # input_size=(6*28*28)
  16. nn.MaxPool2d(kernel_size=2, stride=2), # output_size=(6*14*14)
  17. )
  18. self.conv2 = nn.Sequential(
  19. nn.Conv2d(6, 16, 5),
  20. nn.ReLU(), # input_size=(16*10*10)
  21. nn.MaxPool2d(2, 2) # output_size=(16*5*5)
  22. )
  23. self.fc1 = nn.Sequential(
  24. nn.Linear(16 * 5 * 5, 120),
  25. nn.ReLU()
  26. )
  27. self.fc2 = nn.Sequential(
  28. nn.Linear(120, 84),
  29. nn.ReLU()
  30. )
  31. self.fc3 = nn.Linear(84, 10)
  32. # 定义前向传播过程,输入为x
  33. def forward(self, x):
  34. x = self.conv1(x)
  35. x = self.conv2(x)
  36. # nn.Linear()的输入输出都是维度为一的值,所以要把多维度的tensor展平成一维(一行)
  37. x = x.view(x.size()[0], -1)
  38. x = self.fc1(x)
  39. x = self.fc2(x)
  40. x = self.fc3(x)
  41. return x
  42. # 使得我们能够手动输入命令行参数,就是让风格变得和Linux命令行差不多
  43. # parser = argparse.ArgumentParser()
  44. # parser.add_argument('--outf', default='./model/', help='folder to output images and model checkpoints') # 模型保存路径
  45. # parser.add_argument('--net', default='./model/net.pth', help="path to netG (to continue training)") # 模型加载路径
  46. # opt = parser.parse_args()
  47. # 超参数设置
  48. EPOCH = 10 # 遍历数据集次数
  49. BATCH_SIZE = 256 # 批处理尺寸(batch_size)
  50. LR = 0.001 # 学习率
  51. # 定义数据预处理方式
  52. transform = transforms.ToTensor()
  53. # 定义训练数据集
  54. trainset = tv.datasets.MNIST(
  55. root='./data/',
  56. train=True,
  57. download=False,
  58. transform=transform)
  59. # 定义训练批处理数据
  60. trainloader = torch.utils.data.DataLoader(
  61. trainset,
  62. batch_size=BATCH_SIZE,
  63. shuffle=True,
  64. )
  65. # 定义测试数据集
  66. testset = tv.datasets.MNIST(
  67. root='./data/',
  68. train=False,
  69. download=False,
  70. transform=transform)
  71. # 定义测试批处理数据
  72. testloader = torch.utils.data.DataLoader(
  73. testset,
  74. batch_size=BATCH_SIZE,
  75. shuffle=False,
  76. )
  77. # 定义损失函数loss function 和优化方式(采用SGD)
  78. net = LeNet().to(device)
  79. a = torch.load('./params.pth')
  80. net.load_state_dict(torch.load('./params.pth'))
  81. criterion = nn.CrossEntropyLoss() # 交叉熵损失函数,通常用于多分类问题上
  82. optimizer = optim.SGD(net.parameters(), lr=LR, momentum=0.9)
  83. # 训练并保存模型参数
  84. def train():
  85. for epoch in range(EPOCH):
  86. sum_loss = 0.0
  87. # 数据读取
  88. for i, data in enumerate(trainloader):
  89. inputs, labels = data
  90. inputs, labels = inputs.to(device), labels.to(device)
  91. # 梯度清零
  92. optimizer.zero_grad()
  93. # forward + backward
  94. outputs = net(inputs)
  95. loss = criterion(outputs, labels)
  96. loss.backward()
  97. optimizer.step()
  98. # 每训练100个batch打印一次平均loss
  99. sum_loss += loss.item()
  100. if i % 100 == 99:
  101. print('[%d, %d] loss: %.03f'
  102. % (epoch + 1, i + 1, sum_loss / 100))
  103. sum_loss = 0.0
  104. # 每跑完一次epoch测试一下准确率
  105. with torch.no_grad():
  106. correct = 0
  107. total = 0
  108. for data in testloader:
  109. images, labels = data
  110. images, labels = images.to(device), labels.to(device)
  111. outputs = net(images)
  112. # 取得分最高的那个类
  113. _, predicted = torch.max(outputs.data, 1)
  114. total += labels.size(0)
  115. correct += (predicted == labels).sum()
  116. print('第%d个epoch的识别准确率为:%d%%' % (epoch + 1, (100 * correct / total)))
  117. # 保存模型参数
  118. torch.save(net.state_dict(), './params.pth')
  119. if __name__ == "__main__":
  120. # train()
  121. img = cv2.imread('./2.png', cv2.IMREAD_GRAYSCALE)
  122. img = cv2.resize(img,(28, 28))
  123. img = torch.from_numpy(img).float()
  124. img = img.view(1, 1, 28, 28)
  125. img = img.to(device)
  126. outputs = net(img)
  127. _, predicted = torch.max(outputs.data, 1)
  128. print(predicted.to('cpu').numpy().squeeze())

 

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/知新_RL/article/detail/339399
推荐阅读
相关标签
  

闽ICP备14008679号