当前位置:   article > 正文

PyTorch载入预训练权重方法和冻结权重方法_xception的预训练权重pytorch

xception的预训练权重pytorch

载入预训练权重

1. 直接载入预训练权重

简单粗暴法:

pretrain_weights_path = "./resnet50.pth"
net.load_state_dict(torch.load(pretrain_weights_path))
  • 1
  • 2

如果这里的pretrain_weights与我们训练的网络不同,一般指的是包含大于模型参数时,可以修改为

net.load_state_dict(torch.load(pretrain_weights_path), strict=False)
  • 1

2. 修改网络结构

常用方法1:

model_weight_path = "resnet34pre.pth"
net.load_state_dict(torch.load(model_weight_path))
# 这里假设最后一层为FC层,使用迁移学习,将分类结果修改
# net是实例化的resnet网络,in_features是网络输入结构参数,最后的5是修改的输出参数
inchannel = net.fc.in_features
net.fc = nn.Linear(inchannel, 5)
# 注意,最后去转换我们的模型设备,否则可能会报错,怀疑是修改的模型部分和原模型部分使用的设备不同
net.to(device) 
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8

常用方法2:

net = MobileNetV2(num_class=5)
net_weights = net.state_dict()
model_weights_path = "./mobilenet_v2pre.pth"
pre_weights = torch.load(model_weights_path)
# delete classifier weights
# 这种方法主要是遍历字典,.pth文件(权重文件)的本质就是字典的存储
# 通过改变我们载入的权重的键值对,可以和当前的网络进行配对的
# 这里举到的例子是对"classifier"结构层的键值对剔除,或者说是不载入该模块的训练权重
pre_dict = {k: v for k, v in pre_weights.items() if "classifier" not in k}
# 另一种方法会直接两种权重对比,直接两种方法对比,减少问题的存在
pre_dict = {k: v for k, v in pre_weight.items() 
			if net_weights[k].numel() == v.numel()}
# 如果修改了载入权重或载入权重的结构和当前模型的结构不完全相同,需要加strict=False,保证能够权重载入
net.load_state_dict(pre_dict, strict=False)
net.to(device)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15

灵活提升:

net = resnet34(num_classes=5)
pre_weights = torch.load(model_weight_path, map_location=device)
del_key = []
for key, _ in pre_weights.items():
        if "fc" in key: # 这里可以多加一下字段,比如  or "layer4" in key:
            del_key.append(key)
# missing_keys表示net中的部分权重未出现在pre_weights中
# unexpected_keys表示pre_weights当中有一部权重不在net中
missing_keys, unexpected_keys = net.load_state_dict(del_key, strict=False)
# 执行结果
# [missing_keys]:
# fc.weight
# fc.bias
# [unexpected_keys]:
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14

载入权重常见问题

  1. key-val不匹配问题
    解决方式:模型结构修改了,没有正确修改预训练权重,导致载入权重与模型不同,使用上文中的方法适当修改载入权重即可
  2. 载入预训练权重param名称和模型中的param名称不同,导致载入失败
    解决方法:修改模型中的para名称,或者修改网络中模块的名称。

冻结训练

冻结训练方法很简单,只要对requires_grad = False即可

for name, para in model.named_parameters():
    # 除最后的全连接层外,其他权重全部冻结
    if "fc" not in name:
        para.requires_grad_(False)
        # 或者 para.requires_grad = False
  • 1
  • 2
  • 3
  • 4
  • 5

还有一个小建议,将model中需要反向梯度传播的param单独list出来

pg = [p for p in model.parameters() if p.requires_grad]
optimizer = optim.SGD(pg, lr=args.lr, momentum=0.9, weight_decay=0.005)
  • 1
  • 2
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/繁依Fanyi0/article/detail/369281
推荐阅读
相关标签
  

闽ICP备14008679号