当前位置:   article > 正文

BP算法和RNN_RNN/LSTM BPTT详细推导以及梯度消失问题分析

适用于单层rnn的bp算法

b36be7ef4a3b052465c5e0ce32b5996d.png

最近面试被问到了LSTM为什么能够解决long-range dependency的问题,回答这个问题实际上需要把BPTT公式写出来,在这篇博文中我们进行了部分推导

习翔宇:RNN Part 3-RNN中的BPTT算法和梯度消失问题​zhuanlan.zhihu.com
9c0f87636325f6691aff20a8184e722e.png

但是不够系统化,本篇博文将完全对RNN的BPTT以及LSTM的BPTT进行推导,并对long-range dependency问题进行分析


1. RNN的BPTT

假设RNN的基本方程如下所示

损失函数定义如下:

对于一个输入序列

,其整体损失函数为

我们接下来分别对

进行求导

首先对

进行求导,这个比较简单

然后对

进行求导

如下公式可知

的计算涉及到
,而
的计算也涉及到
,同样
的计算涉及到
,而
的计算也涉及到
,以此类推,因此需要回溯到t时刻之前的所有时刻,我们需要对公示(6)中的第三项
进行展开,下面我们单独对其进行展开如下所示:

同样的道理,公示(8)中的第一项

的计算如下所示

将其带入到公示(8)中即可得到

这样我们把公式(6)中的第三项就展开了,现在带入公式(6)中即可得到:

按照同样的方式,我们对

进行求导


2. RNN梯度消失分析

在上面的推导中,我们对

部分的推导公式(11),(12)可以看到,在计算
时刻的损失产生的梯度时,必须回溯之前所有时刻
的信息,并且存在连乘项
,根据公式(1)我们可以计算

sigmoid函数的导数大家都很熟悉了,处于

之间,那么会有以下两种情况:
  1. > 4的时候,那么
    ,此时如果
    距离过大,会导致连乘项过多,产生梯度爆炸,趋近于无穷
  2. <4的时候,那么
    ,此时如果
    距离过大,会导致连乘项过多,产生梯度消失,趋近于0

因此当输入序列过长的时候,在求取一个比较远的时刻

的梯度时,需要回溯到前面的所有时刻的信息,由于连乘项的存在,导致前面时刻的信息会缺失,这就是RNN中的梯度消失问题,也是所谓的long-range dependency问题(这样划一个约等号会不会太草率?);

梯度爆炸问题容易解决,例如采用clip的方式即可。但是梯度消失的问题比较难以解决,我们下面介绍LSTM为什么能够缓解梯度消失问题


3. LSTM BPTT推导及梯度消失分析

LSTM的公式如下所示

8378cd8577a8c5e709fde8b4062b9c48.png

其中

可以看作之前RNN中的
,我们将
的计算公式展开如下所示:

那么需要连乘的部分计算可得:

从之前的

变成了sigmoid函数,范围在[0,1]之间,在实际参数更新中,可以通过控制使得其接近于1,因此多次连乘依然不会产生梯度消失,在
距离较大的情况下,依然能够较好的利用
时刻的信息进行梯度计算。

4. 一些思考

本文由两个问题后续进行提升:

  1. LSTM部分的推导并不十分严谨,在RNN BPTT的基础上进行了类比
  2. 梯度消失问题以及long-range dependency问题的定义需要明确,本文进行了约等于,是否准确还有待商榷
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/爱喝兽奶帝天荒/article/detail/765446
推荐阅读
相关标签
  

闽ICP备14008679号