赞
踩
本项目是基于pytorch使用两层CNN网络实现手写数字的分类识别,并且绘制了损失和准确率曲线、训练前后的t-SNE聚类图、混淆矩阵图。
常用的手写数字数据集MNIST,这个大家自行百度就有很多说明这个数据集的文章啦,里面的图片大概是长这样子的。
平台:Window 11
语言:python3.9
编译器:Pycharm
框架:Pytorch:1.13.1
1、Model.py 模型构建
该项目使用的网络包含2维卷积、池化层、全连接层,通过ReLU激活函数进行非线性变换
2、train.py 用于分类的训练通用模板
3、Config.py 参数定义
config类中定义了项目所有需要的参数,可以在里面修改训练参数。
4、mnist_class_cnn_run.py 运行文件
该py文件实现整体训练流程并做绘图操作。依次实现加载数据、数据格式转化、划分训练集测试集、形成数据更迭器、载入模型、定义损失、定义优化器、开始训练、损失可视化、显示预测结果。
5、test_pth.py 模型训练后的测试文件
采用模型训练完成后的pth对数据进行预测,可以展示模型预测效果,前面对数据的处理过程类似mnist_class_cnn_run.py所示。
6、draw_loss_acc.py 模型训练后的loss绘图
将训练后产生并收集的损失loss.csv和准确率acc.csv展示出来,也就是损失和准确率变化曲线。
7、tsne_plot.py 模型训练后用于绘制t-SNE聚类图
绘制了训练前和训练后样本的t-SNE聚类图
8、matrix_plot.py 绘制模型训练后的混淆矩阵
该损失是训练了50个epoch的损失图
其中红框是预测有误的,有误的概率比较小
第一个图是训练了第一个epoch后的混淆矩阵,第二个图是训练了50个epoch的混淆矩阵,横坐标是预测值,纵坐标是真实值,中间的数字指的是样本数,比如第一个数字976指的是有真实值为0的976个样本预测出是0,同一行后面有4个真实值是0的样本预测的是6和8。相比于第1个和第50个epoch的混淆矩阵中的数据,对角线处其对应正确预测的样本数越来越大,说明模型训练有效果。
第一个是未训练时的样本聚类图,第二个是训练后的样本聚类图,模型训练后,各数字的分布明显分隔开了,说明模型对数字识别分类有效果。
若有朋友需要可运行的源码和数据集,可以guan注【科研小条】公众号,回复【手写数字分类】,即可获得。
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。