赞
踩
matplotlib库是一个非常强大的绘图工具,其可以完成python中多种二维或三维图表的绘制,而在深度学习中,我们也经常会看到它的身影。今天给大家分享一下如何利用matplotlib库中的pyplot模块绘制神经网络模型训练与测试后的准确率、损失值曲线图。
1、pyplot.plot() 函数基本参数解析
1.1 使用方式
pyplot.plot() 函数在导入相关包后可简写为plt.plot(),其完整使用方式如下所示:
plt.plot(x, y, format_string, **kwargs)
1.2 参数介绍
(1)x,y:x轴与y轴数据值,列表或数组形式
(2)format_string:可选参数,控制曲线的格式字符串,包括以下内容:
(2.1)ls:折线图的线条风格;
折线图线条风格包括颜色和线条样式的改变,包括但不限于以下内容:
线条样式:
1、‘-’:实线样式 2、‘--’:短横线样式 3、‘-.’:点划线样式 4、‘:’:虚线样式 5、‘.’:点标记 6、‘o’:圆标记 7、‘V’:倒三角标记 8、‘^’:正三角标记 9、‘<’:左三角标记 10、‘>’:右三角表示 11、‘1’:下箭头标记 12、‘2’:上箭头标记 13、‘3’:左箭头标记 14、‘4’:右箭头标记
颜色样式:
1、‘b’ :蓝色 2、‘c’: 青绿色 3、‘g’: 绿色 4、‘k’ :黑色 5、‘m’:洋红色 6、‘r’: 红色 7、‘w’:白色 8、‘y’: 黄色
(2.2)lw:折线图的线条宽度;
(2.3)label:标记图形内容的标签文本;
(2.4)alpha:透明度
(3)**kwargs:第二组或更多(x,y,format_string),用于绘制多条曲线
2、使用plt.plot() 函数绘制神经网络模型准确率和损失值曲线图
先放代码:
- import numpy as np
- from matplotlib import pyplot as plt
-
- # 定义两个空列表,用于存储损失值loss和准确率acc
- Loss_list = []
- Acc_list = []
- # 将每次训练后的准确率或损失值存入列表中
- Loss_list.append(f'{train_loss / (len(train_dataset)):.2f}')
- Acc_list.append(f'{100 * train_acc / (len(train_dataset)):.2f}')
-
- # 对存入列表中的数据进行强制类型转换(转换的类型大家自选,这里是float)
- for i in range(0, len(Loss_list)):
- Loss_list[i] = float(Loss_list[i])
- Acc_list[i] = float(Acc_list[i])
-
- # x轴数据的取值范围(训练多少次就填多少)
- x1 = range(0, 30)
- x2 = range(0, 30)
- y1 = Acc_list
- y2 = Loss_list
-
- # 用于显示中文字符
- # plt.rcParams['font.sans-serif'] = ['SimHei']
- # plt.rcParams['axes.unicode_minus'] = False
-
- # acc图像
- plt.subplot(2, 1, 1)
- plt.plot(x1, y1, 'o-')
- plt.title('model accuracy')
- plt.ylabel('accuracy unit:%')
- my_yTicks1 = np.arange(60, 100, 10)
- plt.yticks(my_yTicks1)
- # loss图像
- plt.subplot(2, 1, 2)
- plt.plot(x2, y2, '.-')
- plt.xlabel('model loss')
- plt.ylabel('loss')
- my_yTicks2 = np.arange(0.02, 0.2, 0.02)
- plt.yticks(my_yTicks2)
-
- plt.savefig("accuracy_loss.jpg")
- plt.show()
代码中的注释对应着每一步绘图的步骤,因此这里只对上述绘图函数中的详细操作进行下解释
(1)plt.subplot(2,1,1) 与 plt.subplot(2,1,2):前面两个数字意为将整个图像窗口分为2行1列, 最后一个数字是对图像的编号(这里绘制了两个曲线图,其中acc曲线图编号为1,loss曲线图编号为2)。如果不进行编号,那么acc与loss的两条数据曲线将会处于一个曲线图上。
(2)my_yTicks1 = np.arange(60, 100, 10) 这句代码是取10个区间在[60, 100]的值,且每个值的取值间隔为10, 即取得的值为60,70......。plt.yticks(my_yTicks1)则是将my_yTicks1中的数据作为曲线图的y轴刻度值。
(3)matplotlib库绘制图表时不能含有中文,如果大家绘制的图表有中文需求,就需要加入下面两句代码
- plt.rcParams['font.sans-serif'] = ['SimHei']
- plt.rcParams['axes.unicode_minus'] = False
注意点:plt.savefig("accuracy_loss.jpg") 这句代码必须位于 plt.show() 的前面,前一句代码用于保存绘制的曲线图,后一句代码则是展示曲线图,若两句代码顺序错乱,则图表的保存不会达到理想效果。这是因为plt.show()在展示绘制的曲线图时,进程也会被暂停,关闭展示后会创建一个新的空白图片,这时候再使用 plt.savefig() 保存的则为新生成的空白图片。
这里对上次的水稻病害识别模型训练了10轮,绘制了相关acc与loss曲线图,如下所示:
可以看到其实模型还未完全收敛,还可以训练几轮哈(但是博主懒得跑了),水稻病害识别的代码与文章链接:(深度学习)基于残差卷积——resnet的水稻病害识别-CSDN博客
参考文章:
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。