当前位置:   article > 正文

pytorch中保存网络模型的两种方式_python model save 存储网络结构

python model save 存储网络结构

一、只保存网络中的参数

保存:

torch.save(model.state_dict(), save_fp)

加载的时候需要先初始化一个模型,然后把文件中的参数恢复。

  1. train_weights = torch.load(model_fp)
  2. model = Model()
  3. model.load_state_dict(model_weights)

这里load得到的是变量类型为OrderedDict(),也就是网络中的参数集合。

二、保存网络结构和参数

保存:

torch.save(model, save_fp)

加载:

model = torch.load(model_fp)

这里load的到的一个对象,类型是<class Model>

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/小蓝xlanll/article/detail/103362
推荐阅读
相关标签
  

闽ICP备14008679号