赞
踩
网络结构中含有参数(h,c)
class Q_net(nn.Module): def __init__(self, state_space=None, action_space=None): super(Q_net, self).__init__() # space size check assert state_space is not None, "None state_space input: state_space should be selected." assert action_space is not None, "None action_space input: action_space should be selected." self.hidden_space = 64 self.state_space = state_space self.action_space = action_space self.Linear1 = nn.Linear(self.state_space, self.hidden_space) self.lstm = nn.LSTM(self.hidden_space,self.hidden_space, batch_first=True) self.Linear2 = nn.Linear(self.hidden_space, self.action_space) def forward(self, x, h, c): x = F.relu(self.Linear1(x)) x, (new_h, new_c) = self.lstm(x,(h,c)) x = self.Linear2(x) return x, new_h, new_c def init_hidden_state(self, batch_size, training=None): assert training is not None, "training step parameter should be dtermined" if training is True: return torch.zeros([1, batch_size, self.hidden_space]), torch.zeros([1, batch_size, self.hidden_space]) else: return torch.zeros([1, 1, self.hidden_space]), torch.zeros([1, 1, self.hidden_space]) //参数h、c h, c = q_net.init_hidden_state(batch_size=batch_size, training=True)
.pth转.onnx报错: TypeError: forward() missing 2 required positional argument, 报错代码如下:
dummy_input = torch.randn(64,2)
checkpoing = torch.load('./DRQN_POMDP_Random_SEED_1.pth', 'cpu')
model1.load_state_dict(checkpoing)
torch.onnx.export(checkpoing, dummy_input, "model_best.onnx", export_params=True, verbose=True) # 将模型保存成.onnx格
最后的解决办法:将h、c代入模型中
dummy_input = torch.randn(1,64,2) # 要求输入3维的矩阵, why?
model1 = Q
h, c = model1.init_hidden_state(batch_size=batch_size, training=False)
checkpoing = torch.load('./DRQN_POMDP_Random_SEED_1.pth', 'cpu') # 导入模型参数
model1.load_state_dict(checkpoing) # 将模型参数赋予自定义的模型
torch.onnx.export(model1, (dummy_input,h,c), "model_best.onnx", export_params=True, verbose=True) # 将模型保存成.onnx格
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。