当前位置:   article > 正文

attention+pytorch+时间序列数据预测_pytroch中lstm +self attention

pytroch中lstm +self attention
将用于NLP的Encoder-Decoder修改用于时间序列数据预测,实验发现添加注意力机制后预测效果能够得到提升。
class Encoder (nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()
        self.rnn=nn.LSTM(
            input_size=INPUT_SIZE,
            hidden_size=HIDDEN_SIZE,
            num_layers= 1,
            batch_first=True
        )
    def forward(self,x):
        r_out, (hidden,cell) = self.rnn(x)
        print(r_out.shape)
        return r_out,hidden,cell

class Decoder (nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()
        self.rnn=nn.LSTM(
            input_size=INPUT_SIZE,
            hidden_size=HIDDEN_SIZE,
            num_layers= 1,
            batch_first=True
        )
        self.out=nn.Linear(HIDDEN_SIZE,1)
    def forward(self,x,hidden,cell):
        print("x:", x.shape)
        output, (hidden,cell) = self.rnn(x,(hidden,cell))
        print("output:", output.shape)
        print("output.squeeze(0):", output.squeeze(0).shape)
        prediction = self.out(output.squeeze(0))
        print("prediction:",prediction.shape)
        return  prediction,hidden,cell


class Seq2Seq(nn.Module):
    def __init__(self, encoder, decoder, device):
        super(Seq2Seq, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.device = device

    def attention_net(self, lstm_output, final_state):

        hidden = final_state.view(-1, HIDDEN_SIZE , 1) # hidden : [batch_size, n_hidden * num_directions(=2), 1(=n_layer)]
        # print("----------------------------------------------------")
        # print("hidden的值:", hidden.shape)
        attn_weights = torch.bmm(lstm_output, hidden).squeeze(2) # attn_weights : [batch_size, n_step]
        # print("attn_weights的值:", attn_weights.shape)
        soft_attn_weights = F.softmax(attn_weights, 1)
        # print("soft_attn_weights的值:", soft_attn_weights.shape)
        # print("soft_attn_weights.unsqueeze(2)的值:", soft_attn_weights.unsqueeze(2).shape)
        # print("torch.bmm(lstm_output.transpose(1, 2), soft_attn_weights.unsqueeze(2))的值:", torch.bmm(lstm_output.transpose(1, 2), soft_attn_weights.unsqueeze(2)).shape)
        context = torch.bmm(lstm_output.transpose(1, 2), soft_attn_weights.unsqueeze(2)).squeeze(2)
        print("context的值:", context.shape)
        return context, soft_attn_weights.data.numpy() # context : [batch_size, n_hidden * num_directions(=2)]


    def forward(self,src):

        src_len=src.shape[0]
        batch_size = src.shape[1]
        outputs =torch.zeros(src_len, batch_size, 1).to(self.device).double()
        # print("------------------------------")
        # print("outputs:",outputs.shape)
        print(src.shape)
        r_out,hidden,cell = self.encoder(src)

        print("r_out",r_out.shape)
        print("hidden", hidden.shape)

        attn_output, attention = self.attention_net(r_out, hidden)
        hidden = attn_output.view(1, -1, HIDDEN_SIZE)

        # print("hidden___",hidden.shape)
        # print("attn_output",attn_output.shape)
        # print("attention", attention.shape)
        # print("------------------------------")
        # print("src:", src.shape)
        # print("hidden:",hidden.shape)
        # print("cell:",cell.shape)
        # print("------------------------------")

        for t in range(1,batch_size):
            input=src[:,t-1,:].unsqueeze(1)
            print("input:",input.shape)
            output, hidden, cell = self.decoder(input, hidden, cell)
            print("------------------------------")
            print("output:",output.shape)
            print("hidden:",  hidden.shape)
            print("cell:", cell.shape)
            print("outputs:", outputs.shape)
            print("outputs[:,t,:]:", outputs[:,t-1,:].unsqueeze(1).shape)
            outputs[:,t-1,:]=output.squeeze(1)
        print("------------------------------")
        print("outputs:",outputs.shape)
        return outputs

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/煮酒与君饮/article/detail/777293
推荐阅读
相关标签
  

闽ICP备14008679号