赞
踩
1、参数设置一个标志:
resume:是否进行续训
initepoch:断点位置的epoch
2、checkpoint载入:
- resume = True
- initepoch = 0
- if resume:
- if os.path.isfile(os.path.join(args.save_model_path, 'latest_dice_loss.pth')) or os.path.join(
- args.save_model_path, 'latest_dice_loss.pth'):
- print("Resume from checkpoint...")
- checkpoint = torch.load(os.path.join(args.save_model_path, 'latest_dice_loss.pth'))
- model.module.load_state_dict(checkpoint['model_state_dict'])
- optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
- initepoch = checkpoint['epoch']
- print("====>loaded checkpoint (epoch{})".format(checkpoint['epoch']))
- else:
- print("====>no checkpoint found.")
- initepoch = 0
3、保存checkpoint包含的参数
- checkpoint = {"model_state_dict": model.module.state_dict(),
- "optimizer_state_dict": optimizer.state_dict(),
- "epoch": epoch}
- checkpoint_path = os.path.join(args.save_model_path, 'latest_dice_loss.pth')
-
- torch.save(checkpoint, checkpoint_path)
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。