当前位置:   article > 正文

PyTorch学习笔记:使用state_dict来保存和加载模型

state_dict

1. state_dict简介

state_dict是Python的字典对象,可用于保存模型参数、超参数以及优化器(torch.optim)的状态信息。需要注意的是,只有具有可学习参数的层(如卷积层、线性层等)才有state_dict。

下面就拿官方教程中的一个小示例来说明state_dict的使用:

  1. import torch
  2. import torch.nn as nn
  3. import torch.optim as optim
  4. # 定义模型
  5. class TheModelClass(nn.Module):
  6. def __init__(self):
  7. super(TheModelClass, self).__init__()
  8. self.conv1 = nn.Conv2d(3, 6, 5)
  9. self.pool = nn.MaxPool2d(2, 2)
  10. self.conv2 = nn.Conv2d(6, 16, 5)
  11. self.fc1 = nn.Linear(16 * 5 * 5, 120)
  12. self.fc2 = nn.Linear(120, 84)
  13. self.fc3 = nn.Linear(84, 10)
  14. def forward(self, x):
  15. x = self.pool(F.relu(self.conv1(x)))
  16. x = self.pool(F.relu(self.conv2(x)))
  17. x = x.view(-1, 16 * 5 * 5)
  18. x = F.relu(self.fc1(x))
  19. x = F.relu(self.fc2(x))
  20. x = self.fc3(x)
  21. return x
  22. # 初始化模型
  23. model = TheModelClass()
  24. # 初始化优化器
  25. optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
  26. # 打印模型的状态字典
  27. print("Model's state_dict:")
  28. for param_tensor in model.state_dict():
  29. print(param_tensor, "\t", model.state_dict()[param_tensor].size())
  30. # 打印优化器的状态字典
  31. print("Optimizer's state_dict:")
  32. for var_name in optimizer.state_dict():
  33. print(var_name, "\t", optimizer.state_dict()[var_name])

让我们来运行一下以上代码:

从以上代码及运行结果可知,state_dict将模型的每一层映射到一个参数张量。在Python中,可以对state_dict进行保存、加载、更新、修改等操作。

下面我们就来看一下PyTorch如何通过state_dict来保存和加载模型。

2. 保存和加载state_dict

可以通过torch.save()来保存模型的state_dict,即只保存学习到的模型参数,并通过load_state_dict()来加载并恢复模型参数。PyTorch中最常见的模型保存扩展名为'.pt'或'.pth'。

下面我们就将上个例子中构造的简单模型TheModelClass的参数保存在state_dict,然后通过load_state_dict()来加载模型参数。

  1. ......
  2. # 将模型保存到当前路径,名称为test_state_dict.pth
  3. PATH = './test_state_dict.pth'
  4. torch.save(model.state_dict(), PATH)
  5. model = TheModelClass() # 首先通过代码获取模型结构
  6. model.load_state_dict(torch.load(PATH)) # 然后加载模型的state_dict
  7. model.eval()

注意:load_state_dict()函数只接受字典对象,不可直接传入模型路径,所以需要先使用torch.load()反序列化已保存的state_dict。

另外,在使用模型做推理之前,需要调用model.eval()函数将dropout和batch normalization层设置为评估模式,否则会导致模型推理结果不一致。 

当然,除了保存state_dict,PyTorch还支持保存和加载整个模型。

3. 保存和加载完整模型

保存和加载整个模型的代码如下:

  1. # 保存完整模型
  2. torch.save(model, PATH)
  3. # 加载完整模型
  4. model = torch.load(PATH)
  5. model.eval()

这种方式虽然代码看起来较state_dict方式要简洁,但是灵活性会差一些。因为torch.save()函数使用Python的pickle模块进行序列化,但pickle无法保存模型本身,而是保存包含类的文件路径,该文件会在模型加载时使用。所以当在其他项目对模型进行重构之后,就可能会出现意想不到的错误。

4. 保存和加载checkpoint用于继续训练或推理

除了以上两种保存模型的方式,PyTorch还支持以checkpoint方式保存模型训练的中间结果,以实现模型的继续训练或者推理。这种方式下,保存的内容不仅包含模型的state_dict,还会保存优化器的state_dict,以及其他参数如loss、epoch等。

保存:

  1. torch.save({
  2. 'epoch': epoch,
  3. 'model_state_dict': model.state_dict(),
  4. 'optimizer_state_dict': optimizer.state_dict(),
  5. 'loss': loss,
  6. ...
  7. }, PATH)

加载:

  1. model = TheModelClass()
  2. optimizer = TheOptimizerClass()
  3. checkpoint = torch.load(PATH)
  4. model.load_state_dict(checkpoint['model_state_dict'])
  5. optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
  6. epoch = checkpoint['epoch']
  7. loss = checkpoint['loss']
  8. model.eval()
  9. model.train()

checkpoint在PyTorch中常保存为.tar的文件扩展名。

注:以上checkpoint保存和加载的代码未经本人测试。

5. 迁移学习下的热启动模式

我们在工程中,常常用到迁移学习,利用训练好的模型在新的数据集上进行迁移训练,可达到使用少量数据进行快速训练的目的。

在迁移学习中,我们常常需要对预训练模型进行部分加载的需要,这个时候我们就要用到热启动模式,可通过在load_state_dict()函数中将strict参数设置为False来忽略非匹配键的参数。

  1. # 保存模型state_dict
  2. torch.save(modelA.state_dict(), PATH)
  3. # 热加载模型
  4. modelB = TheModelBClass()
  5. modelB.load_state_dict(torch.load(PATH), strict=False)
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/花生_TL007/article/detail/406980
推荐阅读
相关标签
  

闽ICP备14008679号