当前位置:   article > 正文

【pytorch】网络模型的保存与读取_torch.save第一种保存方法读取时候要网络

torch.save第一种保存方法读取时候要网络

Pytorch会把模型相关信息保存为一个字典结构的数据,以用于继续训练或者推理。

1. 模型保存

常见的模型保存方式(PyTorch的模型一般以.pt或者.pth文件格式保存。)

# 保存方式1,模型结构+参数
torch.save(model,"xxx.pth") #"xxx.pth"可为PATH

# 保存方式2,模型参数(官方推荐)
torch.save(model.state_dict(),"xxx.pth") 
#把model的状态保存成**字典**格式,只保存网络模型参数,对比较大的模型占用空间小
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6

除了模型参数之外,torch还可以保存其他训练相关参数,例如学习率、优化器信息等如:

torch.save({'model':model.state_dict(),
			 'optimizer':optimizer.state_dict(),
			 'epoch':epoch_num
			 "global_step": step
			 },'xxx.pth') #'xxx.pth'也可以是PATH
  • 1
  • 2
  • 3
  • 4
  • 5

也可以将字典单拎出来分开

state = {'model':model.state_dict(),
		'optimizer':optimizer.state_dict(),
		'epoch':epoch_num
		"global_step": step
			 }
torch.save(state, 'xxx.pth')
#或者
state_path = "./xx/xxx.pth"
torch.save(state, state_path)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9

2.模型加载

基本的加载方式如下:

# 加载单个数据字典
model.load_state_dict(torch.load("xxx.pth")) #"xxx.pth"可为PATH
  • 1
  • 2

加载用于推理的常规Checkpoint常包含其他训练相关参数,例如学习率、优化器信息等如:

#若保存方式为:
torch.save({'model':model.state_dict(),
			 'optimizer':optimizer.state_dict(),
			 'epoch':epoch_num
			 "global_step": step
			 }, checkpoint_path) #checkpoint_path可为文件路径如:xx/xxx.pth
##加载方式可以为:
checkpoint = torch.load(checkpoint_path)

model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
epoch = checkpoint['epoch']
global_step = checkpoint["global_step"]
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13

此外,load提供了很多重载的功能,其可以把在GPU上训练的权重加载到CPU上跑:

torch.load('tensors.pt')
# 强制所有GPU张量加载到CPU中
torch.load('tensors.pt', map_location=lambda storage, loc: storage)  #或者model.load_state_dict(torch.load('model.pth', map_location='cpu'))
# 把所有的张量加载到GPU 1中
torch.load('tensors.pt', map_location=lambda storage, loc: storage.cuda(1))
# 把张量从GPU 1 移动到 GPU 0
torch.load('tensors.pt', map_location={'cuda:1':'cuda:0'})
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7

参考:pytorch模型的保存和加载、checkpoint

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/Cpp五条/article/detail/103383
推荐阅读
相关标签
  

闽ICP备14008679号