当前位置:   article > 正文

LSTM结构_lstm模型的结构

lstm模型的结构

在前面讲的【Deep learning】循环神经网络RNN中,我们对RNN模型做了总结。由于RNN也有梯度消失的问题,因此很难处理长序列的数据,大牛们对RNN做了改进,得到了RNN的特例LSTM(Long Short-Term Memory),它可以避免常规RNN的梯度消失,因此在工业界得到了广泛的应用。下面我们就对LSTM模型做一个总结。

1.从RNNLSTM

img

img

其中上图是传统RNN结构框架,而下图为LSTM结构框架,对比两图可以发现LSTM比传统RNN要复杂得多。它也是一种特殊的循环体结构,拥有三个“门”结构的特殊网络结构:遗忘门、输入门和输出门。下面就将具体介绍每一种门都是怎么工作的。

2.遗忘门

遗忘门(forget gate)顾名思义,是控制是否遗忘的,在LSTM中即以一定的概率控制是否遗忘上一层的隐藏细胞状态。举个栗子,比如一段文章中先介绍了某地原来是绿水蓝天,但是后来被污染了。于是在看到被污染了之后,循环神经网络就会“忘记”了之前绿水蓝天的状态。这就是“遗忘门”工作内容。

遗忘门子结构如下图所示:
img

图中输入的有上一序列的隐藏状态h(t−1)和本序列数据x(t),通过一个激活函数,一般是sigmoid,得到遗忘门的输出f(t)。由于sigmoid的输出f(t)在[0,1]之间,因此这里的输出f^{(t)}代表了遗忘上一层隐藏细胞状态的概率。用数学表达式即为:

f ( t ) = σ ( W f h ( t − 1 ) + U f x ( t ) + b f ) f^{(t)}=\sigma(W_fh^{(t-1)}+U_fx^{(t)}+b_f) f(t)=σ(Wfh(t1)+Ufx(t)+bf)

其中Wf,Uf,bf为线性关系的系数和偏倚,和RNN中的类似。σ为sigmoid激活函数。

3.输入门

在RNN经历“遗忘门”之后,它还需要从当前的输入来补充最新的记忆,这就需要“输入门”来完成。

输入门(input gate)负责处理当前序列位置的输入,它的子结构如下图:

img

从图中可以看到输入门由两部分组成,第一部分使用了sigmoid激活函数,输出为i(t),第二部分使用了tanh激活函数,输出为a(t), 两者的结果后面会相乘再去更新细胞状态。用数学表达式即为:

i ( t ) = σ ( W i h ( t − 1 ) + U i x ( t ) + b i ) a ( t ) = t a n h ( W a h ( t − 1 ) + U a x ( t ) + b a )

i(t)=σ(Wih(t1)+Uix(t)+bi)a(t)=tanh(Wah(t1)+Uax(t)+ba)
i(t)=σ(Wih(t1)+Uix(t)+bi)a(t)=tanh(Wah(t1)+Uax(t)+ba)

其中Wi,Ui,bi,Wa,Ua,ba,为线性关系的系数和偏倚,和RNN中的类似。σ为sigmoid激活函数。

4. Cell状态更新

在研究LSTM输出门之前,我们要先看看LSTM之细胞状态。前面的遗忘门和输入门的结果都会作用于细胞状态C(t)。我们来看看从细胞状态C(t−1)如何得到C(t)。如下图所示:

img

细胞状态C(t)由两部分组成,第一部分是C(t−1)和遗忘门输出f(t)的乘积,第二部分是输入门的i(t)和a(t)的乘积,即:

C ( t ) = C ( t − 1 ) ⊙ f ( t ) + i ( t ) ⊙ a ( t ) C^{(t)}=C^{(t-1)}\odot f^{(t)}+i^{(t)}\odot a^{(t)} C(t)=C(t1)f(t)+i(t)a(t)

其中,⊙为Hadamard积,在DNN中也用到过。

5.输出门

有了新的隐藏细胞状态C(t),我们就可以来看输出门了,子结构如下:

img

从图中可以看出,隐藏状态h(t)的更新由两部分组成,第一部分是o(t), 它由上一序列的隐藏状态h(t−1)和本序列数据x(t),以及激活函数sigmoid得到,第二部分由隐藏状态C(t)和tanh激活函数组成, 即:

o ( t ) = σ ( W o h ( t − 1 ) + U o x ( t ) + b o ) h ( t ) = o ( t ) ⊙ t a n h ( C ( t ) )

o(t)=σ(Woh(t1)+Uox(t)+bo)h(t)=o(t)tanh(C(t))
o(t)h(t)=σ(Woh(t1)+Uox(t)+bo)=o(t)tanh(C(t))

6. LSTM前向传播算法

现在我们来总结下LSTM前向传播算法。LSTM模型有两个隐藏状态h(t),C(t),模型参数几乎是RNN的4倍,因为现在多了Wf,Uf,bf,Wa,Ua,ba,Wi,Ui,bi,Wo,Uo,bo这些参数。

img

7.LSTM反向传播算法

有了LSTM前向传播算法,推导反向传播算法就很容易了, 思路和RNN的反向传播算法思路一致,也是通过梯度下降法迭代更新我们所有的参数,关键点在于计算所有参数基于损失函数的偏导数。

在RNN中,为了反向传播误差,我们通过隐藏状态h(t)的梯度δ(t)一步步向前传播。在LSTM这里也类似。只不过我们这里有两个隐藏状态h(t)和C(t)。这里我们定义两个δ,即:

img

反向传播时只使用了δ(t)C,变量δ(t)h仅为帮助我们在某一层计算用,并没有参与反向传播,这里要注意。如下图所示:

img

而在最后的序列索引位置τ的δ(τ)h和 δ(τ)C为:

δ h ( τ ) = ∂ L ∂ O ( τ ) ∂ O ( τ ) ∂ h ( τ ) = V T ( y ^ ( τ ) − y ( τ ) ) δ C ( τ ) = ∂ L ∂ h ( τ ) ∂ h ( τ ) ∂ C ( τ ) = δ h ( τ ) ⊙ o ( τ ) ⊙ ( 1 − t a n h 2 ( C ( τ ) ) )

δh(τ)=LO(τ)O(τ)h(τ)=VT(y^(τ)y(τ))δC(τ)=Lh(τ)h(τ)C(τ)=δh(τ)o(τ)(1tanh2(C(τ)))
δh(τ)δC(τ)=O(τ)Lh(τ)O(τ)=VT(y^(τ)y(τ))=h(τ)LC(τ)h(τ)=δh(τ)o(τ)(1tanh2(C(τ)))

接着我们由δ(t+1)C反向推导δ(t)C。

δ(t)h的梯度由本层的输出梯度误差决定,即:

δ h ( t ) = ∂ L ∂ h ( t ) = V T ( y ^ ( t ) − y ( t ) ) \delta_h^{(t)}=\frac{\partial L}{\partial h^{(t)}}=V^T(\hat{y}^{(t)}-y^{(t)}) δh(t)=h(t)L=VT(y^(t)y(t))

而δ(t)C的反向梯度误差由前一层δ(t+1)C的梯度误差和本层的从h(t)传回来的梯度误差两部分组成,即:

δ C ( t ) = ∂ L ∂ C ( t + 1 ) ∂ C ( t + 1 ) ∂ C ( t ) + ∂ L ∂ h ( t ) ∂ h ( t ) ∂ C ( t ) = δ C ( t + 1 ) ⊙ f ( t + 1 ) + δ h ( t ) ⊙ o ( t ) ⊙ ( 1 − t a n h 2 ( C ( t ) ) ) \delta_C^{(t)}=\frac{\partial L}{\partial C^{(t+1)}}\frac{\partial C^{(t+1)}}{\partial C^{(t)}}+\frac{\partial L}{\partial h^{(t)}}\frac{\partial h^{(t)}}{\partial C^{(t)}}=\delta_C^{(t+1)}\odot f^{(t+1)}+\delta_h^{(t)}\odot o^{(t)}\odot(1-tanh^2(C^{(t)})) δC(t)=C(t+1)LC(t)C(t+1)+h(t)LC(t)h(t)=δC(t+1)f(t+1)+δh(t)o(t)(1tanh2(C(t)))

有了δ(t)h和δ(t)C, 计算这一大堆参数的梯度就很容易了,这里只给出Wf的梯度计算过程,其他的Uf,bf,Wa,Ua,ba,Wi,Ui,bi,Wo,Uo,bo,V,c的梯度大家只要照搬就可以了。

∂ L ∂ W f = ∑ t = 1 τ ∂ L ∂ C ( t ) ∂ C ( t ) ∂ f ( t ) ∂ f ( t ) ∂ W f = ∑ t = 1 τ δ C ( t ) ⊙ C ( t − 1 ) ⊙ f ( t ) ( 1 − f ( t ) ) ( h ( t − 1 ) ) T \frac{\partial L}{\partial W_f}=\sum_{t=1}^{\tau}\frac{\partial L}{\partial C^{(t)}}\frac{\partial C^{(t)}}{\partial f^{(t)}}\frac{\partial f^{(t)}}{\partial W_f}=\sum_{t=1}^{\tau}\delta_C^{(t)}\odot C^{(t-1)}\odot f^{(t)}(1-f^{(t)})(h^{(t-1)})^T WfL=t=1τC(t)Lf(t)C(t)Wff(t)=t=1τδC(t)C(t1)f(t)(1f(t))(h(t1))T

小结

f}=\sum_{t=1}{\tau}\delta_C{(t)}\odot C^{(t-1)}\odot f{(t)}(1-f{(t)})(h{(t-1)})T$

小结

LSTM虽然结构复杂,但是只要理顺了里面的各个部分和之间的关系,进而理解前向反向传播算法是不难的。当然实际应用中LSTM的难点不在前向反向传播算法,这些有算法库帮你搞定,模型结构和一大堆参数的调参才是让人头痛的问题。不过,理解LSTM模型结构仍然是高效使用的前提。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/小蓝xlanll/article/detail/591880
推荐阅读
相关标签
  

闽ICP备14008679号