赞
踩
一、只保存网络中的参数
保存:
torch.save(model.state_dict(), save_fp)
加载的时候需要先初始化一个模型,然后把文件中的参数恢复。
- train_weights = torch.load(model_fp)
- model = Model()
- model.load_state_dict(model_weights)
这里load得到的是变量类型为OrderedDict(),也就是网络中的参数集合。
二、保存网络结构和参数
保存:
torch.save(model, save_fp)
加载:
model = torch.load(model_fp)
这里load的到的一个对象,类型是<class Model>
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。