赞
踩
Pytorch会把模型相关信息保存为一个字典结构的数据,以用于继续训练或者推理。
常见的模型保存方式(PyTorch的模型一般以.pt或者.pth文件格式保存。)
# 保存方式1,模型结构+参数
torch.save(model,"xxx.pth") #"xxx.pth"可为PATH
# 保存方式2,模型参数(官方推荐)
torch.save(model.state_dict(),"xxx.pth")
#把model的状态保存成**字典**格式,只保存网络模型参数,对比较大的模型占用空间小
除了模型参数之外,torch还可以保存其他训练相关参数,例如学习率、优化器信息等如:
torch.save({'model':model.state_dict(),
'optimizer':optimizer.state_dict(),
'epoch':epoch_num
"global_step": step
},'xxx.pth') #'xxx.pth'也可以是PATH
也可以将字典单拎出来分开
state = {'model':model.state_dict(),
'optimizer':optimizer.state_dict(),
'epoch':epoch_num
"global_step": step
}
torch.save(state, 'xxx.pth')
#或者
state_path = "./xx/xxx.pth"
torch.save(state, state_path)
基本的加载方式如下:
# 加载单个数据字典
model.load_state_dict(torch.load("xxx.pth")) #"xxx.pth"可为PATH
加载用于推理的常规Checkpoint常包含其他训练相关参数,例如学习率、优化器信息等如:
#若保存方式为:
torch.save({'model':model.state_dict(),
'optimizer':optimizer.state_dict(),
'epoch':epoch_num
"global_step": step
}, checkpoint_path) #checkpoint_path可为文件路径如:xx/xxx.pth
##加载方式可以为:
checkpoint = torch.load(checkpoint_path)
model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
epoch = checkpoint['epoch']
global_step = checkpoint["global_step"]
此外,load提供了很多重载的功能,其可以把在GPU上训练的权重加载到CPU上跑:
torch.load('tensors.pt')
# 强制所有GPU张量加载到CPU中
torch.load('tensors.pt', map_location=lambda storage, loc: storage) #或者model.load_state_dict(torch.load('model.pth', map_location='cpu'))
# 把所有的张量加载到GPU 1中
torch.load('tensors.pt', map_location=lambda storage, loc: storage.cuda(1))
# 把张量从GPU 1 移动到 GPU 0
torch.load('tensors.pt', map_location={'cuda:1':'cuda:0'})
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。