当前位置:   article > 正文

pytorch 断点训练,从指定epoch恢复训练_pytorch训练中断后怎么恢复

pytorch训练中断后怎么恢复

1、保存模型

保存整个模型

  1. torch.save(net, path)

保存权重

  1. state_dict = net.state_dict()
  2. torch.save(state_dict , path)

2、模型训练过程保存

  1. checkpoint = {
  2. "net": model.state_dict(),
  3. 'optimizer':optimizer.state_dict(),
  4. "epoch": epoch
  5. }

3、指定epoch恢复

  1. path_checkpoint = "./models/checkpoint/ckpt_best_1.pth" # 断点路径
  2. checkpoint = torch.load(path_checkpoint) # 加载断点
  3. model.load_state_dict(checkpoint['net']) # 加载模型可学习参数
  4. optimizer.load_state_dict(checkpoint['optimizer']) # 加载优化器参数
  5. start_epoch = checkpoint['epoch'] # 设置开始的epoch

4、完整流程

  1. start_epoch = -1
  2. if RESUME:
  3. path_checkpoint = "./models/checkpoint/ckpt_best_1.pth" # 断点路径
  4. checkpoint = torch.load(path_chec
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/weixin_40725706/article/detail/998546
推荐阅读
相关标签
  

闽ICP备14008679号