赞
踩
RuntimeError: Error(s) in loading state_dict for Tacotron2:
Missing key(s) in state_dict: "embedding.weight", "encoder.Conv_layer1.0.conv.weight", "encoder.Conv_layer1.0.conv.bias", "encoder.Conv_layer1.1.weight", "encoder.Conv_layer1.1.bias", "encoder.Conv_layer1.1.running_mean", "encoder.Conv_layer1.1.running_var", "encoder.Conv_layer2.0.conv.weight", "encoder.Conv_layer2.0.conv.bias", "encoder.Conv_layer2.1.weight", "encoder.Conv_layer2.1.bias", "encoder.Conv_layer2.1.running_mean", "encoder.Conv_layer2.1.running_var", "encoder.Conv_layer3.0.conv.weight", "encoder.Conv_layer3.0.conv.bias", "encoder.Conv_layer3.1.weight", "encoder.Conv_layer3.1.bias", "encoder.Conv_layer3.1.running_mean", "encoder.Conv_layer3.1.running_var", "encoder.lstm.weight_ih_l0", "encoder.lstm.weight_hh_l0", "encoder.lstm.bias_ih_l0", "encoder.lstm.bias_hh_l0", "encoder.lstm.weight_ih_l0_reverse", "encoder.lstm.weight_hh_l0_reverse", "encoder.lstm.bias_ih_l0_reverse", "encoder.lstm.bias_hh_l0_reverse", "decoder.prenet.prenet_layer1.0.linear_layer.weight", "decoder.prenet.prenet_layer2.0.linear_layer.weight", "decoder.attention_rnn.weight_ih", "decoder.attention_rnn.weight_hh", "decoder.attention_rnn.bias_ih", "decoder.attention_rnn.bias_hh", "decoder.attention_layer.query_layer.linear_layer.weight", "decoder.attention_layer.memory_layer.linear_layer.weight", "decoder.attention_layer.v.linear_layer.weight", "decoder.attention_layer.location_layer.location_conv.conv.weight", "decoder.attention_layer.location_layer.location_dense.linear_layer.weight", "decoder.decoder_rnn.weight_ih", "decoder.decoder_rnn.weight_hh", "decoder.decoder_rnn.bias_ih", "decoder.decoder_rnn.bias_hh", "decoder.linear_projection.linear_layer.weight", "decoder.linear_projection.linear_layer.bias", "decoder.gate_layer.linear_layer.weight", "decoder.gate_layer.linear_layer.bias", "postnet.postnet_layer_1.0.conv.weight", "postnet.postnet_layer_1.0.conv.bias", "postnet.postnet_layer_1.1.weight", "postnet.postnet_layer_1.1.bias", "postnet.postnet_layer_1.1.running_mean", "postnet.postnet_layer_1.1.running_var", "postnet.postnet_layer_2.0.conv.weight", "postnet.postnet_layer_2.0.conv.bias", "postnet.postnet_layer_2.1.weight", "postnet.postnet_layer_2.1.bias", "postnet.postnet_layer_2.1.running_mean", "postnet.postnet_layer_2.1.running_var", "postnet.postnet_layer_3.0.conv.weight", "postnet.postnet_layer_3.0.conv.bias", "postnet.postnet_layer_3.1.weight", "postnet.postnet_layer_3.1.bias", "postnet.postnet_layer_3.1.running_mean", "postnet.postnet_layer_3.1.running_var", "postnet.postnet_layer_4.0.conv.weight", "postnet.postnet_layer_4.0.conv.bias", "postnet.postnet_layer_4.1.weight", "postnet.postnet_layer_4.1.bias", "postnet.postnet_layer_4.1.running_mean", "postnet.postnet_layer_4.1.running_var", "postnet.postnet_layer_5.0.conv.weight", "postnet.postnet_layer_5.0.conv.bias", "postnet.postnet_layer_5.1.weight", "postnet.postnet_layer_5.1.bias", "postnet.postnet_layer_5.1.running_mean", "postnet.postnet_layer_5.1.running_var".
Unexpected key(s) in state_dict: "model", "opt".
模型保存的时候将模型参数和优化器参数一并保存了。保存格式为字典套字典的格式。
#模型保存
torch.save({'model':m_model.state_dict(),
'opt':m_optimizer.state_dict()},
'save2/model_final.pick')
而调用模型的时候只需要加载模型参数m_model.state_dict(),也就是大字典里面的其中一个小字典。直接调用模型保存路径就会报错。
# 模型实例化
m_model = Tacotron2(para)
m_model.load_state_dict(torch.load('save2/124500/model_final.pick'))
m_model.eval()
m_model.to(device)
将代码
# 模型实例化
m_model = Tacotron2(para)
m_model.load_state_dict(torch.load('save2/124500/model_final.pick'))
m_model.eval()
m_model.to(device)
修改为
# 模型实例化
m_model = Tacotron2(para)
m_dict = torch.load('save2/124500/model_final.pick')
m_model.load_state_dict(m_dict['model'])
m_model.eval()
m_model.to(device)
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。