当前位置:   article > 正文

Pytorch 保存加载模型时的坑_model.module.state_dict()和model.state_dict()的区别

model.module.state_dict()和model.state_dict()的区别

Pytorch 保存加载模型时的坑

在说Pytorch保存加载模型时的坑之前,先介绍一下pytorch对训练好的模型如何进行保存和加载。

方法1:保存模型的参数和结构信息

保存:

model=MobileNetV2(n_class=2)#加载模型
############进行训练##########
 model = torch.nn.DataParallel(model, device_ids=[int(i) for i in args.gpus.strip().split(',')])#用多gpus 训练×××关键
############进行训练##########
torch.save(model, os.path.join(args.save_path, "epoch_" + str(epoch) + ".pth.tar"))#保存模型
  • 1
  • 2
  • 3
  • 4
  • 5

恢复:

model=torch.load(args.load_path)#
  • 1

这种方法会出现一个问题:当利用pytorch 1.0.0 保存好了模型后,加载时利用pytorch1.1.0 进行load() 时回报错,所以官方推荐使用第二种方法进行加载

方法二:官方推荐的方法,只保存和恢复模型中的参数

一个完整的例子:
迁移学习加载模型(此时 checkpoint 字典只有 state_dict ):

model=MobileNetV2(n_class=2)#加载模型结构
model_dict =  model.state_dict()#获取模型参数(未加载保存的模型参数 )
if args.resume:#模型路径
    if os.path.isfile(args.resume):
        print(("=> loading checkpoint '{}'".format(args.resume)))
        checkpoint = torch.load(args.resume)#获取模型参数
         #因为我修改网络模型进行迁移学习,这一步是在checkpoint里获取没有修改的模型参数state_dict
        state_dict = {k: v for k, v in checkpoint.items() if k in model_dict.keys()}
        model_dict.update(state_dict)#更新已经保存的参数至model_dict
        model.load_state_dict(model_dict)#加载模型参数
    else:
        print(("=> no checkpoint found at '{}'".format(args.resume)))
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12

保存:–这里有坑

torch.save({"epoch":epoch, #一共训练的epoch
                   "model_state_dict":model.module.state_dict(), #保存模型参数×××××这里埋个坑××××
                   'epoch_acc': epoch_acc, #一共训练的epoch
                   "optimizer":optimizer.state_dict() }#优化器好像也在保存,这样可以继续加载模型进行训练
                   ,os.path.join(args.save_path,"checkpoints_epoch_" + str(epoch) + ".tar"))      
  • 1
  • 2
  • 3
  • 4
  • 5

再加载:

print("start loading cls model")
model=MobileNetV2(n_class=2)
if os.path.isfile(args.load_path):
    state_dict=torch.load(args.load_path)
    print(state_dict['epoch'])#获取保存的参数 对应key值的参数
    print(state_dict['epoch_acc'])
    params=state_dict["model_state_dict"] 
    for param_tensor in params:#打印参数信息
         print(param_tensor,"\t",params[param_tensor].size())
    model.load_state_dict(params)
    print("load cls model successfully")             
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11

填坑

这段保存模型参数的代码

torch.save({"epoch":epoch, #一共训练的epoch
                   "model_state_dict":model.module.state_dict(), #保存模型参数
                   'epoch_acc': epoch_acc, #一共训练的epoch
                   "optimizer":optimizer.state_dict() }#优化器好像也在保存,这样可以继续加载模型进行训练
                   ,os.path.join(args.save_path,"checkpoints_epoch_" + str(epoch) + ".tar")) 

torch.save({"epoch":epoch, #一共训练的epoch
                   "model_state_dict":model.state_dict(), #保存模型参数
                   'epoch_acc': epoch_acc, #一共训练的epoch
                   "optimizer":optimizer.state_dict() }#优化器好像也在保存,这样可以继续加载模型进行训练
                   ,os.path.join(args.save_path,"checkpoints_epoch_" + str(epoch) + ".tar")) 
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11

与这段的不同在于model.module.state_dict()与model.state_dict()的区别
现在来打印一下

model=MobileNetV2(n_class=2)#加载模型结构
model_dict =  model.state_dict()#获取模型参数(未加载保存的模型参数 )
model_dict----------model.module.state_dict()---------model.state_dict()三者参数的对应的名称(这里只打印几个)
model_dict:
features.0.0.weight 	 torch.Size([32, 3, 3, 3])
features.0.1.weight 	 torch.Size([32])
features.0.1.bias 	 torch.Size([32])
features.0.1.running_mean 	 torch.Size([32])
features.0.1.running_var 	 torch.Size([32])
features.0.1.num_batches_tracked 	 torch.Size([])

model.module.state_dict():
features.0.0.weight 	 torch.Size([32, 3, 3, 3])
features.0.1.weight 	 torch.Size([32])
features.0.1.bias 	 torch.Size([32])
features.0.1.running_mean 	 torch.Size([32])
features.0.1.running_var 	 torch.Size([32])
features.0.1.num_batches_tracked 	 torch.Size([])

model.state_dict():
module.features.0.0.weight 	 torch.Size([32, 3, 3, 3])
module.features.0.1.weight 	 torch.Size([32])
module.features.0.1.bias 	 torch.Size([32])
module.features.0.1.running_mean 	 torch.Size([32])
module.features.0.1.running_var 	 torch.Size([32])
module.features.0.1.num_batches_tracked 	 torch.Size([])
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26

用多gpus进行训练后直接用model.state_dict()进行保存的模型,每个层参数的名称前面会加上module,这时候再用单卡 gpu model_dict加载model.state_dict()参数时会出现名称不匹配的情况。
因此保存模型时注意使用model.module.state_dict():

总结

1.多gpus训练 用model.state_dict() 保存前面会加上网络参数名称前会加上 module
2.单gpus加载模型,需要去掉网络参数名称前加上的module
两种方法:
(1) 用model.module.state_dict()保存
(2) 去掉网络参数名称前会加上的module再加载模型
3.推荐多gpus训练使用model.module.state_dict()保存,然后单gpu加载,
此时如果还需要多gpu训练可以在加载模型参数后使用torch.nn.DataParallel进行训练

还有另外的思路可参考 @[参考这里][参考这里]多gpu训练(https://blog.csdn.net/CV_YOU/article/details/86670188)(https://blog.csdn.net/qq_32998593/article/details/89343507)

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/AllinToyou/article/detail/103447
推荐阅读
相关标签
  

闽ICP备14008679号