当前位置:   article > 正文

手写数字识别_MNIST数据集_手写数字识别数据集

手写数字识别数据集

前言

MNIST数据集由250个不同的人手写而成,总共有7000张手写数据集。其中训练集有6000张,测试集有1000张。每张图片大小为28x28,或者说是由28x28个像素组成。这章打算用一个简单的模型进行手写字符识别。

MNIST

下载MNIST数据集的方式有很多,可以去MNIST官网下载,也可以用函数api下载
官网下载网页为:http://yann.lecun.com/exdb/mnist/,复制链接打开之后可以在网页中看到以下信息,下图圈起来的就是数据集。

在这里插入图片描述
本文采用的是通过pytorch的函数下载

from torchvision import datasets
# 下载训练集,测试集
traindataset = datasets.MNIST(root="./data/",train=True,download=False)
testdataset = datasets.MNIST(root="./data/",train=False,download=False)
  • 1
  • 2
  • 3
  • 4

接下来将数据集保存为图片。

# 查看手写数据集Mnist,保存图片集
import torchvision
from torchvision import datasets
import cv2
from tqdm import tqdm
import numpy as np
from concurrent.futures import ThreadPoolExecutor, as_completed 

def download_save_img(img_message,index,path,train=True):
    # 这里用的是opencv保存图片,img_message是一个tuple,
    #其中tuple[0]是图片类别,tuple[1]是PIL格式的图片,用opencv保存的需要转为numpy格式
    img = np.array(img_message[0])
    img_class = img_message[1]
    cv2.imwrite(path+str(img_class)+"_"+str("train" if train else "test")+str(index)+".jpg",img)

results = []
traindataset = datasets.MNIST(root="./data/",train=True,download=False)
# 打印训练集数量
print(len(traindataset))
# 多线程下保存图片
with ThreadPoolExecutor(max_workers=None) as t:
    for index, img_message in enumerate(traindataset): # "./data/MNIST_ori_img/"
        results.append(t.submit(download_save_img, img_message, index, "./data/MNIST_ori_img/", train=True))
    for result in tqdm(as_completed(results),total=len(results),desc = "train"):
        pass
results = []
testdataset = datasets.MNIST(root="./data/",train=False,download=False)
print(len(testdataset))
with ThreadPoolExecutor(max_workers=None) as t:
    for index, img_message in enumerate(testdataset): # "./data/MNIST_ori_img/"
        results.append(t.submit(download_save_img, img_message, index, "./data/MNIST_ori_img/", train=False))
    for result in tqdm(as_completed(results),total=len(results),desc = "test"):
        pass
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33

可以在文件所在目录下/data/ 查看手写数字原图,图片为黑白手写图集,具体可看下图
在这里插入图片描述

训练

构建一个由2个激活层和全连接层组成的模型
激活函数可以引入非线性因素,为什么要引入非线性因素呢,主要是因为我们所要解决的问题(识别手写字符)是一个非线性问题,引入非线性因素可以更有效地解决非线性问题。
全连接层主要作用是分类,将学到的“分布式特征表示”映射到样本标记空间

import torch
import torch.nn as nn
import numpy as np

class Cnn(nn.Module):
    def __init__(self, class_num):
        super(Cnn, self).__init__()
        self.flatten = nn.Flatten()
        self.relu = nn.ReLU()
        self.relu1 = nn.ReLU()
        self.linear_out = nn.Linear(784,class_num)

    def forward(self, x):
        x = self.flatten(x)
        x = self.relu(x)
        x = self.relu1(x)
        x = self.linear_out(x)
        return x
# 手写字符从0到9 总共有10个类别 实例化模型
cnn = Cnn(class_num = 10)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
from tqdm import tqdm
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor
import torch 
from torchvision import datasets
import torchvision

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 设置batchsize
bs = 8
# 设置多进程加载数据
nw = 4
# 设置训练迭代  
epoches = 10
# 加载数据集
train_set = datasets.MNIST(root="./data",train=True,transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ]),download=True)
train_loader = DataLoader(dataset=train_set,batch_size=bs, shuffle=True, num_workers=nw)
val_set = datasets.MNIST(root="./data",train=False,transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ]),download=True)
val_loader = DataLoader(dataset=val_set,batch_size=bs, shuffle=True, num_workers=nw)


# 训练 使用Adam作为优化器
optimizer = torch.optim.Adam(cnn.parameters(), lr=0.1)
# 损失函数使用交叉熵损失函数
criterion = torch.nn.CrossEntropyLoss()
for epoch in range(epoches):
    train_loss = 0
    train_correct = 0
    val_loss = 0
    val_correct = 0
    for inputs, labels in tqdm(train_loader):
        cnn.to(device).train()
        inputs = inputs.to(device)
        labels = labels.to(device)
        outputs = cnn(inputs)
        _, preds = torch.max(outputs, 1)
        loss = criterion(outputs,labels)
        loss.backward()
        optimizer.step()
        ## loss计算
        train_loss += loss.item() * inputs.size(0)
        train_correct += torch.sum(preds == labels.data)
    for inputs, labels in tqdm(val_loader):
        cnn.to(device).eval()
        inputs = inputs.to(device)
        labels = labels.to(device)
        with torch.no_grad():
            outputs = cnn(inputs)
            _, preds = torch.max(outputs, 1)
            loss = criterion(outputs,labels)
        ## loss计算
        val_loss += loss.item() * inputs.size(0)
        val_correct += torch.sum(preds == labels.data)
    train_losses = train_loss / len(train_loader.dataset)
    train_acc = float(train_correct) / len(train_loader.dataset)
    valid_losses = val_loss / len(val_loader.dataset)
    valid_acc = float(val_correct) / len(val_loader.dataset)
    print("epoch: {},  train_loss is: {}, train_acc is: {}, val_loss: {}, val_acc: {}".format(epoch,train_losses,
                                                                                                                  train_acc,
                                                                                                                  valid_losses,
                                                                                                                  valid_acc))

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70

在这里插入图片描述

结果

不采用卷积神经网络的情况下,准确率在89%左右

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/盐析白兔/article/detail/568898
推荐阅读
相关标签
  

闽ICP备14008679号