赞
踩
将用于NLP的Encoder-Decoder修改用于时间序列数据预测,实验发现添加注意力机制后预测效果能够得到提升。 class Encoder (nn.Module): def __init__(self): super(Encoder, self).__init__() self.rnn=nn.LSTM( input_size=INPUT_SIZE, hidden_size=HIDDEN_SIZE, num_layers= 1, batch_first=True ) def forward(self,x): r_out, (hidden,cell) = self.rnn(x) print(r_out.shape) return r_out,hidden,cell class Decoder (nn.Module): def __init__(self): super(Decoder, self).__init__() self.rnn=nn.LSTM( input_size=INPUT_SIZE, hidden_size=HIDDEN_SIZE, num_layers= 1, batch_first=True ) self.out=nn.Linear(HIDDEN_SIZE,1) def forward(self,x,hidden,cell): print("x:", x.shape) output, (hidden,cell) = self.rnn(x,(hidden,cell)) print("output:", output.shape) print("output.squeeze(0):", output.squeeze(0).shape) prediction = self.out(output.squeeze(0)) print("prediction:",prediction.shape) return prediction,hidden,cell class Seq2Seq(nn.Module): def __init__(self, encoder, decoder, device): super(Seq2Seq, self).__init__() self.encoder = encoder self.decoder = decoder self.device = device def attention_net(self, lstm_output, final_state): hidden = final_state.view(-1, HIDDEN_SIZE , 1) # hidden : [batch_size, n_hidden * num_directions(=2), 1(=n_layer)] # print("----------------------------------------------------") # print("hidden的值:", hidden.shape) attn_weights = torch.bmm(lstm_output, hidden).squeeze(2) # attn_weights : [batch_size, n_step] # print("attn_weights的值:", attn_weights.shape) soft_attn_weights = F.softmax(attn_weights, 1) # print("soft_attn_weights的值:", soft_attn_weights.shape) # print("soft_attn_weights.unsqueeze(2)的值:", soft_attn_weights.unsqueeze(2).shape) # print("torch.bmm(lstm_output.transpose(1, 2), soft_attn_weights.unsqueeze(2))的值:", torch.bmm(lstm_output.transpose(1, 2), soft_attn_weights.unsqueeze(2)).shape) context = torch.bmm(lstm_output.transpose(1, 2), soft_attn_weights.unsqueeze(2)).squeeze(2) print("context的值:", context.shape) return context, soft_attn_weights.data.numpy() # context : [batch_size, n_hidden * num_directions(=2)] def forward(self,src): src_len=src.shape[0] batch_size = src.shape[1] outputs =torch.zeros(src_len, batch_size, 1).to(self.device).double() # print("------------------------------") # print("outputs:",outputs.shape) print(src.shape) r_out,hidden,cell = self.encoder(src) print("r_out",r_out.shape) print("hidden", hidden.shape) attn_output, attention = self.attention_net(r_out, hidden) hidden = attn_output.view(1, -1, HIDDEN_SIZE) # print("hidden___",hidden.shape) # print("attn_output",attn_output.shape) # print("attention", attention.shape) # print("------------------------------") # print("src:", src.shape) # print("hidden:",hidden.shape) # print("cell:",cell.shape) # print("------------------------------") for t in range(1,batch_size): input=src[:,t-1,:].unsqueeze(1) print("input:",input.shape) output, hidden, cell = self.decoder(input, hidden, cell) print("------------------------------") print("output:",output.shape) print("hidden:", hidden.shape) print("cell:", cell.shape) print("outputs:", outputs.shape) print("outputs[:,t,:]:", outputs[:,t-1,:].unsqueeze(1).shape) outputs[:,t-1,:]=output.squeeze(1) print("------------------------------") print("outputs:",outputs.shape) return outputs
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。