当前位置:   article > 正文

Ptorch深度学习-------断点续训_深度学习resume load

深度学习resume load

1、参数设置一个标志:

        resume:是否进行续训

        initepoch:断点位置的epoch
2、checkpoint载入:

  1. resume = True
  2. initepoch = 0
  3. if resume:
  4. if os.path.isfile(os.path.join(args.save_model_path, 'latest_dice_loss.pth')) or os.path.join(
  5. args.save_model_path, 'latest_dice_loss.pth'):
  6. print("Resume from checkpoint...")
  7. checkpoint = torch.load(os.path.join(args.save_model_path, 'latest_dice_loss.pth'))
  8. model.module.load_state_dict(checkpoint['model_state_dict'])
  9. optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
  10. initepoch = checkpoint['epoch']
  11. print("====>loaded checkpoint (epoch{})".format(checkpoint['epoch']))
  12. else:
  13. print("====>no checkpoint found.")
  14. initepoch = 0

3、保存checkpoint包含的参数

  1. checkpoint = {"model_state_dict": model.module.state_dict(),
  2. "optimizer_state_dict": optimizer.state_dict(),
  3. "epoch": epoch}
  4. checkpoint_path = os.path.join(args.save_model_path, 'latest_dice_loss.pth')
  5. torch.save(checkpoint, checkpoint_path)

声明:本文内容由网友自发贡献,转载请注明出处:【wpsshop博客】
推荐阅读
  

闽ICP备14008679号