赞
踩
PyTorch模型加载的时候,有预训练模型,通过使用预训练模型可以给模型使用带来很多的便捷,对于模型的使用以下给出了一些总结,如有错误恳请指正。
一、直接加载预训练模型进行训练
1、加载保存的整个模型
torch.save(model,'model.pkl')
...
model = torch.load('model.pkl')
2、加载保存的模型参数
torch.save(model.state_dict(),'model_state_dict.pkl')
...
model.load_state_dict(torch.load('model_state_dict.pkl'))
关于模型的保存和加载,可以详细参照我的这篇文章:HUST小菜鸡:Pytorch搭建简单神经网络(三)——快速搭建、保存与提取zhuanlan.zhihu.com
通过对模型参数的保存的解析,我们可以深入的了解
load_dict = torch.load('models/cifar10_statedict.pkl')
print(load_dict.keys())
print(type(load_dict))
输出的结果如下所示:
odict_keys(['conv1.0.weight', 'conv1.0.bias', 'conv2.0.weight', 'conv2.0.bias', 'conv3.0.weight', 'conv3.0.bias', 'conv4.0.weight', 'conv4.0.bias', 'conv5.0.weight', 'conv5.0.bias', 'conv6.0.weight', 'conv6.0.bias', 'classifier.1.weight', 'classifier.1.bias', 'classifier.3.weight', 'classifier.3.bias', 'classifier.5.weight', 'classifier.5.bias'])
可以看出保存的state_dict其实是一个collections.OrderedDict的Object,和普通的dict不同的是,该类别是有着严格的顺序,而dict中的元素是没有严格的顺序。
但是有一个问题值得深入考量——两个网络的结构是一样的,但是结构的命名是不一样的,那么对于这种模型的加载,如果不一样的话会出现报错,该如何解决
参照以上结果的输出,state_dict中key就是网络结构的名称,所以当网络结构一样的时候,只需要修改索引key,就可以解决以上的问题,至于如何修改可以参照如下方式:https://stackoverflow.com/questions/12150872/change-key-in-ordereddict-without-losing-orderstackoverflow.com
二、加载部分预训练模型
我们经常对现有的经典网络进行如下操作,我们不修改网络的主体部分,我们只修改网络的输出,或者在最后加上一些网络层来达到我们想要的输出结果,虽然很难保证网络模型和某些公开的模型完全一样,但是预训练模型的参数确实有助于提高训练的准确率,为了结合二者的优点,就需要我们加载部分预训练模型。
model = cifar10_cnn.CIFAR10_Nettest()
pretrained_dict = torch.load('models/cifar10_statedict.pkl')
model_dict = model.state_dict()
print('随机初始化权重第一层:',model_dict['conv1.0.weight'])
# 将pretrained_dict里不属于model_dict的键剔除掉
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
print('预训练权重第一层:',pretrained_dict['conv1.0.weight'])
# 更新现有的model_dict
model_dict.update(pretrained_dict) #利用预训练模型的参数,更新模型
model.load_state_dict(model_dict)
print('更新后权重第一层:',model_dict['conv1.0.weight'])
输出的部分结果如下所示,为了直观显示我只截取了中间的某一部分
随机初始化权重第一层: tensor([[[[ 0.0142, 0.1039, 0.1260],
[ 0.1805, -0.0533, 0.0007],
[-0.1032, -0.1039, -0.0633]],
[[ 0.0714, -0.0053, 0.0059],
[-0.0528, 0.0438, -0.1108],
[ 0.0544, 0.0157, 0.1265]],
预训练权重第一层: tensor([[[[ 8.0685e-02, -3.8643e-02, 3.4450e-02],
[-2.3942e-01, -1.5474e-01, 1.3142e-01],
[-9.4602e-02, 6.4120e-02, -9.4336e-02]],
[[ 9.7318e-02, 1.0526e-01, 2.3400e-03],
[-5.8471e-02, -8.8146e-02, -1.6053e-01],
[-1.0788e-01, -5.9083e-02, -9.0651e-02]],
更新后权重第一层: tensor([[[[ 8.0685e-02, -3.8643e-02, 3.4450e-02],
[-2.3942e-01,
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。