当前位置:   article > 正文

【小白也可做】pytorch实现普通CNN的MNIST手写数字分类,含t-SNE聚类图、混淆矩阵图绘制,内含通用代码,可根据自己的项目需要进行修改_cnn pychram 项目结构

cnn pychram 项目结构

一、项目简介

本项目是基于pytorch使用两层CNN网络实现手写数字的分类识别,并且绘制了损失和准确率曲线、训练前后的t-SNE聚类图、混淆矩阵图。

二、数据集

常用的手写数字数据集MNIST,这个大家自行百度就有很多说明这个数据集的文章啦,里面的图片大概是长这样子的。

三、实验环境

平台:Window 11

语言:python3.9

编译器:Pycharm

框架:Pytorch:1.13.1

四、实验内容及部分代码展示

1、Model.py 模型构建

该项目使用的网络包含2维卷积、池化层、全连接层,通过ReLU激活函数进行非线性变换

2、train.py 用于分类的训练通用模板

3、Config.py 参数定义

config类中定义了项目所有需要的参数,可以在里面修改训练参数。

4、mnist_class_cnn_run.py 运行文件

该py文件实现整体训练流程并做绘图操作。依次实现加载数据、数据格式转化、划分训练集测试集、形成数据更迭器、载入模型、定义损失、定义优化器、开始训练、损失可视化、显示预测结果。

5、test_pth.py 模型训练后的测试文件

采用模型训练完成后的pth对数据进行预测,可以展示模型预测效果,前面对数据的处理过程类似mnist_class_cnn_run.py所示。

6、draw_loss_acc.py 模型训练后的loss绘图

将训练后产生并收集的损失loss.csv和准确率acc.csv展示出来,也就是损失和准确率变化曲线。

7、tsne_plot.py 模型训练后用于绘制t-SNE聚类图

绘制了训练前和训练后样本的t-SNE聚类图

8、matrix_plot.py 绘制模型训练后的混淆矩阵

五、实验效果分析

1、loss损失图

该损失是训练了50个epoch的损失图

2、acc准确率图

3、test_pth.py的预测效果展示 

其中红框是预测有误的,有误的概率比较小

4、matrix_plot.py的混淆矩阵效果展示 

第一个图是训练了第一个epoch后的混淆矩阵,第二个图是训练了50个epoch的混淆矩阵,横坐标是预测值,纵坐标是真实值,中间的数字指的是样本数,比如第一个数字976指的是有真实值为0的976个样本预测出是0,同一行后面有4个真实值是0的样本预测的是6和8。相比于第1个和第50个epoch的混淆矩阵中的数据,对角线处其对应正确预测的样本数越来越大,说明模型训练有效果。

5、matrix_plot.py的混淆矩阵效果展示 

第一个是未训练时的样本聚类图,第二个是训练后的样本聚类图,模型训练后,各数字的分布明显分隔开了,说明模型对数字识别分类有效果。

六、资源与总结

若有朋友需要可运行的源码和数据集,可以guan注【科研小条】公众号,回复【手写数字分类】,即可获得。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/小蓝xlanll/article/detail/235337
推荐阅读
相关标签
  

闽ICP备14008679号