赞
踩
@本文来源于公众号:csdn2299,喜欢可以关注公众号 程序员学府
今天小编就为大家分享一篇使用PyTorch实现MNIST手写体识别代码,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
win10 + anaconda + jupyter notebook
Pytorch1.1.0
Python3.7
gpu环境(可选)
MNIST 包括6万张28x28的训练样本,1万张测试样本,可以说是CV里的“Hello Word”。本文使用的CNN网络将MNIST数据的识别率提高到了99%。下面我们就开始进行实战。
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
torch.__version__
BATCH_SIZE=512
EPOCHS=20
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
我们直接使用PyTorch中自带的dataset,并使用DataLoader对训练数据和测试数据分别进行读取。如果下载过数据集这里download可选择False
train_loader = torch.utils.data.DataLoader(
datasets.MNIST('data', train=True, download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])),
batch_size=BATCH_SIZE, shuffle=True)
test_loader = torch.utils.data.DataLoader(
datasets.MNIST('data', train=False, transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])),
batch_size=BATCH_SIZE, shuffle=True)
该网络包括两个卷积层和两个线性层,最后输出10个维度,即代表0-9十个数字。
class ConvNet(nn.Module):
def __init__(self):
super().__init__()
self.conv1=nn.Conv2d(1,10,5) # input:(1,28,28) output:(10,24,24)
self.conv2=nn.Conv2d(10,20,3) # input:(10,12,12) output:(20,10,10)
self.fc1 = nn.Linear(20*10*10,500)
self.fc2 =</
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。