当前位置:   article > 正文

解决RuntimeError:Error(s) in loading state_dict for Tacotron2: Missing key(s) in state_dict_missing key(s) in state_dict: "encoder.kernel1.wei

missing key(s) in state_dict: "encoder.kernel1.weight

报错信息

RuntimeError: Error(s) in loading state_dict for Tacotron2:
	Missing key(s) in state_dict: "embedding.weight", "encoder.Conv_layer1.0.conv.weight", "encoder.Conv_layer1.0.conv.bias", "encoder.Conv_layer1.1.weight", "encoder.Conv_layer1.1.bias", "encoder.Conv_layer1.1.running_mean", "encoder.Conv_layer1.1.running_var", "encoder.Conv_layer2.0.conv.weight", "encoder.Conv_layer2.0.conv.bias", "encoder.Conv_layer2.1.weight", "encoder.Conv_layer2.1.bias", "encoder.Conv_layer2.1.running_mean", "encoder.Conv_layer2.1.running_var", "encoder.Conv_layer3.0.conv.weight", "encoder.Conv_layer3.0.conv.bias", "encoder.Conv_layer3.1.weight", "encoder.Conv_layer3.1.bias", "encoder.Conv_layer3.1.running_mean", "encoder.Conv_layer3.1.running_var", "encoder.lstm.weight_ih_l0", "encoder.lstm.weight_hh_l0", "encoder.lstm.bias_ih_l0", "encoder.lstm.bias_hh_l0", "encoder.lstm.weight_ih_l0_reverse", "encoder.lstm.weight_hh_l0_reverse", "encoder.lstm.bias_ih_l0_reverse", "encoder.lstm.bias_hh_l0_reverse", "decoder.prenet.prenet_layer1.0.linear_layer.weight", "decoder.prenet.prenet_layer2.0.linear_layer.weight", "decoder.attention_rnn.weight_ih", "decoder.attention_rnn.weight_hh", "decoder.attention_rnn.bias_ih", "decoder.attention_rnn.bias_hh", "decoder.attention_layer.query_layer.linear_layer.weight", "decoder.attention_layer.memory_layer.linear_layer.weight", "decoder.attention_layer.v.linear_layer.weight", "decoder.attention_layer.location_layer.location_conv.conv.weight", "decoder.attention_layer.location_layer.location_dense.linear_layer.weight", "decoder.decoder_rnn.weight_ih", "decoder.decoder_rnn.weight_hh", "decoder.decoder_rnn.bias_ih", "decoder.decoder_rnn.bias_hh", "decoder.linear_projection.linear_layer.weight", "decoder.linear_projection.linear_layer.bias", "decoder.gate_layer.linear_layer.weight", "decoder.gate_layer.linear_layer.bias", "postnet.postnet_layer_1.0.conv.weight", "postnet.postnet_layer_1.0.conv.bias", "postnet.postnet_layer_1.1.weight", "postnet.postnet_layer_1.1.bias", "postnet.postnet_layer_1.1.running_mean", "postnet.postnet_layer_1.1.running_var", "postnet.postnet_layer_2.0.conv.weight", "postnet.postnet_layer_2.0.conv.bias", "postnet.postnet_layer_2.1.weight", "postnet.postnet_layer_2.1.bias", "postnet.postnet_layer_2.1.running_mean", "postnet.postnet_layer_2.1.running_var", "postnet.postnet_layer_3.0.conv.weight", "postnet.postnet_layer_3.0.conv.bias", "postnet.postnet_layer_3.1.weight", "postnet.postnet_layer_3.1.bias", "postnet.postnet_layer_3.1.running_mean", "postnet.postnet_layer_3.1.running_var", "postnet.postnet_layer_4.0.conv.weight", "postnet.postnet_layer_4.0.conv.bias", "postnet.postnet_layer_4.1.weight", "postnet.postnet_layer_4.1.bias", "postnet.postnet_layer_4.1.running_mean", "postnet.postnet_layer_4.1.running_var", "postnet.postnet_layer_5.0.conv.weight", "postnet.postnet_layer_5.0.conv.bias", "postnet.postnet_layer_5.1.weight", "postnet.postnet_layer_5.1.bias", "postnet.postnet_layer_5.1.running_mean", "postnet.postnet_layer_5.1.running_var". 
	Unexpected key(s) in state_dict: "model", "opt". 
  • 1
  • 2
  • 3

错误原因

模型保存的时候将模型参数和优化器参数一并保存了。保存格式为字典套字典的格式。

#模型保存
torch.save({'model':m_model.state_dict(),
                            'opt':m_optimizer.state_dict()},
                            'save2/model_final.pick')
  • 1
  • 2
  • 3
  • 4

而调用模型的时候只需要加载模型参数m_model.state_dict(),也就是大字典里面的其中一个小字典。直接调用模型保存路径就会报错。

# 模型实例化
m_model = Tacotron2(para)
m_model.load_state_dict(torch.load('save2/124500/model_final.pick'))
m_model.eval()
m_model.to(device)
  • 1
  • 2
  • 3
  • 4
  • 5

解决方法

将代码

# 模型实例化
m_model = Tacotron2(para)
m_model.load_state_dict(torch.load('save2/124500/model_final.pick'))
m_model.eval()
m_model.to(device)
  • 1
  • 2
  • 3
  • 4
  • 5

修改为

# 模型实例化
m_model = Tacotron2(para)
m_dict = torch.load('save2/124500/model_final.pick')
m_model.load_state_dict(m_dict['model'])
m_model.eval()
m_model.to(device)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/盐析白兔/article/detail/270251
推荐阅读
相关标签
  

闽ICP备14008679号