当前位置:   article > 正文

pytorch 迁移学习加载预训练模型,并进行修改_基于模型的迁移学习加载预训练模型torch

基于模型的迁移学习加载预训练模型torch

在迁移学习中,要加载预训练模型,如果是torch内置的一些模型网上有很多的方法很简单,但是当加载自己训练完成的模型以后,如何解决最后连接层输入维度不一致的问题,看了好几个帖子都不成功,最后看了一下,既然是将参数param转换成dict,不如直接将不需要的删掉即可。
于是可以用如下方法

	pretrained_params = torch.load(path)
    net = YOUR_OWN_MODEL()
    state_dict = pretrained_params.state_dict()
    del state_dict[xxx] #需要被删掉的参数
    net.load_state_dict(state_dict, strict=False)
  • 1
  • 2
  • 3
  • 4
  • 5

简单有效!

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/花生_TL007/article/detail/217609
推荐阅读
相关标签
  

闽ICP备14008679号