赞
踩
在深度学习领域,经常需要使用其他人已训练好的模型进行改进或微调,这个时候我们会加载已有的预训练模型文件的参数,如果网络结构不变,希望使用新数据微调部分网络参数。这时我们则需要冻结部分参数,禁止其更新。
(1)通过遍历网络结构,设置梯度更新requires_grad = False。
# 冻结network1的全部参数和network2的部分参数
for name, parameter in network1.named_parameters():
parameter.requires_grad = False
for name, parameter in network2.named_parameters():
if 'key' in name:
parameter.requires_grad = False
(2)优化器中过滤filter
冻结的参数
optimizer_network2 = torch.optim.Adam(filter(lambda p: p.requires_grad, network2.parameters()), lr=0.005, betas=(0.5, 0.999))
结合加载模型部分参数的情况,优化器需要按如下设置:
optimizer_network2 = torch.optim.Adam([{'params': filter(lambda p: p.requires_grad, network2.parameters()), 'initial_lr': 0.0002}], lr=0.005, betas=(0.5, 0.999))
[1] csdn - 使用PyTorch加载模型部分参数方法
[2] 知乎 - Pytorch自由载入部分模型参数并冻结
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。