当前位置:   article > 正文

简单介绍长短期记忆网络 - LSTM

简单介绍长短期记忆网络 - LSTM

一、引言

1.1 什么是LSTM

首先看看百科的解释。
长短期记忆(英语:Long Short-Term Memory,LSTM)是一种时间循环神经网络(RNN),论文首次发表于1997年。由于独特的设计结构,LSTM适合于处理和预测时间序列中间隔和延迟非常长的重要事件。1

为了更好地理解长短期记忆网络 - LSTM(下文简称LSTM),可以先了解循环神经网络-RNN(下文简称RNN)的相关知识,这里有一些相关的文章。LSTM只是RNN的一个变种,LSTM是为了解决RNN中的梯度消失的问题而提出的。

二、循环神经网络RNN

2.1 为什么需要RNN

人的思想是有记忆延续性。比如当你在阅读这篇文章,你会根据你曾经对每个字的理解来理解这篇文章的字,而不是每次都要思考一个字在这篇文章的语境下到底如何理解(从一个字或词的多种解释来选择一个符合当下语境的解释)。

举个例子:要识别这么一个句子:
The cat, which already ate cakes, () full.2

假设对其中的单词从左到右一个一个地处理,前面已经cat的识别结果是一个单数名词,到后边()里的内容,到底是填were 还是 was,那么就需要根据前边cat的识别结果进行判断。这就是RNN需要做的。

使用神经网络来预测句子中下一个字的解释。传统的神经网络在模型训练好了以后,在输入层给定一个x,通过网络之后就能在输出层得到特定的y。利用这个模型可以通过训练拟合任意函数,但是只能单独的取处理一个个的输入,前一个输出和后一个输出是完全没有关系的

神经网络的结构如下:
Alt
但是,在理解一句话的意思的时候,一个字的意思是跟前面的字相关联的,即前面的输出和后面的输出是有关系的。所以仅仅利用这样的模型是不够的的,为了解决这个问题,有人提出了RNN。
RNN模型构造:
传统RNN模型

RNN神经网络示意图:
RNN模型
蓝色部分的是隐藏层,RNN利用隐藏层将信息向后传递。
我们来看看RNN隐藏层里发生了什么,将上图按时间线展开3

隐藏层

符号意义
X一个向量,输入层的值
S一个向量,隐藏层的值
O一个向量,输出层的值
U输入层到隐藏层的权重矩阵
V隐藏层到输出层的权重矩阵
W隐藏层上一次的值作为这一次输入的权重

再给出一个更具体的图,给出各层元素的对应关系
具体图
现在看上去就比较清楚了,这个网络在 t 时刻接收到输入 x t x_t xt 之后,隐藏层的值是 s t s_t st ,输出值是 o t o_t ot 。关键一点是, s t s_t st 的值不仅仅取决于 x t x_t xt ,还取决于 s t − 1 s_{t-1} st1 我们可以用下面的公式来表示RNN的计算方法:
用公式表示如下:
O t = g ( V ⋅ S t ) O_t = g(V·S_t) Ot=g(VSt)

S t = f ( U ⋅ X t + W ⋅ S t − 1 ) S_t = f(U·X_t + W ·S_{t-1}) St=f(UXt+WSt1
注意:为了简单说明问题,偏置都没有包含在公式里面。

这样,就可以做到的在一个序列中根据前面的输出来影响后面的输出。

三、长短时记忆神经网络LSTM

3.1 为什么需要LSTM

回到我们的例子:
The cat, which already ate …, () full.

这个例子与之前的例子稍微有一些不同,这里的cat 和()之间已经相隔了较长的一段距离,这时候用RNN来处理这样的长期信息就不太合适。

因为RNN在反向传播阶段有梯度消失等问题不能处理长依赖问题,这里的梯度消失是由于RNN在计算过程中使用链式法则

具体来说,RNN使用覆盖的方式来计算状态: S t = f ( S t − 1 , x t ) S_t = f(S_{t-1},x_t) St=f(St1,xt),这类似于复合函数,根据链式求导的法则,复合函数求导:设 f f f g g g x x x 的可导函数,则 ( f ∘ g ) ′ ( x ) = f ′ ( g ( x ) ) g ′ ( x ) (f \circ g)'(x) = f'(g(x))g'(x) (fg)(x)=f(g(x))g(x),这是一种连乘的方式,如果导数小于或大于1,会发生梯度下降以及梯度爆炸。梯度爆炸可以通过剪枝算法解决,但是梯度消失却没办法解决。

梯度消失可能不太好理解,可以简单理解为RNN中后边输入的数据影响越大,前面的数据的影响小,因此不能处理长期信息。后来,有学者在一篇论文Long Short-Term Memory 4 提出了LSTM,LSTM通过选择性地保留信息,有效地缓解了梯度消失以及梯度下降的问题,可以说LSTM正是为了适合学习长期依赖而产生的。

3.2 LSTM结构分析

回顾一下RNN的模型构造:

RNN模型构造
可以看到,RNN循环网络模型的链式结构非常简单,通常仅含有一个tanh层。

LSTM模型构造:
LSTM
而LSTM的链式结构中,循环单元结构不同,里边有四个神经网络层。

先来解释一下图中符号含义:
符号含义

符号含义
黄色矩形神经网络层
粉色圆结点操作,比如向量相加
箭头从一个结点的输出到另外的结点的输入
箭头合并链接
箭头分叉内容复制后副本流向不同的位置

LSTM结构(图右)和普通RNN的主要输入输出区别如下所示:
LSTM对比RNN
相比RNN只有一个传递状态 h t h^t ht , LSTM有两个传输状态,一个 c t c^t ct (cell state), 和一个 h t h^t ht (hidden state)。(RNN中的 h t h^t ht 对应LSTM中的 C t C^t Ct

3.3 LSTM背后的核心思想

LSTM的核心思想,LSTM的关键是细胞状态(cell state),即下图中上边的水平线。cell state像是一条传送带,它贯穿整条链,其中只发生一些小的线性作用。信息流过这条线而不改变是非常容易的。5 改变cell state需要三个门的相互配合。

如下图所示:
细胞状态
LSTM删除或添加信息到cell state,是由被称为门的结构控制的。LSTM中有三个门,“遗忘门” “输入门” 以及“输出门”,用来保护和更新cell的状态。
门是筛选信息的方法,由一个sigmoid网络层和一个点乘操作组成。
如下图:
门
sigmoid层作为激活函数,将输出控制在(0,1)区间内,Sigmoid的函数图形如下:
Sigmoid
可以看到,绝大多数的值都是接近0或者接近1的。利用这一个性质,0 表示不允许任何通过,1 表示允许一切通过。

3.4 LSTM的运行机制

第一步,需要决定从cell state中丢弃什么样的信息,这个由“遗忘门”的sigmoid层决定。根据输入 h t − 1 h_{t-1} ht1 x t x_t xt,得到的输出是0和1之间的数。0 代表“完全保留这个值”,1代表“完全丢弃这个值”。

回到开始的例子,原来的主语是"cat",之后遇到了一个新的主语"cats"。这时需要把之前的"cat"给忘掉,以便确定接下来是要使用"were",而不是"was"。如下图:
遗忘门
第二步,需要决定在cell state里存储什么样的信息。这一步划分为两个部分,一是称为“输入门”的sigmoid层决定哪些数据需要更新。然后,tanh层创建一个新的候选值向量 C ~ t \widetilde{C}_t C t,这些值能加入state中。第二部分,需要将这两个部分合并以实现对state的更新。

在例子中,这里对应于把新的"cats"加入到"cell state"中,以替代需要遗忘的"cat"。如下图:
input gate
在决定好需要遗忘的以及需要加入的记忆之后,就可以把旧的cell state C t − 1 C_{t-1} Ct1更新到新的cell state C t C_t Ct。 这一步中,把旧的state C t − 1 C_{t-1} Ct1 f t f_t ft 相乘,遗忘先前决定遗忘的东西,之后加上新的记忆信息 i t ∗ C ~ t i_t \ast \widetilde{C}_t itC t。这里为了体现对状态值的更新度是有限制的,可以把 i t i_t it当成一个权重。如下图:
更新
最后,需要决定输出。这个输出将会基于cell state ,这是一个过滤后的值。首先,使用“输出门”的sigmoid层决定输出cell state的哪些部分的。然后,将cell state放入tanh(将数值限制在-1到1),最后将结果与sigmoid门的输出相乘,这样就可以只输出需要的部分。如下图:
输出门

3.5 LSTM如何避免梯度下降

上边提到了RNN中的梯度下降以及梯度爆炸问题,是是因为在计算过程中使用链式法则,使用了乘积。而在LSTM中,状态是通过累加的方式来计算, S t = ∑ τ = 1 t Δ S τ S_t = \sum_{\tau =1}^t \Delta S_{\tau} St=τ=1tΔSτ。这样的计算,就不是复合函数的形式,它的导数也就不是乘积的形式,就不会发生梯度消失的情况。

四、入门例子

下面给出LSTM的一个入门实例-根据前9年的数据预测后3年的客流6,感谢原作者的代码,完整的代码见GithubYonv1943。这里简单说一下这个代码实例的结果,需要了解更加详细的代码细节可以看看原作者的原文详解。

考虑有一组某机场1949年~1960年12年共144个月的客流量数据。使用这个数据中的前9年的客流量来预测后3年的客流量,再和实际的数据进行比对,可以看出LSTM的对这类具有时序关系的拟合效果。

结果图:
结果图

  • 数据:机场1949~1960年12年共144个月的客流量数据。数据具有三个维度[客运量,年份,月份]。其中前75%(前9年)的数据作为训练集,后25%(后3年)的数据作为测试集。
  • 纵坐标:标准化处理:变量值与平均数的差除以标准差,给出数值的相对位置。横坐标为月数。
  • 图解释:竖直黑线左边是训练集(前9年)。右边(后3年)红色的是预测数值,蓝色的是实际数值。

可以看到在这个LSTM对这个数据集的拟合效果是比较好的,在这样的实际场景中,可以利用LSTM这样的工具来对客流量做一个预测,以便对客运高峰等情况做好预备方案。

五、总结

  • RNN的计算中存在多个偏导数连乘,导致梯度消失或梯度爆炸,难以处理长依赖的信息。
  • LSTM通过三个选择性地保留信息,可以选择最近的信息或者很久之前的信息。
  • LSTM更新cell state是采用了线性求和的计算,因此不会出现梯度消失问题,可以处理长期依赖的信息。

六、参考资料


  1. 长短期记忆 ↩︎

  2. 吴恩达深度学习课程 ↩︎

  3. 一文搞懂RNN(循环神经网络)基础篇 ↩︎

  4. Long Short-Term Memory ↩︎

  5. Understanding LSTM Networks ↩︎

  6. LSTM入门例子:根据前9年的数据预测后3年的客流(PyTorch实现) ↩︎

声明:本文内容由网友自发贡献,转载请注明出处:【wpsshop博客】
推荐阅读
相关标签
  

闽ICP备14008679号