当前位置:   article > 正文

重磅分享:使用PyTorch实现MNIST手写体识别代码_pytorch手写体识别案例代码

pytorch手写体识别案例代码

@本文来源于公众号:csdn2299,喜欢可以关注公众号 程序员学府
今天小编就为大家分享一篇使用PyTorch实现MNIST手写体识别代码,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧

实验环境

win10 + anaconda + jupyter notebook

Pytorch1.1.0

Python3.7

gpu环境(可选)

MNIST数据集介绍

MNIST 包括6万张28x28的训练样本,1万张测试样本,可以说是CV里的“Hello Word”。本文使用的CNN网络将MNIST数据的识别率提高到了99%。下面我们就开始进行实战。

导入包

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
torch.__version__
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6

定义超参数

BATCH_SIZE=512
EPOCHS=20
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  • 1
  • 2
  • 3

数据集

我们直接使用PyTorch中自带的dataset,并使用DataLoader对训练数据和测试数据分别进行读取。如果下载过数据集这里download可选择False

train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('data', train=True, download=True, 
            transform=transforms.Compose([
              transforms.ToTensor(),
              transforms.Normalize((0.1307,), (0.3081,))
            ])),
    batch_size=BATCH_SIZE, shuffle=True)
 
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('data', train=False, transform=transforms.Compose([
              transforms.ToTensor(),
              transforms.Normalize((0.1307,), (0.3081,))
            ])),
    batch_size=BATCH_SIZE, shuffle=True)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14

定义网络

该网络包括两个卷积层和两个线性层,最后输出10个维度,即代表0-9十个数字。

class ConvNet(nn.Module):
  def __init__(self):
    super().__init__()
    self.conv1=nn.Conv2d(1,10,5) # input:(1,28,28) output:(10,24,24) 
    self.conv2=nn.Conv2d(10,20,3) # input:(10,12,12) output:(20,10,10)
    self.fc1 = nn.Linear(20*10*10,500)
    self.fc2 =</
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/小丑西瓜9/article/detail/214897
推荐阅读
相关标签
  

闽ICP备14008679号