当前位置:   article > 正文

【Pytorch】21. 循环神经网络RNN和LSTM_chris olah rnn lstm

chris olah rnn lstm


本节我们学习recurrent neural networks (RNNs)和 long short-term memory (LSTM)

RNN

假设有一个普通的神经网络,输入一张图片,预测出来是狗
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-4vimcEea-1618547689413)(en-resource://database/16973:1)]

但是如果这张图片真的是一只狼呢?神经网络怎么才可以知道呢?假如我们在看自然的电视节目,前面出现了熊,出现了狐狸,这个时候我们就猜测最后一张图片是一头狼,因为都是野外的动物嘛。这种情况下,就是前面的信息对后面的预测有帮助。

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-kzGjKILd-1618547689418)(en-resource://database/16975:1)]

这基本上就是RNN的原理

LSTM

但是RNN也存在问题,假如最近出现的两张图片分别是树和松鼠,比较远的地方出现的是熊,那么最后一张图片很可能预测是狗,因为熊出现的地方太远了,这中间经过了很多次sigmoid,信息被稀释掉了,影响就很小了。而且更甚的是,一路回去做反向传播,会出现梯度消失的问题

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-CziTosQ9-1618547689421)(en-resource://database/16979:1)]

这就是RNN的问题,并不擅长记忆长期内容,只能记忆短期记忆

LSTM就可以解决这个问题,因为它不仅记忆短期的记忆,还有长期记忆

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-hp0LWKYr-1618547689425)(en-resource://database/16981:1)]

比方说再刚刚的案例中,长期记忆就是“这是关于自然科学的”,“有很多森林动物”,短期记忆就是“松鼠”“树”,有一个事件是判别是“狗还是狼”,这三个东西一起再用来生成三个东西,分别是output,更新的长期记忆,更新的短期记忆(下图的紫色箭头不仅都指向短期记忆,也都指向长期记忆和output)

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-zuMlm0Mp-1618547689426)(en-resource://database/17005:1)]

在LSTM中,中间的过程是有很多“门”来控制的,分别如下图四个颜色所示

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-DU7TNzVb-1618547689428)(en-resource://database/17009:1)]

其中长期记忆我们用大象表示,它会进入到forget gate,这个gate决定要忘记哪些东西。

短期记忆用金鱼表示,金鱼和狼都会进入到learn gate来学习。

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-ACF0pyis-1618547689430)(en-resource://database/17007:1)]

然后没有遗忘的大象,学习好的金鱼和学习好的狼就会进入到remember gate,形成新的长期记忆。

而use gate也会用这些信息形成新的短期记忆,输出output,也就是狼的预测,和新的短期记忆

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-vBCoZwD7-1618547689431)(en-resource://database/17011:1)]

以上的过程不断地进行,就形成了下图,t代表时间

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-YdsDByHT-1618547689432)(en-resource://database/17013:1)]

LSTM架构

先来看一下RNN的架构,就是利用short term memory STM和Event E一起,乘上weights,加上bias,然后用激活函数得到新的memory
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-BWOjL5k5-1618547689434)(en-resource://database/17015:1)]

LSTM的结构就是类似的

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-1T4esUoA-1618547689436)(en-resource://database/17017:1)]

Learn Gate学习门

Learn gate是接收短期记忆和event,将两者结合起来,然后忽略一部分,只保留重要的部分

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-OQaPCaXC-1618547689438)(en-resource://database/17019:1)]

这里就是忽略了“树”,只保留了动物的部分

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-4BpzD9uA-1618547689439)(en-resource://database/17021:1)]

数学表示是这样的:

①短期记忆STM和事件E进来之后,乘上weights,加上bias,然后用tanh激活函数,就产生了新的信息Nt。这部分用的就是conbine的部分。

②然后需要再忽略一部分,那就乘上遗忘因子it。it是一个向量,进行element wise乘法。it的计算依旧需要用到前面的STM和E的信息,这里又有一个小的神经网络,用到新的权重Wi和偏差bi。

如下图所示,这就是学习门的工作原理

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-CxBMsJ1s-1618547689439)(en-resource://database/17023:1)]

Forget Gate 遗忘门

这里遗忘门的目的就是接受long term memory,which 包含了自然和科学,我们需要忘记科学,留下自然。

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-2cadXfIx-1618547689441)(en-resource://database/17027:1)]

工作原理和上面差不多,接收的是长期记忆LTM,对长期记忆进行遗忘,遗忘的过程就是乘ft,ft和it一样,计算需要用到前面的STM和E的信息,经过一层小小的神经网络就得到了ft。

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-D7oYEs1H-1618547689441)(en-resource://database/17025:1)]

Remember Gate记忆门

这个门就更简单了,就是接收长期记忆和短期记忆,当然是处理过的长期记忆和短期记忆,就是把forget gate和learn gate的内容加起来就好了

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-l2HVgXYG-1618547689442)(en-resource://database/17029:1)]

Use Gate 应用门

这个门就是为了得到输出,也是短期记忆,这两个是一个东西。

比方说在这里,就是将long term memory中的东西里面找到一只熊,然后从short term memory中找到一只松鼠,然后得出“你的图片最有可能是一只狼,当然也涉及其他动物”

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-DDr760TL-1618547689444)(en-resource://database/17033:1)]

工作原理就是将遗忘门的输出结果放到一个小型神经网络里面,使用tanh激活函数;然后把短期记忆和事件放到另一个小型神经网络里面,使用sigmoid激活函数;最后一步就是把两者相乘,得到新的输出结果。

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-ojITqARR-1618547689445)(en-resource://database/17037:1)]

其他架构

LSTM的结构我们已经知道了,但是为什么有的地方用tanh,有的地方用sigmoid,为什么有的地方要+,有的地方要×。其实就是因为试验下来这样有用。所以用了这样的结构。其实还有很多其他可行的结构,下面就介绍一些

Gated Recurrent Unit(GRU)

GRU把遗忘门和学习门合并为更新门 update gate,更新门的结果交给合并门 combine gate来处理。他只会翻出一个工作记忆,而不是一对长期记忆和短期记忆。

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-e1lBXJBF-1618547712737)(en-resource://database/17043:1)]

下面列举了 GRUs 的一些参考文献

Peephole Connections

这里是另一种结构,叫做窥视孔连接

回忆一下遗忘门的结构,短期以及和事件一起来决定该以往什么,也就是ft的产生,那么为什么长期记忆不来决定哪些内容该遗忘呢?所以peephole连接就加入了长期记忆,也来做决策。这样ft的生成就需要更大的数据。

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-TWHt2KFY-1618547712739)(en-resource://database/17045:1)]

可以把LSTM的所有遗忘门的部分都换成窥视孔连接,就成了具备窥视孔连接的LSTM

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-8bYNWNpG-1618547712742)(en-resource://database/17047:1)]


代码:

git clone https://github.com/udacity/deep-learning-v2-pytorch.git

转到
recurrent-neural-networks > time-series

代码分析请看【Pytorch】21. RNN代码分析


其他参考资料

本系列笔记来自Udacity课程《Intro to Deep Learning with Pytorch

全部笔记请关注微信公众号【阿肉爱学习】,在菜单栏点击“利其器”,并选择“pytorch”查看

在这里插入图片描述

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/Cpp五条/article/detail/373273?site
推荐阅读
相关标签
  

闽ICP备14008679号