当前位置:   article > 正文

深度学习入门——手写数字识别MINIST数据集_minist 手写识别

minist 手写识别
  1. # 1 加载必要的库
  2. import torch
  3. import torch.nn as nn
  4. import torch.nn.functional as F
  5. import torch.optim as optim
  6. from torchvision import datasets, transforms
  1. # 2 定义超参数hyperparameter
  2. BATCH_SIZE = 64 # 每批处理的数据
  3. DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") # 是否用GPU还是CPU训练
  4. EPOCHS = 30 # 训练数据集的轮次数
  1. # 3 构建pipeline,对图像做处理transforms
  2. pipeline = transforms.Compose([
  3. transforms.ToTensor(), # 将图片转换成tensor
  4. transforms.Normalize((0.1307,), (0.3081,)) # 正则化(均值,标准差)-->当模型过拟合overfitting时可以降低模型复杂度
  5. ])
  1. # 4 下载、加载数据集
  2. from torch.utils.data import DataLoader
  3. # 下载数据集
  4. train_set = datasets.MNIST("dataset", train=True, download=True, transform=pipeline)
  5. test_set = datasets.MNIST("dataset", train=True, download=True, transform=pipeline)
  6. # 加载数据集
  7. train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True)
  8. test_loader = DataLoader(test_set, batch_size=BATCH_SIZE, shuffle=True)
  1. # 5 构建网络模型
  2. class Digit(nn.Module):
  3. def __init__(self):
  4. super().__init__()
  5. self.conv1 = nn.Conv2d(1, 10, 5) #第一个卷积层(1:输入的灰度图的通道,10:输出通道,5:卷积层Kernel)
  6. self.conv2 = nn.Conv2d(10, 20, 3) #第二个卷积层(10:输入通道,20:输出通道,3:卷积层Kernel)
  7. self.fc1 = nn.Linear(20*10*10, 500) #第一个全连接层(20*10*10:输入通道,500:输出通道)
  8. self.fc2 = nn.Linear(500, 10) #第二个全连接层(500:输入通道,10:输出通道【0~9】)
  9. def forward(self, x):
  10. input_size = x.size(0) # batch_size
  11. x = self.conv1(x) # 输入:batch*1*28*28,输出:batch*10*24*24 (28-5+1=24)
  12. x = F.relu(x) #激活函数,保持shape不变,输出:batch*10*24*24
  13. x = F.max_pool2d(x, 2, 2) #池化层 输入:batch*10*24*24, 输出:batch*10*12*12
  14. x = self.conv2(x) # 输入:batch*10*12*12,输出:batch*20*10*10 (12-3+1=10)
  15. x = F.relu(x) #
  16. x = x.view(input_size, -1) # 拉平, -1:自动计算维度 20*10*10=2000
  17. x = self.fc1(x) # 输入:batch*2000 输出:batch*500
  18. x = F.relu(x) # 激活, 保持shape不变
  19. x = self.fc2(x) # 输入:batch*500,输出:batch*10
  20. output = F.log_softmax(x, dim=1) #计算分类后,每个数字0~9的概率
  21. return output
  1. # 6 定义优化器
  2. model = Digit().to(DEVICE)
  3. optimizer = optim.Adam(model.parameters())
  1. # 7 定义训练函数
  2. def train_model(model, device, train_loader, optimizer, epoch):
  3. # 模型训练
  4. model.train()
  5. for batch_index, (img, target) in enumerate(train_loader):
  6. # 将数据部署到DEVICE上去
  7. img, target = img.to(device), target.to(device)
  8. # 梯度初始化为0
  9. optimizer.zero_grad()
  10. # 训练后的结果
  11. output = model(img)
  12. # 计算loss
  13. loss = F.cross_entropy(output, target) #cross_entropy适合多分类问题,将计算结果与真实值对比
  14. # 反向传播
  15. loss.backward()
  16. # 参数优化
  17. optimizer.step() # 用step方法更新参数
  18. # 每隔3000张图片打印一次loss
  19. if batch_index % 3000 == 0:
  20. print("Train Epoch : {} \t Loss : {:.6f}".format(epoch, loss.item()))
  1. # 8 定义测试方法
  2. def test_model(model, device, test_loader):
  3. # 模型验证
  4. model.eval()
  5. # 初始化正确率
  6. correct = 0.0
  7. # 初始化测试loss
  8. test_loss = 0.0
  9. with torch.no_grad(): # 测试时不会计算梯度,也不会进行反向传播
  10. for img, target in test_loader:
  11. # 部署到DEVICE上
  12. img, target = img.to(device), target.to(device)
  13. # 测试数据
  14. output = model(img)
  15. # 计算测试损失
  16. test_loss += F.cross_entropy(output, target).item()
  17. # 找到概率最大下标
  18. pred = output.max(1, keepdim=True)[1] #值,索引
  19. # pred = output.argmax(dim=1)
  20. # pred = torch.max(output, dim=1)
  21. # 累计正确的值
  22. correct += pred.eq(target.view_as(pred)).sum().item()
  23. test_loss /= len(test_loader.dataset)
  24. print("Test --Average loss : {:.4f}, Accuracy : {:.3f}\n".format(
  25. test_loss, 100 * correct / len(test_loader.dataset)))
  1. # 9 调用方法7、8
  2. for epoch in range(1, EPOCHS + 1):
  3. train_model(model, DEVICE, train_loader, optimizer, epoch)
  4. test_model(model, DEVICE, test_loader)

运行后即可打印出训练信息

具体搭建视频可参考03-2 轻松学 PyTorch 手写字体识别 MNIST ( 实战 - 上 )_哔哩哔哩_bilibiliicon-default.png?t=M4ADhttps://www.bilibili.com/video/BV1WT4y177SA

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/天景科技苑/article/detail/936907
推荐阅读
相关标签
  

闽ICP备14008679号

        
cppcmd=keepalive&