当前位置:   article > 正文

【Python】基于卷积神经网络的手写数字识别-PyTorch实现_利用神经网络实现手写数字识别实验报告

利用神经网络实现手写数字识别实验报告

介绍

        文章内容来自我的《深度学习》课程作业实验报告。

        手写数字识别属于图像分类问题,通过对图像进行识别,将图像分为0-9的十个类别。

解决方案

2.1 LeNet模型

        LeNet-5是一个经典的深度卷积神经网络,由Yann LeCun在1998年提出,旨在解决手写数字识别问题,被认为是卷积神经网络的开创性工作之一[2]。该网络是第一个被广泛应用于数字图像识别的神经网络之一,也是深度学习领域的里程碑之一。

        LeNet-5模型结构如下:

2.2 损失函数

        使用交叉熵损失作为损失函数,计算公式如下:

2.3 评价指标

       使用精确度作为评价指标,计算公式如下:

其中:TP为真正例,TN为真负例,FP为假正例,FN为假负例。

实验结果和分析

3.1数据集与工具

Visual Studio Code

​ PyTorch 1.13

​ MNIST 数据集

3.2训练过程及代码

        获取数据集,这里Normalize()转换使用的值0.1307和0.3081是MNIST数据集的全局平均值和标准偏差。

  1. # 获取数据集
  2. data_path = '.\mnistdata'
  3. data_tf = torchvision.transforms.Compose(
  4. [
  5. torchvision.transforms.ToTensor(),
  6. torchvision.transforms.Normalize([0.1307],[0.3081])
  7. ]
  8. )
  9. train_data = mnist.MNIST(data_path,train=True,transform=data_tf,download=True)
  10. test_data = mnist.MNIST(data_path,train=False,transform=data_tf,download=True)
  11. train_loader = data.DataLoader(train_data,batch_size=batch_size,shuffle=True,pin_memory=False)
  12. test_loader = data.DataLoader(test_data,batch_size=batch_size)

        定义LeNet-5模型网络结构

  1. class LeNet5(torch.nn.Module):
  2. def __init__(self):
  3. super(LeNet5,self).__init__()
  4. self.features = torch.nn.Sequential(
  5. torch.nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5, stride=1, padding=2, bias=True),
  6. torch.nn.MaxPool2d(kernel_size=2),
  7. torch.nn.Conv2d(in_channels=6,out_channels=16,kernel_size=5,stride=1,padding=0,bias=True),
  8. torch.nn.MaxPool2d(kernel_size=2),
  9. torch.nn.Flatten()
  10. )
  11. self.classification = torch.nn.Sequential(
  12. torch.nn.Linear(16*5*5, 120),
  13. torch.nn.ReLU(inplace=True),
  14. torch.nn.Linear(120, 84),
  15. torch.nn.Linear(84, 10),
  16. )
  17. def forward(self, input):
  18. x=self.features(input)
  19. output=self.classification(x)
  20. return output

        定义优化器与损失函数

  1. lr=0.005 #学习率
  2. momentum=0.8
  3. device=torch.device("cuda" if torch.cuda.is_available() else "cpu" )
  4. model=LeNet5().to(device)
  5. print(model)
  6. optimizer=torch.optim.SGD(model.parameters(),lr=lr,momentum=momentum)
  7. closs=torch.nn.CrossEntropyLoss()

         定义两个函数进行模型训练与测试。设置训练的epoch=9,batch=128,因为数据量比较大,iteration数量比较多,所以每50个iteration进行一次采样,将交叉熵误差和准确度进行打印输出,并进行最终的可视化分析制图。每一个epoch后,使用测试集数据进行测试,并输出交叉熵误差和准确度结果。

  1. # 网络训练
  2. def train(model,device,train_loader,optimizer,epoch,losses,accuracies):
  3. for idx,(t_data,t_target) in enumerate(train_loader):
  4. t_data,t_target=t_data.to(device),t_target.to(device)
  5. pred=model(t_data)#batch_size*10
  6. predictions = torch.argmax(pred, dim = 1)
  7. accuracy = torch.sum(predictions == t_target)/t_target.shape[0]
  8. loss=closs(pred,t_target)
  9. #SGD
  10. optimizer.zero_grad()#将上一步的梯度清0
  11. loss.backward()#重新计算梯度
  12. optimizer.step()#更新参数
  13. if idx%50==0:
  14. print("epoch:{},iteration:{},loss:{},accuracy:{}".format(epoch,idx,loss.item(),accuracy))
  15. losses.append(loss.item()) #每50批数据采样一次loss,记录下来,用来画图可视化分析。
  16. accuracies.append(accuracy.item())
  17. # 网络测试
  18. def test(model,valid_loader,criterion,valid_loss,valid_acc):
  19. model.eval()
  20. correct=0#预测对了几个。
  21. loss=0
  22. with torch.no_grad():
  23. for idx,(t_data,t_target) in enumerate(test_loader):
  24. t_data,t_target=t_data.to(device),t_target.to(device)
  25. pred=model(t_data)#batch_size*10
  26. pred_class=pred.argmax(dim=1)#batch_size*10->batch_size*1
  27. correct+=pred_class.eq(t_target.view_as(pred_class)).sum().item()
  28. loss+=closs(pred,t_target).item()
  29. acc_mean=correct/len(test_data)
  30. loss_mean=loss/len(test_loader)
  31. valid_loss.append(loss_mean)
  32. valid_acc.append(acc_mean)
  33. print("loss:{},accuracy:{},".format(loss_mean,acc_mean))
  34. model.train()
'
运行

3.3 结果与分析

        迭代过程中,交叉熵损失与精确度变化情况如上图所示。训练过程设置了epoch=9。在图中可以看出,在第5个epoch后,模型逐渐趋于稳定,并且没有出现明显的过拟合或欠拟合现象。迭代完成后,交叉熵损失为0.04,准确率为98.68%。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/神奇cpp/article/detail/936851
推荐阅读
相关标签
  

闽ICP备14008679号