赞
踩
torch.nn.Module 类中提供了将模型参数的作为字典映射保存和加载的方法
因此,可以在训练过程中对最优模型的参数进行保存
1、state_dict()
返回一个包含 Module 实例完整状态的字典
包括参数和缓冲区,字典的键值是参数或缓冲区的名称
2、load_state_dict(state_dict, strict=True)
从 state_dict 中复制参数和缓冲区到 Module 及其子类中
state_dict:包含参数和缓冲区的 Module 状态字典
strict:默认 True,是否严格匹配 state_dict 的键值和 Module.state_dict()的键值
假设定义了一个CNN模型:
- # Save model state
- best_model_weights = copy.deepcopy(cnn_model.state_dict())
-
- # Load model state
- cnn_model.load_state_dict(best_model_weights)
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。