赞
踩
state_dict是Python的字典对象,可用于保存模型参数、超参数以及优化器(torch.optim)的状态信息。需要注意的是,只有具有可学习参数的层(如卷积层、线性层等)才有state_dict。
下面就拿官方教程中的一个小示例来说明state_dict的使用:
- import torch
- import torch.nn as nn
- import torch.optim as optim
-
- # 定义模型
- class TheModelClass(nn.Module):
- def __init__(self):
- super(TheModelClass, self).__init__()
- self.conv1 = nn.Conv2d(3, 6, 5)
- self.pool = nn.MaxPool2d(2, 2)
- self.conv2 = nn.Conv2d(6, 16, 5)
- self.fc1 = nn.Linear(16 * 5 * 5, 120)
- self.fc2 = nn.Linear(120, 84)
- self.fc3 = nn.Linear(84, 10)
-
- def forward(self, x):
- x = self.pool(F.relu(self.conv1(x)))
- x = self.pool(F.relu(self.conv2(x)))
- x = x.view(-1, 16 * 5 * 5)
- x = F.relu(self.fc1(x))
- x = F.relu(self.fc2(x))
- x = self.fc3(x)
- return x
-
- # 初始化模型
- model = TheModelClass()
-
- # 初始化优化器
- optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
-
- # 打印模型的状态字典
- print("Model's state_dict:")
- for param_tensor in model.state_dict():
- print(param_tensor, "\t", model.state_dict()[param_tensor].size())
-
- # 打印优化器的状态字典
- print("Optimizer's state_dict:")
- for var_name in optimizer.state_dict():
- print(var_name, "\t", optimizer.state_dict()[var_name])
让我们来运行一下以上代码:
从以上代码及运行结果可知,state_dict将模型的每一层映射到一个参数张量。在Python中,可以对state_dict进行保存、加载、更新、修改等操作。
下面我们就来看一下PyTorch如何通过state_dict来保存和加载模型。
可以通过torch.save()来保存模型的state_dict,即只保存学习到的模型参数,并通过load_state_dict()来加载并恢复模型参数。PyTorch中最常见的模型保存扩展名为'.pt'或'.pth'。
下面我们就将上个例子中构造的简单模型TheModelClass的参数保存在state_dict,然后通过load_state_dict()来加载模型参数。
- ......
-
- # 将模型保存到当前路径,名称为test_state_dict.pth
- PATH = './test_state_dict.pth'
- torch.save(model.state_dict(), PATH)
-
- model = TheModelClass() # 首先通过代码获取模型结构
- model.load_state_dict(torch.load(PATH)) # 然后加载模型的state_dict
- model.eval()
注意:load_state_dict()函数只接受字典对象,不可直接传入模型路径,所以需要先使用torch.load()反序列化已保存的state_dict。
另外,在使用模型做推理之前,需要调用model.eval()函数将dropout和batch normalization层设置为评估模式,否则会导致模型推理结果不一致。
当然,除了保存state_dict,PyTorch还支持保存和加载整个模型。
保存和加载整个模型的代码如下:
- # 保存完整模型
- torch.save(model, PATH)
-
- # 加载完整模型
- model = torch.load(PATH)
- model.eval()
这种方式虽然代码看起来较state_dict方式要简洁,但是灵活性会差一些。因为torch.save()函数使用Python的pickle
模块进行序列化,但pickle无法保存模型本身,而是保存包含类的文件路径,该文件会在模型加载时使用。所以当在其他项目对模型进行重构之后,就可能会出现意想不到的错误。
除了以上两种保存模型的方式,PyTorch还支持以checkpoint方式保存模型训练的中间结果,以实现模型的继续训练或者推理。这种方式下,保存的内容不仅包含模型的state_dict,还会保存优化器的state_dict,以及其他参数如loss、epoch等。
保存:
- torch.save({
- 'epoch': epoch,
- 'model_state_dict': model.state_dict(),
- 'optimizer_state_dict': optimizer.state_dict(),
- 'loss': loss,
- ...
- }, PATH)
加载:
- model = TheModelClass()
- optimizer = TheOptimizerClass()
-
- checkpoint = torch.load(PATH)
- model.load_state_dict(checkpoint['model_state_dict'])
- optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
- epoch = checkpoint['epoch']
- loss = checkpoint['loss']
-
- model.eval()
-
- model.train()
checkpoint在PyTorch中常保存为.tar的文件扩展名。
注:以上checkpoint保存和加载的代码未经本人测试。
我们在工程中,常常用到迁移学习,利用训练好的模型在新的数据集上进行迁移训练,可达到使用少量数据进行快速训练的目的。
在迁移学习中,我们常常需要对预训练模型进行部分加载的需要,这个时候我们就要用到热启动模式,可通过在load_state_dict()函数中将strict参数设置为False来忽略非匹配键的参数。
- # 保存模型state_dict
- torch.save(modelA.state_dict(), PATH)
-
- # 热加载模型
- modelB = TheModelBClass()
- modelB.load_state_dict(torch.load(PATH), strict=False)
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。