赞
踩
pytorch 中的 state_dict 是一个简单的python的字典对象,将每一层与它的对应参数建立映射关系.(如model的每一层的weights及偏置等等)
(注意,只有那些参数可以训练的layer才会被保存到模型的state_dict中,如卷积层,线性层等等)
优化器对象Optimizer也有一个state_dict,它包含了优化器的状态以及被使用的超参数(如lr, momentum,weight_decay等)
备注:
1) state_dict是在定义了model或optimizer之后pytorch自动生成的,可以直接调用.常用的保存state_dict的格式是".pt"或'.pth'的文件,即下面命令的 PATH="./***.pt"
torch.save(model.state_dict(), PATH)
2) load_state_dict 也是model或optimizer之后pytorch自动具备的函数,可以直接调用
model = TheModelClass(*args, **kwargs) model.load_state_dict(torch.load(PATH)) model.eval()
注意:model.eval() 的重要性,在2)中最后用到了model.eval(),是因为,只有在执行该命令后,"dropout层"及"batch normalization层"才会进入 evalution 模态. 而在"训练(training)模态"与"评估(evalution)模态"下,这两层有不同的表现形式.
-------------------------------------------------------------------------------------------------------------------------------
模态字典(state_dict)的保存(model是一个网络结构类的对象
)
1.1)仅保存学习到的参数,用以下命令
torch.save(model.state_dict(), PATH)
1.2)加载model.state_dict,用以下命令
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.eval()
备注:model.load_state_dict的操作对象是 一个具体的对象,而不能是文件名
-----------
2.1)保存整个model的状态,用以下命令
torch.save(model,PATH)
2.2)加载整个model的状态,用以下命令:
# Model class must be defined somewhere
model = torch.load(PATH)
model.eval()
--------------------------------------------------------------------------------------------------------------------------------------
state_dict 是一个python的字典格式,以字典的格式存储,然后以字典的格式被加载,而且只加载key匹配的项
----------------------------------------------------------------------------------------------------------------------
如何仅加载某一层的训练的到的参数(某一层的state)
If you want to load parameters from one layer to another, but some keys do not match, simply change the name of the parameter keys in the state_dict that you are loading to match the keys in the model that you are loading into.
conv1_weight_state = torch.load('./model_state_dict.pt')['conv1.weight']
--------------------------------------------------------------------------------------------
加载模型参数后,如何设置某层某参数的"是否需要训练"(param.requires_grad)
-
for param
in list(model.pretrained.parameters()):
-
param.requires_grad =
False
注意: requires_grad的操作对象是tensor.
疑问:能否直接对某个层直接之用requires_grad呢?例如:model.conv1.requires_grad=False
回答:经测试,不可以.model.conv1 没有requires_grad属性.
---------------------------------------------------------------------------------------------
全部测试代码:
-
#-*-coding:utf-8-*-
-
import torch
-
import torch.nn
as nn
-
import torch.nn.functional
as F
-
import torch.optim
as optim
-
-
-
-
# define model
-
class TheModelClass(nn.Module):
-
def __init__(self):
-
super(TheModelClass,self).__init__()
-
self.conv1 = nn.Conv2d(
3,
6,
5)
-
self.pool = nn.MaxPool2d(
2,
2)
-
self.conv2 = nn.Conv2d(
6,
16,
5)
-
self.fc1 = nn.Linear(
16*
5*
5,
120)
-
self.fc2 = nn.Linear(
120,
84)
-
self.fc3 = nn.Linear(
84,
10)
-
-
def forward(self,x):
-
x = self.pool(F.relu(self.conv1(x)))
-
x = self.pool(F.relu(self.conv2(x)))
-
x = x.view(
-1,
16*
5*
5)
-
x = F.relu(self.fc1(x))
-
x = F.relu(self.fc2(x))
-
x = self.fc3(x)
-
return x
-
-
# initial model
-
model = TheModelClass()
-
-
#initialize the optimizer
-
optimizer = optim.SGD(model.parameters(),lr=
0.001,momentum=
0.9)
-
-
# print the model's state_dict
-
print(
"model's state_dict:")
-
for param_tensor
in model.state_dict():
-
print(param_tensor,
'\t',model.state_dict()[param_tensor].size())
-
-
print(
"\noptimizer's state_dict")
-
for var_name
in optimizer.state_dict():
-
print(var_name,
'\t',optimizer.state_dict()[var_name])
-
-
print(
"\nprint particular param")
-
print(
'\n',model.conv1.weight.size())
-
print(
'\n',model.conv1.weight)
-
-
print(
"------------------------------------")
-
torch.save(model.state_dict(),
'./model_state_dict.pt')
-
# model_2 = TheModelClass()
-
# model_2.load_state_dict(torch.load('./model_state_dict'))
-
# model.eval()
-
# print('\n',model_2.conv1.weight)
-
# print((model_2.conv1.weight == model.conv1.weight).size())
-
## 仅仅加载某一层的参数
-
conv1_weight_state = torch.load(
'./model_state_dict.pt')[
'conv1.weight']
-
print(conv1_weight_state==model.conv1.weight)
-
-
model_2 = TheModelClass()
-
model_2.load_state_dict(torch.load(
'./model_state_dict.pt'))
-
model_2.conv1.requires_grad=
False
-
print(model_2.conv1.requires_grad)
-
print(model_2.conv1.bias.requires_grad)
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。