赞
踩
首先我们知道不论是保存模型还是参数都需要用到torch.save()
。
对于torch.save()
有两种保存方式:
model.state_dict()
;Eg. 假设我有一个训练好的模型名叫model,如何来保存参数以及结构?
import torch
# 保存模型步骤
torch.save(model, 'net.pth') # 保存整个神经网络的模型结构以及参数
torch.save(model, 'net.pkl') # 同上
torch.save(model.state_dict(), 'net_params.pth') # 只保存模型参数
torch.save(model.state_dict(), 'net_params.pkl') # 同上
# 加载模型步骤
model = torch.load('net.pth') # 加载整个神经网络的模型结构以及参数
model = torch.load('net.pkl') # 同上
model.load_state_dict(torch.load('net_params.pth')) # 仅加载参数
model.load_state_dict(torch.load('net_params.pkl')) # 同上
上面例子也可以看出若使用torch.save()
来进行模型参数的保存,那保存文件的后缀其实没有任何影响,.pkl 文件和 .pth 文件一模一样。
实际上,这两种格式的文件还是有区别的。
首先介绍 .pkl 文件,它若直接打开会显示一堆序列化的东西,以二进制形式存储的。如果去 read 这些文件,需要用'rb'
而不是'r'
模式。
import pickle as pkl
file = os.path.join('annot',model.pkl) # 打开pkl文件
with open(file, 'rb') as anno_file:
result = pkl.load(anno_file)
或者:
import pickle as pkl
file = os.path.join('annot',model.pkl) # 打开pkl文件
anno_file = open(file, 'rb')
result = pkl.load(anno_file)
import torch
filename = r'E:\anaconda\model.pth' # 字符串前面加r,表示的意思是禁止字符串转义
model = torch.load(filename)
print(model)
但其实不管pkl文件还是pth文件,都是以二进制形式存储的,没有本质上的区别,你用pickle这个库去加载pkl文件或pth文件,效果都是一样的。
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。