当前位置:   article > 正文

python怎么使用预训练的模型_PyTorch使用预训练模型

pytorch 进行预训练时怎么使用自己构建的模型

PyTorch模型加载的时候,有预训练模型,通过使用预训练模型可以给模型使用带来很多的便捷,对于模型的使用以下给出了一些总结,如有错误恳请指正。

一、直接加载预训练模型进行训练

1、加载保存的整个模型

torch.save(model,'model.pkl')

...

model = torch.load('model.pkl')

2、加载保存的模型参数

torch.save(model.state_dict(),'model_state_dict.pkl')

...

model.load_state_dict(torch.load('model_state_dict.pkl'))

关于模型的保存和加载,可以详细参照我的这篇文章:HUST小菜鸡:Pytorch搭建简单神经网络(三)——快速搭建、保存与提取​zhuanlan.zhihu.comv2-5aed9b4858ee329f1de0b9d5ff33ce4a_180x120.jpg

通过对模型参数的保存的解析,我们可以深入的了解

load_dict = torch.load('models/cifar10_statedict.pkl')

print(load_dict.keys())

print(type(load_dict))

输出的结果如下所示:

odict_keys(['conv1.0.weight', 'conv1.0.bias', 'conv2.0.weight', 'conv2.0.bias', 'conv3.0.weight', 'conv3.0.bias', 'conv4.0.weight', 'conv4.0.bias', 'conv5.0.weight', 'conv5.0.bias', 'conv6.0.weight', 'conv6.0.bias', 'classifier.1.weight', 'classifier.1.bias', 'classifier.3.weight', 'classifier.3.bias', 'classifier.5.weight', 'classifier.5.bias'])

可以看出保存的state_dict其实是一个collections.OrderedDict的Object,和普通的dict不同的是,该类别是有着严格的顺序,而dict中的元素是没有严格的顺序。

但是有一个问题值得深入考量——两个网络的结构是一样的,但是结构的命名是不一样的,那么对于这种模型的加载,如果不一样的话会出现报错,该如何解决

参照以上结果的输出,state_dict中key就是网络结构的名称,所以当网络结构一样的时候,只需要修改索引key,就可以解决以上的问题,至于如何修改可以参照如下方式:https://stackoverflow.com/questions/12150872/change-key-in-ordereddict-without-losing-order​stackoverflow.com

二、加载部分预训练模型

我们经常对现有的经典网络进行如下操作,我们不修改网络的主体部分,我们只修改网络的输出,或者在最后加上一些网络层来达到我们想要的输出结果,虽然很难保证网络模型和某些公开的模型完全一样,但是预训练模型的参数确实有助于提高训练的准确率,为了结合二者的优点,就需要我们加载部分预训练模型。

model = cifar10_cnn.CIFAR10_Nettest()

pretrained_dict = torch.load('models/cifar10_statedict.pkl')

model_dict = model.state_dict()

print('随机初始化权重第一层:',model_dict['conv1.0.weight'])

# 将pretrained_dict里不属于model_dict的键剔除掉

pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}

print('预训练权重第一层:',pretrained_dict['conv1.0.weight'])

# 更新现有的model_dict

model_dict.update(pretrained_dict) #利用预训练模型的参数,更新模型

model.load_state_dict(model_dict)

print('更新后权重第一层:',model_dict['conv1.0.weight'])

输出的部分结果如下所示,为了直观显示我只截取了中间的某一部分

随机初始化权重第一层: tensor([[[[ 0.0142, 0.1039, 0.1260],

[ 0.1805, -0.0533, 0.0007],

[-0.1032, -0.1039, -0.0633]],

[[ 0.0714, -0.0053, 0.0059],

[-0.0528, 0.0438, -0.1108],

[ 0.0544, 0.0157, 0.1265]],

预训练权重第一层: tensor([[[[ 8.0685e-02, -3.8643e-02, 3.4450e-02],

[-2.3942e-01, -1.5474e-01, 1.3142e-01],

[-9.4602e-02, 6.4120e-02, -9.4336e-02]],

[[ 9.7318e-02, 1.0526e-01, 2.3400e-03],

[-5.8471e-02, -8.8146e-02, -1.6053e-01],

[-1.0788e-01, -5.9083e-02, -9.0651e-02]],

更新后权重第一层: tensor([[[[ 8.0685e-02, -3.8643e-02, 3.4450e-02],

[-2.3942e-01,

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/笔触狂放9/article/detail/124038?site
推荐阅读
相关标签
  

闽ICP备14008679号