当前位置:   article > 正文

pytorch torch.nn.LSTM

torch.nn.lstm

应用

>>> rnn = nn.LSTM(10, 20, 2)
>>> input = torch.randn(5, 3, 10)
>>> h0 = torch.randn(2, 3, 20)
>>> c0 = torch.randn(2, 3, 20)
>>> output, (hn, cn) = rnn(input, (h0, c0))
  • 1
  • 2
  • 3
  • 4
  • 5

概念

RNN可以看成是一个普通的网络(比如CNN)在时间线上做了多次复制,时间线上的每个网络都会向后续的网络传递信息。
在这里插入图片描述
传统的RNN神经网络无法记忆较长的序列,无法解决长期依赖“long-term dependencies”。
LSTM(Long Short Term Memory networks)可以解决此问题

下图是RNN的网络图, t t t时刻会将前一个神经网络的输出 h t − 1 h_{t-1} ht1和本时刻的 X t X_t Xt作为输入,经过 t a n h tanh tanh激活作为输出
在这里插入图片描述

LSTM也有类似的结构,不过在每个时刻的层会有不相同。

在这里插入图片描述
在这里插入图片描述

LSTM关键在于cell state,就是水平穿图片顶部的黑线。LSTM中每个时刻的神经网络都可以向cell state中添加或删除信息。
Gates来控制信息的通过。Gates通过sigmoid网络层获得然后和cell state进行点乘。
sigmoid输出是0~1,表示有多少信息可以通过。0代表没有信息可以通过,1代表所有信息可以通过。
LSTM总共有三个Gates
在这里插入图片描述
第一步是“forget gate layer” 遗忘门 f t f_t ft,主要用来决定cell state要丢掉哪些信息。
它的输入是 h t − 1 h_{t-1} ht1 x t x_t xt,输出是0~1.然后和 C t − 1 C_{t-1} Ct1进行点乘。
0代表完全保持信息,1代表完全取消信息。

在这里插入图片描述

第二步是“input gate layer” i t i_t it主要决定cell state要添加哪些信息。包含两个部分:
第一部分,sigmoid 层称为“input gate layer”决定 C t C_t Ct哪些信息需要添加到cell state中。
第二部分,tanh层生成了候选的值 C t C_t Ct
sigmoid和tanh层进行点乘。

在这里插入图片描述
然后我们来更新cell state。既将 C t − 1 C_{t-1} Ct1更新为 C t C_t Ct
首先使用 f t ∗ C t − 1 f_t*C_{t-1} ftCt1来忘记之前的信息,再使用 i t ∗ C t i_t*C_t itCt记住当前的信息。

在这里插入图片描述
第三步,我们来决定输出。我们的输出是基于cell state,但是需要进行过滤。通过sigmoind层决定cell state中哪些信息是要输出的,既sigmoid层与cell state的tanh层进行点乘。

在这里插入图片描述

API

CLASS torch.nn.LSTM(*args, **kwargs)
  • 1

f t = σ ( W i f x t + b i f + W h f h t − 1 + b h f ) f_t=\sigma(W_{if}x_t+b_{if}+W_{hf}h_{t-1}+b_{hf}) ft=σ(Wifxt+bif+Whfht1+bhf) # 遗忘门
i t = σ ( W i i x t + b i i + W h i h t − 1 + b h i ) i_t=\sigma(W_{ii}x_t+b_{ii}+W_{hi}h_{t-1}+b_{hi}) it=σ(Wiixt+bii+Whiht1+bhi) # 输入门
i t = σ ( W i o x t + b i o + W h o h t − 1 + b h o ) i_t=\sigma(W_{io}x_t+b_{io}+W_{ho}h_{t-1}+b_{ho}) it=σ(Wioxt+bio+Whoht1+bho) # 输出门
g t = t a n h ( W i g x t + b i g + W h g h t − 1 + b h g ) g_t=tanh(W_{ig}x_t+b_{ig}+W_{hg}h_{t-1}+b_{hg}) gt=tanh(Wigxt+big+Whght1+bhg)
c t = f t ⊙ c t − 1 + i t ⊙ g t c_t=f_t \odot c_{t-1} + i_t \odot g_t ct=ftct1+itgt
h t = o t ⊙ t a n h ( c t ) h_t = o_t \odot tanh(c_t) ht=ottanh(ct)

参数描述
input_sizeThe number of expected features in the input x
hidden_sizeThe number of features in the hidden state h
num_layersNumber of recurrent layers. E.g., setting num_layers=2 would mean stacking two LSTMs together to form a stacked LSTM, with the second LSTM taking in outputs of the first LSTM and computing the final results. Default: 1
biasIf False, then the layer does not use bias weights b_ih and b_hh. Default: True
batch_firstIf True, then the input and output tensors are provided as (batch, seq, feature). Default: False
dropoutIf non-zero, introduces a Dropout layer on the outputs of each LSTM layer except the last layer, with dropout probability equal to dropout. Default: 0
bidirectionalIf True, becomes a bidirectional LSTM. Default: False

输入

参数描述
input shape(seq_len,batch,input_size)
h_0 shape(num_layers*num_directions,batch,hidden_size)
c_0 shape (num_layers * num_directions, batch, hidden_size)

If (h_0, c_0) is not provided, both h_0 and c_0 default to zero.

输出

参数描述
output of shape (seq_len, batch, num_directions * hidden_size)
h_n of shape (num_layers * num_directions, batch, hidden_size)
c_n of shape (num_layers * num_directions, batch, hidden_size)

参考:
http://colah.github.io/posts/2015-08-Understanding-LSTMs/
https://www.cnblogs.com/mfryf/p/7904017.html

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/AllinToyou/article/detail/716943
推荐阅读
相关标签
  

闽ICP备14008679号