当前位置:   article > 正文

GRU解决预测分类问题(多变量预测多步)_gru分类问题

gru分类问题

解决问题的背景:现有五个属性列,前四个属性列作为特征输入,第五个属性列作为标签值,第五个属性列的意义是类别;先需要通过前50步的数据特征预测后10步的类别(即:51-60步)。

1.直接多输出的方式:直接多输出的方式就是在神经网络的最后加上几个(对应的是需要预测步长是几步,这里是10)一样的全连接神经网络,在这一层之后进行对每个全连接神经网络输出的值的拼接得到一个10步长的结果,用于后面计算损失进行训练。

简单的网络结构如下图:

模型网络的代码如下:

  1. # GRU
  2. class GRURNN(torch.nn.Module):
  3. def __init__(self, input_size, hidden_size, num_layers):
  4. super().__init__()
  5. self.input_size = input_size
  6. self.hidden_size = hidden_size
  7. self.num_layers = num_layers
  8. self.gru = torch.nn.GRU(self.input_size, self.hidden_size, self.num_layers, batch_first=True)
  9. self.fc1 = torch.nn.Linear(self.hidden_size, 4)
  10. self.fc2 = torch.nn.Linear(self.hidden_size, 4)
  11. self.fc3 = torch.nn.Linear(self.hidden_size, 4)
  12. self.fc4 = torch.nn.Linear(self.hidden_size, 4)
  13. self.fc5 = torch.nn.Linear(self.hidden_size, 4)
  14. self.fc6 = torch.nn.Linear(self.hidden_size, 4)
  15. self.fc7 = torch.nn.Linear(self.hidden_size, 4)
  16. self.fc8 = torch.nn.Linear(self.hidden_size, 4)
  17. self.fc9 = torch.nn.Linear(self.hidden_size, 4)
  18. self.fc10 = torch.nn.Linear(self.hidden_size, 4)
  19. self.softmax = torch.nn.Softmax(dim=1)
  20. def forward(self, input_seq):
  21. batch_size = input_seq.shape[0]
  22. h_0 = torch.zeros(self.num_layers, batch_size, self.hidden_size).to(device)
  23. output, _ = self.gru(input_seq,h_0)
  24. pred1 = self.fc1(output)
  25. pred2 = self.fc2(output)
  26. pred3 = self.fc3(output)
  27. pred4 = self.fc4(output)
  28. pred5 = self.fc5(output)
  29. pred6 = self.fc6(output)
  30. pred7 = self.fc7(output)
  31. pred8 = self.fc8(output)
  32. pred9 = self.fc9(output)
  33. pred10 = self.fc10(output)
  34. pred1, pred2, pred3, pred4, pred5, pred6, pred7, pred8, pred9, pred10 = pred1[:, -1, :], pred2[:, -1, :], pred3[:, -1, :], pred4[:, -1, :], pred5[:, -1, :], pred6[:, -1, :], pred7[:, -1, :], pred8[:, -1, :], pred9[:, -1, :], pred10[:, -1, :]
  35. pred1, pred2, pred3, pred4, pred5, pred6, pred7, pred8, pred9, pred10 = self.softmax(pred1), self.softmax(pred2), self.softmax(pred3), self.softmax(pred4), self.softmax(pred5), self.softmax(pred6), self.softmax(pred7), self.softmax(pred8), self.softmax(pred9), self.softmax(pred10)
  36. pred = torch.stack([pred1, pred2, pred3, pred4, pred5, pred6, pred7, pred8, pred9, pred10], dim=1)
  37. return pred

2.滚动数据集输出的方式:滚动数据集的方式就是单步预测的一个整合的版本,具体就是先用前50步预测第51步然后用2-51步作为50步的值进行下一次的输入预测第52步,以此类推;这里后面预测完加入到输入数据中的新值可以就是刚刚预测出来的新值,也可以是数据标签集值的对应到这一步的值。滚动预测的效果会比直接多输出的方式的效果好,但是时间是较长的,对于需要一个较好性能模型的需求来说,时间久一点不是什么问题。

简单的预测步骤如下图(简单表示:5步预测3步):

 模型网络的代码如下:

  1. # GRU
  2. class GRURNN(torch.nn.Module):
  3. def __init__(self, input_size, hidden_size, num_layers):
  4. super().__init__()
  5. self.input_size = input_size
  6. self.hidden_size = hidden_size
  7. self.num_layers = num_layers
  8. self.gru = torch.nn.GRU(self.input_size, self.hidden_size, self.num_layers, batch_first=True)
  9. self.mlp = torch.nn.Sequential(
  10. torch.nn.Linear(self.hidden_size, 32),
  11. torch.nn.LeakyReLU(),
  12. torch.nn.Linear(32, 16),
  13. torch.nn.LeakyReLU(),
  14. torch.nn.Linear(16, 4)
  15. )
  16. self.softmax = torch.nn.Softmax(dim=1)
  17. def forward(self, input_seq):
  18. batch_size = input_seq.shape[0]
  19. h_0 = torch.zeros(self.num_layers, batch_size, self.hidden_size).to(device)
  20. output, _ = self.gru(input_seq,h_0)
  21. output = output[:, -1, :]
  22. pred = self.mlp(output)
  23. pred = self.softmax(pred)
  24. return pred
  25. # 直接单步滚动,预测未来多步的预测
  26. class GRURNN_PRO_MORE(torch.nn.Module):
  27. def __init__(self,gru,device):
  28. super(GRURNN_PRO_MORE, self).__init__()
  29. self.gru = gru
  30. self.device = device
  31. def forward(self, src, trg):
  32. batch_size = src.shape[0]
  33. src_len = src.shape[1]
  34. trg_len = trg.shape[1]
  35. output_size = 4
  36. outputs = torch.zeros(batch_size, trg_len, output_size).to(self.device)
  37. for i in range(trg_len):
  38. src = src.float()
  39. output = self.gru(src)
  40. outputs[:, i, :] = output
  41. trg_input = trg[:, i, :4].reshape([batch_size, 1, output_size])
  42. src = torch.cat((src[:, 1:, :], trg_input),dim=1)
  43. return outputs

数据处理:对于数据的处理用到了常用的一些库,像pandas,numpy等。

作者处于学习阶段,如有错误,欢迎批评指正。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/weixin_40725706/article/detail/101501?site
推荐阅读
相关标签
  

闽ICP备14008679号