赞
踩
torch.save可以保存我们的模型的部分参数 如下图。
//n_hidden,n_layers为超参.net.state_dict()为模型参数
class model(nn.Module):
def __init__(self, **kwargs):
def __init__(self, dataset, embedding):
self.lstm = nn.LSTM(len(self.chars), n_hidden, n_layers, dropout=drop_prob, batch_first=True)
self.fc = nn.Linear(n_hidden, len(self.chars))
... ...
def forward():
... ...
model_name = 'rnn_x_epoch.net'
checkpoint = {'n_hidden': net.n_hidden,
'n_layers': net.n_layers,
'state_dict': net.state_dict()}
with open(model_name, 'wb') as f:
torch.save(checkpoint, f);
torch.save也可以保存我们的整个模型 如下图。
torch.save(model, './path')
具体选择哪一种依照自己的需要,如果只取模型中的一部分,第一种感觉方便一些。如果希望以后直接加载现成的模型。第二种可能方便一些。
model_name = 'rnn_x_epoch.net'
model=torch.load(model_name)
print(type(model))
print('____')
for i in model:
print(i)
print('____')
输出结果如下
//n_hidden,n_layers为超参.net.state_dict()为模型参数 <class 'dict'> ____ n_hidden n_layers state_dict ____ lstm.weight_ih_l0 lstm.weight_hh_l0 lstm.bias_ih_l0 lstm.bias_hh_l0 lstm.weight_ih_l1 lstm.weight_hh_l1 lstm.bias_ih_l1 lstm.bias_hh_l1 fc.weight fc.bias
可以看到 按照第一种方式,是以字典方式将文件存储进了文件,那么我们怎么将这个里面训练好的网络加载进新的模型呢?
//按照需要 创建一个你希望的新模型
class Net_1(nn.Module):
def __init__(self, **kwargs):
def __init__(self, dataset, embedding):
self.lstm = nn.LSTM(len(self.chars), n_hidden, n_layers, dropout=drop_prob, batch_first=True)
self.fc1 = nn.Linear(n_hidden, len(self.chars))
... ...
def forward():
... ...
// model1_dict是一个存着net_1所有参数的字典,里面的数据是随机初始化的
//model是我们的预训练模型,model['state_dict']存储着我们需要的参数
net_1=Net_1(*kwargs,**kwargs)
model1_dict = net_1.state_dict()
new_state_dict = {k:v for k,v in model['state_dict'].items() if k in model1_dict}
为什么需要这个if k in model1_dict语句呢?因为我们有时候只需要部分加载 而不是一股脑全部放上去,那么这个语句是怎么实现这个功能的呢?
我们来看看model1_dict的结构
for i in model1_dict:
print(i)
lstm.weight_ih_l0
lstm.weight_hh_l0
lstm.bias_ih_l0
lstm.bias_hh_l0
lstm.weight_ih_l1
lstm.weight_hh_l1
lstm.bias_ih_l1
lstm.bias_hh_l1
fc1.weight
fc1.bias
在经过new_state_dict = {k:v for k,v in model[‘state_dict’].items() if k in model1_dict}语句以后 新的模型继承了lstm层的参数,但是没有继承线性层的参数,因为原模型线性层的名字为fc ,而新模型的线性层的参数为fc1。
这就需要我们在创建新的模型对象的时候,将希望保存的层,与原层有相同的名字,而不希望的保存的层,有不同的名字。
现在new_state_dict 中包含我们需要更新的所有参数
model1_dict.update(new_state_dict) #更新参数
net_1.load_state_dict(model1_dict) #加载参数
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。