赞
踩
跑了个网络,每个epoch的损失等信息会打印在日志里,但是损失变化通过数字来看不够直观。
那个日志的格式是这样的:
- import numpy as np
- from os import path
- import os
- import matplotlib.pyplot as plt
- from matplotlib import cm,ticker
- if __name__ == '__main__':
- path="tt/training_log_2021_3_31_09_50_38.txt"
- out = open(path, encoding='utf-8')
- lines = out.readlines()
-
- #提取trainLoss和validationLoss
- trainLoss=[]
- validationLoss=[]
- for line in lines:
- if "train loss" in line:
- val=np.float(line.split("loss : ")[-1][:-1]) #[:-1]是去除末尾'\n'
- trainLoss.append(val)
- if "validation loss" in line:
- val=np.float(line.split("loss: ")[-1][:-1])
- validationLoss.append(val)
-
- epochNum=len(trainLoss)
- for i in range(epochNum):
- print("epoch{}: train loss:{} val loss:{}".format(i,trainLoss[i],validationLoss[i]))
-
-
- #绘图
- fig=plt.figure()
- xs=np.arange(epochNum)
- plt.yticks(np.arange(-1,0,0.1))
- plt.plot(xs, trainLoss, color='coral', label="train loss")
- plt.plot(xs, validationLoss, color='g', label="val loss")
- plt.legend()
- plt.show()
- #plt.savefig("loss.png")
-
-
-
得到图示:
OVER
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。