当前位置:   article > 正文

超全总结!Pythorch 构建Attention-lstm时序模型 !!

序列lstm attention pytorch

核心点:使用PyTorch框架构建一个基于Attention机制的 LSTM 模型来处理时序数据。

废话不多说,首先呢,今天和大家聊一个小案例:使用PyTorch构建Attention-LSTM时序模型!!

时序数据分析在预测未来事件、检测异常、识别模式等领域中广泛应用。

因此,下面将详细介绍如何使用PyTorch框架构建一个基于Attention机制的LSTM(长短期记忆网络)模型来处理时序数据。

原理阐述

LSTM网络

咱们先聊基础,关于LSTM在此前夜讲过很多了。大家可以翻回去看看~

LSTM是一种特殊的RNN(循环神经网络),适用于处理和预测基于时间的数据。它通过三个门(输入门、遗忘门、输出门)来控制信息的流动,从而能够学习长期依赖关系。

LSTM的核心公式如下:

  • 遗忘门:决定当前时刻  遗忘多少先前状态信息。

  • 输入门:决定当前时刻  添加多少新信息。

  • 候选记忆单元:生成新的候选信息。

  • 当前记忆单元:综合遗忘信息和新输入信息。

  • 输出门:决定输出多少信息。

  • 最终输出

其中, 是sigmoid函数, 是逐元素乘积。

Attention机制

Attention机制在处理长序列时特别有用,因为它可以帮助模型关注序列中的重要部分。它通过计算不同时间步的加权和来突出不同的输入数据对输出的重要性。

Attention机制的核心公式:

  • 计算注意力权重

其中,score函数可以是点积、双线性或其他相似性测量方法。

  • 归一化权重

  • 加权和

最终,Attention机制的输出是输入时间步的加权和,能使模型更有效地关注重要的信息。

模型训练

为了演示Attention-LSTM在时序预测中的应用,我们使用一个模拟的时序数据集进行预测。

Pytorch代码实现

下面是使用PyTorch构建Attention-LSTM模型的代码示例。我们使用一个简单的正弦波数据集来说明。

  1. import torch
  2. import torch.nn as nn
  3. import numpy as np
  4. import matplotlib.pyplot as plt
  5. # 创建模拟数据
  6. def create_sin_wave(seq_len, n_samples):
  7.     x = np.linspace(050, n_samples)
  8.     data = np.sin(x)
  9.     return data
  10. # Attention机制实现
  11. class Attention(nn.Module):
  12.     def __init__(self, hidden_dim):
  13.         super(Attention, self).__init__()
  14.         self.hidden_dim = hidden_dim
  15.         self.attn = nn.Linear(hidden_dim, hidden_dim)
  16.         self.context = nn.Linear(hidden_dim, 1, bias=False)
  17.     def forward(self, hidden_states):
  18.         attn_weights = torch.tanh(self.attn(hidden_states))
  19.         attn_weights = self.context(attn_weights).squeeze(2)
  20.         attn_weights = torch.softmax(attn_weights, dim=1)
  21.         context_vector = torch.sum(attn_weights.unsqueeze(2) * hidden_states, dim=1)
  22.         return context_vector, attn_weights
  23. # LSTM模型实现
  24. class AttentionLSTM(nn.Module):
  25.     def __init__(self, input_dim, hidden_dim, output_dim, n_layers):
  26.         super(AttentionLSTM, self).__init__()
  27.         self.hidden_dim = hidden_dim
  28.         self.n_layers = n_layers
  29.         self.lstm = nn.LSTM(input_dim, hidden_dim, n_layers, batch_first=True)
  30.         self.attention = Attention(hidden_dim)
  31.         self.fc = nn.Linear(hidden_dim, output_dim)
  32.     def forward(self, x):
  33.         h_0 = torch.zeros(self.n_layers, x.size(0), self.hidden_dim).to(x.device)
  34.         c_0 = torch.zeros(self.n_layers, x.size(0), self.hidden_dim).to(x.device)
  35.         out, _ = self.lstm(x, (h_0, c_0))
  36.         context_vector, attn_weights = self.attention(out)
  37.         out = self.fc(context_vector)
  38.         return out, attn_weights
  39. # 生成数据集
  40. seq_len = 20
  41. n_samples = 1000
  42. data = create_sin_wave(seq_len, n_samples)
  43. data = torch.tensor(data, dtype=torch.float32).unsqueeze(1)
  44. # 准备训练集和测试集
  45. def create_inout_sequences(data, seq_len):
  46.     inout_seq = []
  47.     L = len(data)
  48.     for i in range(L-seq_len):
  49.         train_seq = data[i:i+seq_len]
  50.         train_label = data[i+seq_len:i+seq_len+1]
  51.         inout_seq.append((train_seq, train_label))
  52.     return inout_seq
  53. train_seq = create_inout_sequences(data, seq_len)
  54. train_X = torch.stack([s[0for s in train_seq])
  55. train_Y = torch.stack([s[1for s in train_seq])
  56. # 训练模型
  57. input_dim = 1
  58. hidden_dim = 64
  59. output_dim = 1
  60. n_layers = 2
  61. n_epochs = 100
  62. learning_rate = 0.001
  63. model = AttentionLSTM(input_dim, hidden_dim, output_dim, n_layers)
  64. criterion = nn.MSELoss()
  65. optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
  66. model.train()
  67. for epoch in range(n_epochs):
  68.     optimizer.zero_grad()
  69.     output, attn_weights = model(train_X)
  70.     loss = criterion(output, train_Y)
  71.     loss.backward()
  72.     optimizer.step()
  73.     if (epoch+1) % 10 == 0:
  74.         print(f'Epoch [{epoch+1}/{n_epochs}], Loss: {loss.item():.4f}')
  75. # 可视化结果
  76. model.eval()
  77. with torch.no_grad():
  78.     pred, attn_weights = model(train_X)
  79. # 绘制实际值与预测值
  80. plt.figure(figsize=(147))
  81. plt.plot(data.numpy(), label='True Data')
  82. plt.plot(range(seq_len, seq_len + len(pred)), pred.numpy(), label='Predicted Data')
  83. plt.xlabel('Time step')
  84. plt.ylabel('Value')
  85. plt.title('Attention-LSTM: True vs Predicted')
  86. plt.legend()
  87. plt.show()

509a43fb490e51f1ba2ce9d58443cb14.jpeg

  1. # 绘制注意力权重
  2. attn_weights = attn_weights.numpy()
  3. plt.figure(figsize=(147))
  4. plt.imshow(attn_weights.T, aspect='auto', cmap='viridis')
  5. plt.colorbar()
  6. plt.xlabel('Time step')
  7. plt.ylabel('Attention Weights')
  8. plt.title('Attention Weights Distribution')
  9. plt.show()

6a9f8c2c1ac8e1d8c110455e3443d185.png

代码说明

  1. 数据生成:我们创建了一个正弦波数据集,用于模拟时序数据。

  2. Attention机制实现:定义了一个Attention类,用于计算注意力权重和上下文向量。

  3. LSTM模型实现:定义了一个AttentionLSTM类,该类包含LSTM层、Attention层和全连接层。

  4. 训练模型:通过梯度下降训练模型。

  5. 可视化结果:绘制预测值和实际值的对比图,以及注意力权重的分布图。

Attention机制能够增强LSTM模型的性能,使其在处理长时间依赖关系时更加有效。通过引入Attention机制,模型能够自动关注时序数据中重要的时间步,从而提高预测的准确性。

这种Attention-LSTM模型在金融预测、气象分析、医疗诊断等领域都有广泛的应用潜力。大家在未来的研究可以探索不同类型的注意力机制以及其在不同应用场景中的效果。

最后

喜欢本文的朋友可以收藏、点赞、转发起来!

 
 
 
 

28950e96bb0e10b1baf6b2c1f329d77b.jpeg

 
 
 
 
  1. 往期精彩回顾
  2. 适合初学者入门人工智能的路线及资料下载(图文+视频)机器学习入门系列下载机器学习及深度学习笔记等资料打印《统计学习方法》的代码复现专辑
  • 交流群

欢迎加入机器学习爱好者微信群一起和同行交流,目前有机器学习交流群、博士群、博士申报交流、CV、NLP等微信群,请扫描下面的微信号加群,备注:”昵称-学校/公司-研究方向“,例如:”张小明-浙大-CV“。请按照格式备注,否则不予通过。添加成功后会根据研究方向邀请进入相关微信群。请勿在群内发送广告,否则会请出群,谢谢理解~(也可以加入机器学习交流qq群772479961)

22413538ccfc7d9479a5b4e092b29838.png

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/煮酒与君饮/article/detail/978295
推荐阅读
相关标签
  

闽ICP备14008679号