赞
踩
有些模型训练完在log中打印了每个epoch的loss和精度,但是没有一个直观的曲线图,也找不到最大acc。
train.log类似于:
Epoch: 0, Test_loss: 5.9283, Test_acc: 0.1417
Epoch: 1, Test_loss: 5.8111, Test_acc: 0.1139
Epoch: 2, Test_loss: 5.8015, Test_acc: 0.1389
Epoch: 3, Test_loss: 5.7686, Test_acc: 0.1417
Epoch: 4, Test_loss: 5.8408, Test_acc: 0.1667
Epoch: 5, Test_loss: 5.7008, Test_acc: 0.1417
因此写了一个函数用来返回acc的最大值,并且绘制出 epoch-loss epoch-acc曲线图。
def log2draw(path, num, isloss, acc_index, loss_index=None, rule=None):
"""
说明:rule = r'\d+?\.\d+' 匹配 1.21 0.123 1.0 0.0 这种
如果有loss,isloss=True loss_index要写,否则默认=None
import re
import matplotlib.pyplot as plt
:param path: log或txt文件的路径
:param num: 每行匹配后的列表应该有多少个元素/数字
:param isloss: 匹配到的是否有Loss
:param acc_index: acc在匹配后列表中的索引
:param loss_index: loss在匹配后列表中的索引
:param rule: 正则匹配
:return: 绘制两个子图 epoch-loss epoch-acc
"""
plt.rcParams['font.family'] = 'FangSong' # 设置字体为仿宋
plt.rcParams['font.size'] = 10 # 设置字体的大小为10
plt.rcParams['axes.unicode_minus'] = False # 显示正、负的问题
if rule:
rule = rule
else:
rule = r'\d+?\.\d+'
pattern = re.compile(rule)
with open(path) as f:
lines = f.readlines()
print(lines[:2])
loss = []
acc = []
epoch = []
e = 0
max_acc = [0, 0]
for line in lines:
find = pattern.findall(line)
if find and len(find) == num:
if isloss:
loss.append(float(find[loss_index]))
acc.append(float(find[acc_index]))
if float(find[acc_index]) > max_acc[1]:
max_acc[0] = e
max_acc[1] = float(find[acc_index])
epoch.append(e)
e += 1
print('The best acc is epoch:{} acc:{}'.format(max_acc[0], max_acc[1]))
if isloss:
# 参考 https://blog.csdn.net/weixin_48468999/article/details/117537303
# 第一个子图
ax1 = plt.subplot(121)
ax1.plot(epoch, acc)
ax1.set_xlabel('Epoch') # 为x轴添加标签
ax1.set_ylabel('AP') # 为y轴添加标签
# ax1.legend(loc='upper left') # 设置图表图例在左上角
ax1.grid(True) # 绘制网格
# 第二个子图
ax1 = plt.subplot(122)
ax1.plot(epoch, loss)
ax1.set_xlabel('Epoch') # 为x轴添加标签
ax1.set_ylabel('Loss') # 为y轴添加标签
ax1.grid(True) # 绘制网格
plt.tight_layout() # 自动调整各子图间距
else:
plt.xlabel('Epoch')
plt.ylabel('AP')
plt.grid()
plt.plot(epoch, acc)
plt.show()
if __name__ == '__main__':
import re
import matplotlib.pyplot as plt
path = 'data/train.log'
log2draw(path, 2, True, 1, 0)
关于isloss , loss_index, acc_index这三个参数,以上面的Log文件进行说明。
isloss: 代表输出的信息中,是否有Loss值,True表示有。
loss_index,acc_index : 分别代表 loss 和 acc的索引值,如图:
find为每一行匹配后得到的列表,可以看到,列表中的两个元素分别代表 loss和acc的值。所以loss_index = 0 , acc_index = 1。
还有一个num参数:该参数代表find中,应该有几个元素,这是为了判断log文件中是否有无效行,例如:
很显然这里的过滤方式并不 严谨,还需要进一步改进。
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。