当前位置:   article > 正文

pytroch 如何保存与导入预训练模型_预训练模型添加自定义词表后如何保存

预训练模型添加自定义词表后如何保存

1 如何保存预训练模型

1.1 以字典形式 保存模型参数

torch.save可以保存我们的模型的部分参数 如下图。

//n_hidden,n_layers为超参.net.state_dict()为模型参数
class model(nn.Module):
       def __init__(self, **kwargs):
           def __init__(self, dataset, embedding):
           self.lstm = nn.LSTM(len(self.chars), n_hidden, n_layers, dropout=drop_prob, batch_first=True)
           self.fc = nn.Linear(n_hidden, len(self.chars))
           ... ... 
       def forward():
           ... ...
model_name = 'rnn_x_epoch.net'
checkpoint = {'n_hidden': net.n_hidden,
              'n_layers': net.n_layers,
              'state_dict': net.state_dict()}
 with open(model_name, 'wb') as f:
    torch.save(checkpoint, f);
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15

1.2 直接保存整个模型

torch.save也可以保存我们的整个模型 如下图。

torch.save(model, './path')
  • 1

具体选择哪一种依照自己的需要,如果只取模型中的一部分,第一种感觉方便一些。如果希望以后直接加载现成的模型。第二种可能方便一些。

2 如何加载自己训练的预训练模型

2.1 加载以字典形式保存的模型参数

model_name = 'rnn_x_epoch.net'
model=torch.load(model_name)
print(type(model))
print('____')
for i in model:
    print(i)
 print('____')
   
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8

输出结果如下

//n_hidden,n_layers为超参.net.state_dict()为模型参数
<class 'dict'>
____
n_hidden
n_layers
state_dict
____
lstm.weight_ih_l0
lstm.weight_hh_l0
lstm.bias_ih_l0
lstm.bias_hh_l0
lstm.weight_ih_l1
lstm.weight_hh_l1
lstm.bias_ih_l1
lstm.bias_hh_l1
fc.weight
fc.bias
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17

可以看到 按照第一种方式,是以字典方式将文件存储进了文件,那么我们怎么将这个里面训练好的网络加载进新的模型呢?

2.1.1 创建新的模型对象
//按照需要 创建一个你希望的新模型
class Net_1(nn.Module):
       def __init__(self, **kwargs):
           def __init__(self, dataset, embedding):
           self.lstm = nn.LSTM(len(self.chars), n_hidden, n_layers, dropout=drop_prob, batch_first=True)
           self.fc1 = nn.Linear(n_hidden, len(self.chars))
           ... ... 
       def forward():
           ... ... 
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
2.1.1 创建新的模型对象
// model1_dict是一个存着net_1所有参数的字典,里面的数据是随机初始化的
 //model是我们的预训练模型,model['state_dict']存储着我们需要的参数
net_1=Net_1(*kwargs,**kwargs)
model1_dict = net_1.state_dict()
new_state_dict = {k:v for k,v in model['state_dict'].items() if k in model1_dict}

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6

为什么需要这个if k in model1_dict语句呢?因为我们有时候只需要部分加载 而不是一股脑全部放上去,那么这个语句是怎么实现这个功能的呢?
我们来看看model1_dict的结构

for i in model1_dict:
  print(i)
  • 1
  • 2
lstm.weight_ih_l0
lstm.weight_hh_l0
lstm.bias_ih_l0
lstm.bias_hh_l0
lstm.weight_ih_l1
lstm.weight_hh_l1
lstm.bias_ih_l1
lstm.bias_hh_l1
fc1.weight
fc1.bias
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10

在经过new_state_dict = {k:v for k,v in model[‘state_dict’].items() if k in model1_dict}语句以后 新的模型继承了lstm层的参数,但是没有继承线性层的参数,因为原模型线性层的名字为fc ,而新模型的线性层的参数为fc1。

这就需要我们在创建新的模型对象的时候,将希望保存的层,与原层有相同的名字,而不希望的保存的层,有不同的名字。
现在new_state_dict 中包含我们需要更新的所有参数

model1_dict.update(new_state_dict)	#更新参数
net_1.load_state_dict(model1_dict) #加载参数
  • 1
  • 2
声明:本文内容由网友自发贡献,转载请注明出处:【wpsshop博客】
推荐阅读
相关标签
  

闽ICP备14008679号