当前位置:   article > 正文

在pt文件转换为onnx过程中:TypeError: forward() missing 2 required positional argument

typeerror: forward() missing 2 required positional arguments: 'i' and 'd

网络结构中含有参数(h,c)

class Q_net(nn.Module):
    def __init__(self, state_space=None,
                 action_space=None):
        super(Q_net, self).__init__()

        # space size check
        assert state_space is not None, "None state_space input: state_space should be selected."
        assert action_space is not None, "None action_space input: action_space should be selected."

        self.hidden_space = 64
        self.state_space = state_space
        self.action_space = action_space

        self.Linear1 = nn.Linear(self.state_space, self.hidden_space)
        self.lstm    = nn.LSTM(self.hidden_space,self.hidden_space, batch_first=True)
        self.Linear2 = nn.Linear(self.hidden_space, self.action_space)

    def forward(self, x, h, c):
        x = F.relu(self.Linear1(x))
        x, (new_h, new_c) = self.lstm(x,(h,c))
        x = self.Linear2(x)
        return x, new_h, new_c
    def init_hidden_state(self, batch_size, training=None):

        assert training is not None, "training step parameter should be dtermined"

        if training is True:
            return torch.zeros([1, batch_size, self.hidden_space]), torch.zeros([1, batch_size, self.hidden_space])
        else:
            return torch.zeros([1, 1, self.hidden_space]), torch.zeros([1, 1, self.hidden_space])

//参数h、c
 h, c = q_net.init_hidden_state(batch_size=batch_size, training=True)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33

.pth转.onnx报错: TypeError: forward() missing 2 required positional argument, 报错代码如下:

dummy_input = torch.randn(64,2) 
checkpoing = torch.load('./DRQN_POMDP_Random_SEED_1.pth', 'cpu')
model1.load_state_dict(checkpoing)
torch.onnx.export(checkpoing, dummy_input, "model_best.onnx", export_params=True, verbose=True)  # 将模型保存成.onnx格
  • 1
  • 2
  • 3
  • 4

最后的解决办法:将h、c代入模型中

dummy_input = torch.randn(1,64,2)  # 要求输入3维的矩阵, why?
model1 = Q
h, c = model1.init_hidden_state(batch_size=batch_size, training=False)
checkpoing = torch.load('./DRQN_POMDP_Random_SEED_1.pth', 'cpu')  # 导入模型参数
model1.load_state_dict(checkpoing)  # 将模型参数赋予自定义的模型
torch.onnx.export(model1, (dummy_input,h,c), "model_best.onnx", export_params=True, verbose=True)  # 将模型保存成.onnx格
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/正经夜光杯/article/detail/850123
推荐阅读
相关标签
  

闽ICP备14008679号