当前位置:   article > 正文

pytorch基础知识整理(三)模型保存与加载_model.module.state_dict()

model.module.state_dict()

1, torch.save(); troch.load()

torch.save()使用python的pickle模块把目标保存到磁盘,可以用来保存模型、张量、字典等,文件后缀名一般用pth或pt或pkl。torch.load()使用python的pickle模块实现从磁盘加载。可以用此来直接保存或加载完整模型:

torch.save(model, 'PATH.pth')
model = torch.load('PATH.pth')
  • 1
  • 2

注意:pytorch1.6以后保存的模型使用zip压缩,所以保存的模型无法被1.6以前的版本加载,如果要跨版本使用,需要做以下修改

torch.save(model, 'PATH.pth', _use_new_zipfile_serialization=False)
  • 1

2, .state_dict(); .load_state_dict()

模型的框架已经在程序代码中了,因此训练好的模型只需要保存模型的参数即可供推理使用。model.state_dict()以字典的形式保存模型的参数,字典的键是参数名,值是参数值的张量。得到状态字典后还需用torch.save()固化到磁盘。
除模型外,优化器optimizer也可以保存和加载状态字典。

torch.save(model.state_dict(), 'PATH.pth')
model.load_state_dict(torch.load('PATH.pth'))
  • 1
  • 2

注意在多卡GPU训练时,保存和加载模型需要在model后加上module,即

torch.save(model.module.state_dict(), 'PATH.pth')
model.module.load_state_dict(torch.load('PATH.pth'))
  • 1
  • 2

3, 保存checkpoint

如果是训练中途保存用于继续训练,就不仅要保存权重参数,还要保存当前epoch,优化器的状态,当前的损失值等,可以统一打包到一个字典中保存为checkpoint,此时文件后缀名一般用tar。

#保存:
torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss,
            ...
            }, PATH)
##加载:
checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/笔触狂放9/article/detail/103462
推荐阅读
相关标签
  

闽ICP备14008679号