当前位置:   article > 正文

使用PyTorch冻结模型参数的方法_pytorch冻结一部分参数

pytorch冻结一部分参数

前言

深度学习领域,经常需要使用其他人已训练好的模型进行改进或微调,这个时候我们会加载已有的预训练模型文件的参数,如果网络结构不变,希望使用新数据微调部分网络参数。这时我们则需要冻结部分参数,禁止其更新。

在这里插入图片描述

方法

(1)通过遍历网络结构,设置梯度更新requires_grad = False。

 # 冻结network1的全部参数和network2的部分参数
 for name, parameter in network1.named_parameters():
     parameter.requires_grad = False

 for name, parameter in network2.named_parameters():
     if 'key' in name:
         parameter.requires_grad = False
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7

(2)优化器中过滤filter冻结的参数

optimizer_network2 = torch.optim.Adam(filter(lambda p: p.requires_grad, network2.parameters()), lr=0.005, betas=(0.5, 0.999))
  • 1

其他

结合加载模型部分参数的情况,优化器需要按如下设置:

   optimizer_network2 = torch.optim.Adam([{'params': filter(lambda p: p.requires_grad, network2.parameters()), 'initial_lr': 0.0002}], lr=0.005, betas=(0.5, 0.999))
  • 1

在这里插入图片描述

参考资料

[1] csdn - 使用PyTorch加载模型部分参数方法
[2] 知乎 - Pytorch自由载入部分模型参数并冻结

声明:本文内容由网友自发贡献,转载请注明出处:【wpsshop】
推荐阅读
相关标签
  

闽ICP备14008679号