赞
踩
这篇博客主要是对使用PyTorch保存和加载训练模型参数的一个学习记录。第1-5小节是比较常规的模型参数保存操作,第6小是用已经训练好的模型参数来初始化新的模型,包括从一层加载到另一层,某些参数名不匹配的情况,也给出了实验代码和结果,完整实验项目见github。如果对您有所帮助,欢迎关注点赞~
觉得文章有收获,欢迎关注公众号鼓励一下作者呀~
在学习的过程中,也搜集了一些量化、技术的视频及书籍资源,欢迎大家关注公众号【亚里随笔】获取
在PyTorch中,torch.nn.Module的可学习参数(i.e. weights and biases),保存在模型的parameters中,它可以通过model.parameters()进行访问。state_dict是一个从参数名称映射到参数Tensor的字典对象。注意,只有具有可学习参数的层(卷积层、线性层等)和已经注册的缓冲区(bachnorm’s running _mean)才有state_dict中的条目。优化器(optim)也有一个state_dict,其中包含关于优化器状态以及所使用的超参数的信息。由于state_dic对象是Python字典,因此可以轻松地保存、更新、更改和还原它们,从而为PyTorch模型和优化增加了很多模块性。
从训练分类器教程中使用的简单模型看一下state_dict。
# Define model
class TheModelClass(nn.Module):
def __init__(self):
super(TheModelClass, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 5 * 5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
# Initialize model
model = TheModelClass()
# Initialize optimizer
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
# Print model's state_dict
print("Model's state_dict:")
for param_tensor in model.state_dict():
print(param_tensor, "\t", model.state_dict()[param_tensor].size())
# Print optimizer's state_dict
print("Optimizer's state_dict:")
for var_name in optimizer.state_dict():
print(var_name, "\t", optimizer.state_dict()[var_name])
从输出结果可以看出,每一层的模型参数的名称格式是:层名.参数;如果有它的一层是由另一个类定义的话,那么就把层名往后扩展:层名.层名…参数。下面对上述代码的模型进行重新整理,验证一下。
TheModelClass(
(conv1): Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))
(pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
(fc1): Linear(in_features=400, out_features=120, bias=True)
(fc2): Linear(in_features=120, out_features=84, bias=True)
(fc3): Linear(in_features=84, out_features=10, bias=True)
)
Model's state_dict:
conv1.weight torch.Size([6, 3, 5, 5])
conv1.bias torch.Size([6])
conv2.weight torch.Size([16, 6, 5, 5])
conv2.bias torch.Size([16])
fc1.weight torch.Size([120, 400])
fc1.bias torch.Size([120])
fc2.weight torch.Size([84, 120])
fc2.bias torch.Size([84])
fc3.weight torch.Size([10, 84])
fc3.bias torch.Size([10])
模型重新整理的代码与结果:
class TheModelClass(nn.Module):
def __init__(self):
super(TheModelClass, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc_the_model_class = FC()
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 5 * 5)
x = self.fc_the_model_class(x)
return x
class FC(nn.Module):
def __init__(self):
super(FC, self).__init__()
self.fc = nn.Sequential(
nn.Linear(16 * 5 * 5, 120),
nn.ReLU(),
nn.Linear(120, 84),
nn.ReLU(),
nn.Linear(84, 10)
)
def forward(self, x):
return self.fc(x)
# Initialize model
model = TheModelClass()
print(model)
# Initialize optimizer
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
# Print model's state_dict
print("Model's state_dict:")
for param_tensor in model.state_dict():
print(param_tensor, "\t", model.state_dict()[param_tensor].size())
参数名称输出。整个模型的网络结构还是一样的,但是将全连接层重新使用FC类来定义了。从输出的网络结构可以看出,TheModelClass类中定义时,使用的类的结构层次,在网络的结构中会体现。与(conv2)并列的(fc_the_model_class)是在TheModelClass类定义时用的变量名。后接的FC是fc_the_model_class使用的类名,后面的是这个类中定义的层。输出模型时,就是按一种深度优先的方法遍历了整个模型。对于更深层次层的参数,类名是不会出现在参数名中的,然后将参数名按深度组织:fc_the_model_class.fc.0.weight,也就是在打印过程中,:后面的名称会忽略。
TheModelClass(
(conv1): Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))
(pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
(fc_the_model_class): FC(
(fc): Sequential(
(0): Linear(in_features=400, out_features=120, bias=True)
(1): ReLU()
(2): Linear(in_features=120, out_features=84, bias=True)
(3): ReLU()
(4): Linear(in_features=84, out_features=10, bias=True)
)
)
)
Model's state_dict:
conv1.weight torch.Size([6, 3, 5, 5])
conv1.bias torch.Size([6])
conv2.weight torch.Size([16, 6, 5, 5])
conv2.bias torch.Size([16])
fc_the_model_class.fc.0.weight torch.Size([120, 400])
fc_the_model_class.fc.0.bias torch.Size([120])
fc_the_model_class.fc.2.weight torch.Size([84, 120])
fc_the_model_class.fc.2.bias torch.Size([84])
fc_the_model_class.fc.4.weight torch.Size([10, 84])
fc_the_model_class.fc.4.bias torch.Size([10])
保存使用:
torch.save(model.state_dict(), PATH)
加载使用:
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.eval()
保存模型用于推理时,仅需要保存训练后的模型的参数,使用torch.save()函数直接保存模型的state_dict,通常文件的后缀名为.pt或.pth。请记住,在运行推理之前,必须先调用model.eval(),将dropout层和batch normalization层设为关闭状态。否则将会产生不一致的推断结果。
需要注意的是,load_state_dict()函数使用的是字典对象,而不是保存对象的路径,所以需要先进行torch.load(PATH)
保存使用:
torch.save(model, PATH)
加载使用:
# Model class must be defined somewhere
model = torch.load(PATH)
model.eval()
保存使用:
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss,
...
}, PATH)
加载使用:
model = TheModelClass(*args, **kwargs)
optimizer = TheOptimizerClass(*args, **kwargs)
checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']
model.eval()
# - or -
model.train()
在保存checkpoint用于继续训练时,保存优化器的state_dict是必要的,因为它包含着随着模型训练而更新的缓冲区和参数,可能也需要保存一些其他的项目,包括epoch和loss。常见的PyTorch约定是使用.tar文件扩展名保存这些检查点。
保存使用 本质上还是保存的是一个字典对象,PyTorch约定使用.tar保存这些检查点。 :
torch.save({
'modelA_state_dict': modelA.state_dict(),
'modelB_state_dict': modelB.state_dict(),
'optimizerA_state_dict': optimizerA.state_dict(),
'optimizerB_state_dict': optimizerB.state_dict(),
...
}, PATH)
加载使用 加载还是加载的是字典对象,然后取字典对象。 :
modelA = TheModelAClass(*args, **kwargs)
modelB = TheModelBClass(*args, **kwargs)
optimizerA = TheOptimizerAClass(*args, **kwargs)
optimizerB = TheOptimizerBClass(*args, **kwargs)
checkpoint = torch.load(PATH)
modelA.load_state_dict(checkpoint['modelA_state_dict'])
modelB.load_state_dict(checkpoint['modelB_state_dict'])
optimizerA.load_state_dict(checkpoint['optimizerA_state_dict'])
optimizerB.load_state_dict(checkpoint['optimizerB_state_dict'])
modelA.eval()
modelB.eval()
# - or -
modelA.train()
modelB.train()
保存使用:
torch.save(modelA.state_dict(), PATH)
加载使用:对于不同的模型,设置strict=False是必要的。
modelB = TheModelBClass(*args, **kwargs)
modelB.load_state_dict(torch.load(PATH), strict=False)
在迁移学习或者训练新的复杂模型时,部分加载模型或加载部分模型是常见的方案。利用训练过的参数,即使是只有一小部分可以使用,也会对warmstart训练过程有所帮助,而且有望比从头开始训练模型更快地收敛。所谓的warmstart,我理解的就是在参数初始化时,将待训练模型的参数使用已经训练好的模型的部分参数进行初始化,然后接着训练,这种参数初始化方案会大大提高收敛的速度。
无论是从缺少某些键的部分state_dict加载,还是要加载比待加载模型更多的键的state_dic,都可以在lod_state_dict()中将strict参数设置为False,这样可以忽略不匹配的键。
如果要将参数从一层加载到另一层,但是某些键不匹配,只需要加载的state_dict中参数键的名称,来匹配到要加载到的模型中的键。 实验代码如下。
targetModel = TheModelClass()
cifar_net = torch.load('./cifar_net.pth')
for item in cifar_net:
print('cifar_net \t', item, '\t')
targetModel.load_state_dict(cifar_net, strict=False)
for item in targetModel.state_dict():
print('targetModel \t', item, '\t')
print('cifar_net \t', cifar_net["fc3.bias"], '\t', cifar_net["fc3.bias"].data)
print('targetModel \t', targetModel.state_dict()["fc_the_model_class.fc.4.bias"], '\t', targetModel.state_dict()["fc_the_model_class.fc.4.bias"].data)
# 更新层的名称
cifar_net["fc_the_model_class.fc.0.weight"] = cifar_net.pop("fc1.weight")
cifar_net["fc_the_model_class.fc.0.bias"] = cifar_net.pop("fc1.bias")
cifar_net["fc_the_model_class.fc.2.weight"] = cifar_net.pop("fc2.weight")
cifar_net["fc_the_model_class.fc.2.bias"] = cifar_net.pop("fc2.bias")
cifar_net["fc_the_model_class.fc.4.weight"] = cifar_net.pop("fc3.weight")
cifar_net["fc_the_model_class.fc.4.bias"] = cifar_net.pop("fc3.bias")
targetModel.load_state_dict(cifar_net, strict=False)
print('cifar_net \t', cifar_net["fc_the_model_class.fc.4.bias"], '\t', cifar_net["fc_the_model_class.fc.4.bias"].data)
print('targetModel \t', targetModel.state_dict()["fc_the_model_class.fc.4.bias"], '\t', targetModel.state_dict()["fc_the_model_class.fc.4.bias"].data)
输出结果,可以看出fc_the_model_class.fc.4.bias的参数由随机初始化,变成从cifar_net模型中初始化。
cifar_net tensor([-3.1628e-01, -9.6105e-01, -7.1377e-04, 4.6838e-01, 1.1072e+00,
-2.2960e-01, 1.9044e-01, -5.1352e-02, 1.8365e-01, -3.4669e-01],
device='cuda:0') tensor([-3.1628e-01, -9.6105e-01, -7.1377e-04, 4.6838e-01, 1.1072e+00,
-2.2960e-01, 1.9044e-01, -5.1352e-02, 1.8365e-01, -3.4669e-01],
device='cuda:0')
targetModel tensor([-0.0878, -0.1059, -0.0949, 0.0353, 0.0164, -0.1002, -0.0126, -0.1012,
-0.0115, -0.1006]) tensor([-0.0878, -0.1059, -0.0949, 0.0353, 0.0164, -0.1002, -0.0126, -0.1012,
-0.0115, -0.1006])
cifar_net tensor([-3.1628e-01, -9.6105e-01, -7.1377e-04, 4.6838e-01, 1.1072e+00,
-2.2960e-01, 1.9044e-01, -5.1352e-02, 1.8365e-01, -3.4669e-01],
device='cuda:0') tensor([-3.1628e-01, -9.6105e-01, -7.1377e-04, 4.6838e-01, 1.1072e+00,
-2.2960e-01, 1.9044e-01, -5.1352e-02, 1.8365e-01, -3.4669e-01],
device='cuda:0')
targetModel tensor([-3.1628e-01, -9.6105e-01, -7.1377e-04, 4.6838e-01, 1.1072e+00,
-2.2960e-01, 1.9044e-01, -5.1352e-02, 1.8365e-01, -3.4669e-01]) tensor([-3.1628e-01, -9.6105e-01, -7.1377e-04, 4.6838e-01, 1.1072e+00,
-2.2960e-01, 1.9044e-01, -5.1352e-02, 1.8365e-01, -3.4669e-01])
dict={'a':1, 'b':2}
dict["c"] = dict.pop("a")
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。