赞
踩
torch.save()使用python的pickle模块把目标保存到磁盘,可以用来保存模型、张量、字典等,文件后缀名一般用pth或pt或pkl。torch.load()使用python的pickle模块实现从磁盘加载。可以用此来直接保存或加载完整模型:
torch.save(model, 'PATH.pth')
model = torch.load('PATH.pth')
注意:pytorch1.6以后保存的模型使用zip压缩,所以保存的模型无法被1.6以前的版本加载,如果要跨版本使用,需要做以下修改
torch.save(model, 'PATH.pth', _use_new_zipfile_serialization=False)
模型的框架已经在程序代码中了,因此训练好的模型只需要保存模型的参数即可供推理使用。model.state_dict()以字典的形式保存模型的参数,字典的键是参数名,值是参数值的张量。得到状态字典后还需用torch.save()固化到磁盘。
除模型外,优化器optimizer也可以保存和加载状态字典。
torch.save(model.state_dict(), 'PATH.pth')
model.load_state_dict(torch.load('PATH.pth'))
注意在多卡GPU训练时,保存和加载模型需要在model后加上module,即
torch.save(model.module.state_dict(), 'PATH.pth')
model.module.load_state_dict(torch.load('PATH.pth'))
如果是训练中途保存用于继续训练,就不仅要保存权重参数,还要保存当前epoch,优化器的状态,当前的损失值等,可以统一打包到一个字典中保存为checkpoint,此时文件后缀名一般用tar。
#保存:
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss,
...
}, PATH)
##加载:
checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。