赞
踩
在PyTorch中,torch.save()和torch.load()默认使用Python的pickle模块(二进制序列化和反序列化)进行操作,因此你也可以将多个张量保存为Python对象的一部分,如元组、列表和字典(实际上,凡是能pickle的数据结构,包括自建的数据结构,均可save和load),如:
#保存一个字典:
d = {'a': torch.tensor([1., 2.]), 'b': torch.tensor([3., 4.])}
torch.save(d, 'tensor_dict.pt')
torch.load('tensor_dict.pt')
{'a': tensor([1., 2.]), 'b': tensor([3., 4.])}
#还可以保存一个列表:
numbers = torch.arange(1, 10)
evens = numbers[1::2]#[1::2]表示从第2个元素(‘1’)开始,按照步长为2的方式取数
torch.save([numbers, evens], 'tensors.pt')
loaded_numbers, loaded_evens = torch.load('tensors.pt')
loaded_evens *= 2
loaded_numbers
tensor([ 1, 4, 3, 8, 5, 12, 7, 16, 9])
class TheModelClass(nn.Module): def __init__(self): super(TheModelClass, self).__init__() self.conv1 = nn.Conv2d(3, 6, 5) self.pool = nn.MaxPool2d(2, 2) self.conv2 = nn.Conv2d(6, 16, 5) self.fc1 = nn.Linear(16 * 5 * 5, 120) self.fc2 = nn.Linear(120, 84) self.fc3 = nn.Linear(84, 10) def forward(self, x): x = self.pool(F.relu(self.conv1(x))) x = self.pool(F.relu(self.conv2(x))) x = x.view(-1, 16 * 5 * 5) x = F.relu(self.fc1(x)) x = F.relu(self.fc2(x)) x = self.fc3(x) return x # Initialize model model = TheModelClass() # Initialize optimizer optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
输出相应的state_dict字典对象:
print("Model's state_dict:")
for param_tensor in model.state_dict():
print(param_tensor, "\t", model.state_dict()[param_tensor].size())
print("Optimizer's state_dict:")
for var_name in optimizer.state_dict():
print(var_name, "\t", optimizer.state_dict()[var_name])
1.2 保存模型的参数state_dict()(推荐保存方式)
save:
torch.save(model.state_dict(), PATH)
load:
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.eval()
注意: PyTorch1.6版本改变了torch.save()的保存格式,使用新zip文件格式,但torch.load仍然可以加载旧格式的文件。如果想要torch.save仍保存旧格式,只需要torch.save函数中,传递kwarg参数`_use_new_zipfile_serialization=False,即可。
需要特别强调的是,在运行推理之前,必须调用model.eval()将dropout和batch normalization层设置为评估模式,如果不这样做,将产生不一致的推理结果。所有的torch.load模型和模型都要这样执行。
注意: load_state_dict()函数接受一个字典对象,而不是保存对象的路径。这意味着在将保存的state_dict传递给load_state_dict()函数之前,必须反序列化(如使用torch.load(PATH))。例如,不能使用model.load_state_dict(PATH)进行加载。
注意: 如果只想保留性能最好的模型(根据获得的验证损失),必须序列化best_model_state或使用best_model_state = deepcopy(model.state_dict()),否则最好的best_model_state将在后续的训练迭代中不断更新。因此,模型的最终状态将是过拟合模型的状态。
1.3 保留整个模型
save:
torch.save(model, PATH)
load:
model = torch.load(PATH)
model.eval()
1.4 使用TorchScript格式保存和加载模型
TorchScript格式能使模型在Python中运行,也可以在c++等高性能环境中运行。
Export(导出为TorchScipt,然后保存):
model_scripted = torch.jit.script(model) # Export to TorchScript
model_scripted.save('model_scripted.pt') # Save
load:
model = torch.jit.load('model_scripted.pt')
model.eval()
1.6 保存和加载一个用于推理和/或恢复训练的一般检查点(CheckPoint)
这种保存方式很重要,比如在模型训练中,忽然通知要停电,那么这种保存方式就显得比较重要。
save:
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss,
...
}, PATH)
load:
model = TheModelClass(*args, **kwargs)
optimizer = TheOptimizerClass(*args, **kwargs)
checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']
model.eval()
model.train()##重新训练
当保存模型的checkpoint时,必须保存的不仅是模型的state_dict, 保存优化器的state_dict也很重要,因为它包含缓冲区和参数,在模型训练时更新。当然,你也可能想要保存其他项目包括你停止训练的时间、最新记录的训练损失、外部torch.nn.Embedding层等。因此,这样的checkpoint保存方式通常比单独的模型大2~3倍。
1.6 在一个文件夹中保存多个模型(在指定的checkpoint处)
save:
torch.save({
'modelA_state_dict': modelA.state_dict(),
'modelB_state_dict': modelB.state_dict(),
'optimizerA_state_dict': optimizerA.state_dict(),
'optimizerB_state_dict': optimizerB.state_dict(),
...
}, PATH)
load:
modelA = TheModelAClass(*args, **kwargs) modelB = TheModelBClass(*args, **kwargs) optimizerA = TheOptimizerAClass(*args, **kwargs) optimizerB = TheOptimizerBClass(*args, **kwargs) checkpoint = torch.load(PATH) modelA.load_state_dict(checkpoint['modelA_state_dict']) modelB.load_state_dict(checkpoint['modelB_state_dict']) optimizerA.load_state_dict(checkpoint['optimizerA_state_dict']) optimizerB.load_state_dict(checkpoint['optimizerB_state_dict']) modelA.eval() modelB.eval() # - or - modelA.train() modelB.train()
注:像GAN模型就需要这样保存,一个是生成器,一个是判别器。
1.7 使用来自不同模型的参数预热模型(必须是同构模型才可以)
save:
torch.save(modelA.state_dict(), PATH)
load:
modelB = TheModelBClass(*args, **kwargs)
modelB.load_state_dict(torch.load(PATH), strict=False)
在迁移学习中,这种保存方式很实用。在迁移学习或训练新的复杂模型时,部分加载模型或加载部分模型是常见的场景。利用已训练的参数,即使只有少量可用,也将有助于预热训练过程,并有望帮助您的模型以比从头开始训练更快的速度收敛。如果想从一层加载参数到另一层,但有些键不匹配,只需更改正在加载的state_dict中的参数键的名称,以匹配正在加载的模型中的键。
1.8 跨设备保存和加载模型—在GPU上保存,在CPU上加载
save:
torch.save(model.state_dict(), PATH)
load:
device = torch.device('cpu')
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH, map_location=device))
当在使用GPU训练的CPU上加载模型时,将torch.device(’ CPU ')传递给torch.load()函数中的map_location参数。在这种情况下,使用map_location参数将张量的底层存储动态重映射到CPU设备。
1.9 跨设备保存和加载模型—在GPU上保存,在GPU上加载
save:
torch.save(model.state_dict(), PATH)
load:
device = torch.device("cuda")
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.to(device)
当在GPU上训练并保存的GPU上加载模型时,只需使用model.to(torch.device(’ CUDA '))将初始化的模型转换为CUDA优化的模型。此外,请确保在所有模型输入上使用。to(torch.device(‘cuda’))函数来为模型准备数据。请注意,调用my_tensor.to(device)会在GPU上返回一个新的my_tensor副本。它不会覆盖my_tensor。因此请记住手动重写张量:
my_tensor = my_tensor.to(torch.device(‘cuda’))。
1.10 跨设备保存和加载模型—在CPU上保存,在GPU上加载
save:
torch.save(model.state_dict(), PATH)
load:
device = torch.device("cuda")
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH, map_location="cuda:0")) # Choose whatever GPU device number you want
model.to(device)
当在经过训练并保存在CPU上的GPU上加载模型时,将torch.load()函数中的map_location参数设置为cuda:device_id。这会将模型加载到给定的GPU设备。接下来,确保调用model.to(torch.device(‘cuda’))将模型的参数张量转换为cuda张量。最后,请确保在所有模型输入上使用。to(torch.device(‘cuda’))函数来为cuda优化模型准备数据。请注意,调用my_tensor.to(device)会在GPU上返回一个新的my_tensor副本。它不会覆盖my_tensor。因此,请记住手动重写张量:
my_tensor = my_tensor.to(torch.device(‘cuda’))。
1.11 保存并行计算的模型
save:
torch.save(model.module.state_dict(), PATH)
load:
利用常规方法:加载到目标设备
torch.nn.DataParallel是一个模型包装器,可以并行地利用GPU计算。要保存DataParallel模型,请保存model.module.state_dict()。这样,你可以灵活地以任何方式加载模型到目标设备上。
附:该文章参考Pytorch的官方教程
https://pytorch.org/tutorials/beginner/saving_loading_models.html
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。