当前位置:   article > 正文

手写数字识别实战_手写数字识别实战项目csdn

手写数字识别实战项目csdn
  1. import torch
  2. import torch.nn as nn
  3. import torch.optim as optim
  4. import torchvision.transforms as transforms
  5. from torchvision.datasets import MNIST
  6. from torch.utils.data import DataLoader
  7. # 定义神经网络模型
  8. class Net(nn.Module):
  9. def __init__(self):
  10. super(Net, self).__init__()
  11. self.fc1 = nn.Linear(784, 256)
  12. self.fc2 = nn.Linear(256, 128)
  13. self.fc3 = nn.Linear(128, 10)
  14. def forward(self, x):
  15. x = x.view(x.size(0), -1)
  16. x = torch.relu(self.fc1(x))
  17. x = torch.relu(self.fc2(x))
  18. x = self.fc3(x)
  19. return x
  20. # 设置一些超参数
  21. learning_rate = 0.001
  22. batch_size = 64
  23. num_epochs = 10
  24. # 加载 MNIST 数据集
  25. train_dataset = MNIST(root='.', train=True, transform=transforms.ToTensor(), download=True)
  26. test_dataset = MNIST(root='.', train=False, transform=transforms.ToTensor())
  27. train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
  28. test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
  29. # 创建模型和优化器
  30. model = Net()
  31. optimizer = optim.Adam(model.parameters(), lr=learning_rate)
  32. criterion = nn.CrossEntropyLoss()
  33. # 训练模型
  34. total_step = len(train_loader)
  35. for epoch in range(num_epochs):
  36. for i, (images, labels) in enumerate(train_loader):
  37. # 前向传播
  38. outputs = model(images)
  39. loss = criterion(outputs, labels)
  40. # 反向传播和优化
  41. optimizer.zero_grad()
  42. loss.backward()
  43. optimizer.step()
  44. if (i + 1) % 100 == 0:
  45. print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'.format(epoch + 1, num_epochs, i + 1, total_step,
  46. loss.item()))
  47. # 测试模型
  48. model.eval()
  49. with torch.no_grad():
  50. correct = 0
  51. total = 0
  52. for images, labels in test_loader:
  53. outputs = model(images)
  54. _, predicted = torch.max(outputs.data, 1)
  55. total += labels.size(0)
  56. correct += (predicted == labels).sum().item()
  57. print('准确率: {:.2f}%'.format(100 * correct / total))

实验结果

 

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/我家小花儿/article/detail/751845
推荐阅读
相关标签
  

闽ICP备14008679号