当前位置:   article > 正文

【深度学习与神经网络】MNIST手写数字识别1

【深度学习与神经网络】MNIST手写数字识别1

简单的全连接层

导入相应库

import torch
import numpy as np
from torch import nn,optim
from torch.autograd import Variable
import matplotlib.pyplot as plt
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7

读入数据并转为tensor向量

# 训练集
# 转为tensor数据
train_dataset = datasets.MNIST(root='./',train=True, transform = transforms.ToTensor(), download=True)
test_dataset = datasets.MNIST(root='./',train=False, transform = transforms.ToTensor(), download=True)
  • 1
  • 2
  • 3
  • 4

装载数据集

# 批次大小
batch_size = 64

# 装载训练集
train_loader = DataLoader(dataset = train_dataset, batch_size=batch_size, shuffle = True)
test_loader = DataLoader(dataset = test_dataset, batch_size=batch_size, shuffle = True)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6

定义网络结构
一层全连接网络,最后使用softmax转概率值输出

# 定义网络结构
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(784, 10)
        self.softmax = nn.Softmax(dim =1)
        
    def forward(self, x):
        # [64,1,28,28] ——> [64, 784]
        x = x.view(x.size()[0], -1)
        x = self.fc1(x)
        x = self.softmax(x)
        return x   
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13

定义模型
使用均方误差损失函数,梯度下降优化

# 定义模型
model = Net()
mes_loss = nn.MSELoss()
optimizer = optim.SGD(model.parameters(),0.5)

  • 1
  • 2
  • 3
  • 4
  • 5

训练并测试网络:
训练时注意最后输出(64,10)
标签是(64) ,需要将其转为one-hot编码(64,10)

def train():
    for i,data in enumerate(train_loader):
        # 获得一个批次的数据和标签
        inputs, labels = data
        # 获得模型结果 (64,10)
        out = model(inputs)
        # to one-hot 把数据标签变为独热编码
        labels = labels.reshape(-1,1)
        one_hot = torch.zeros(inputs.shape[0],10).scatter(1, labels, 1)
        # 计算loss
        loss = mes_loss(out, one_hot)
        # 梯度清0
        optimizer.zero_grad()
        # 计算梯度
        loss.backward()
        # 修改权值
        optimizer.step()
        
def test():
    correct = 0
    for i,data in enumerate(test_loader):
        # 获得一个批次的数据和标签
        inputs, labels = data
        # 获得模型结果 (64,10)
        out = model(inputs)
        # 获取最大值和最大值所在位置
        _,predicted = torch.max(out,1)
        # 预测正确数量
        correct += (predicted == labels).sum()
        
        
    print("test ac:{0}".format(correct.item()/len(test_dataset)))
        
  
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34

调用模型 训练10次

# 使用mse损失函数 
for epoch in range(10):
    print("epoch:",epoch)
    train()
    test()
  • 1
  • 2
  • 3
  • 4
  • 5

训练结果:
在这里插入图片描述
准确率不够

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/weixin_40725706/article/detail/255181
推荐阅读
相关标签
  

闽ICP备14008679号