赞
踩
假设有一个普通的神经网络,输入一张图片,预测出来是狗
但是如果这张图片真的是一只狼呢?神经网络怎么才可以知道呢?假如我们在看自然的电视节目,前面出现了熊,出现了狐狸,这个时候我们就猜测最后一张图片是一头狼,因为都是野外的动物嘛。这种情况下,就是前面的信息对后面的预测有帮助。
这基本上就是RNN的原理
但是RNN也存在问题,假如最近出现的两张图片分别是树和松鼠,比较远的地方出现的是熊,那么最后一张图片很可能预测是狗,因为熊出现的地方太远了,这中间经过了很多次sigmoid,信息被稀释掉了,影响就很小了。而且更甚的是,一路回去做反向传播,会出现梯度消失的问题
这就是RNN的问题,并不擅长记忆长期内容,只能记忆短期记忆
LSTM就可以解决这个问题,因为它不仅记忆短期的记忆,还有长期记忆
比方说再刚刚的案例中,长期记忆就是“这是关于自然科学的”,“有很多森林动物”,短期记忆就是“松鼠”“树”,有一个事件是判别是“狗还是狼”,这三个东西一起再用来生成三个东西,分别是output,更新的长期记忆,更新的短期记忆(下图的紫色箭头不仅都指向短期记忆,也都指向长期记忆和output)
在LSTM中,中间的过程是有很多“门”来控制的,分别如下图四个颜色所示
其中长期记忆我们用大象表示,它会进入到forget gate,这个gate决定要忘记哪些东西。
短期记忆用金鱼表示,金鱼和狼都会进入到learn gate来学习。
然后没有遗忘的大象,学习好的金鱼和学习好的狼就会进入到remember gate,形成新的长期记忆。
而use gate也会用这些信息形成新的短期记忆,输出output,也就是狼的预测,和新的短期记忆
以上的过程不断地进行,就形成了下图,t代表时间
先来看一下RNN的架构,就是利用short term memory STM和Event E一起,乘上weights,加上bias,然后用激活函数得到新的memory
LSTM的结构就是类似的
Learn gate是接收短期记忆和event,将两者结合起来,然后忽略一部分,只保留重要的部分
这里就是忽略了“树”,只保留了动物的部分
数学表示是这样的:
①短期记忆STM和事件E进来之后,乘上weights,加上bias,然后用tanh激活函数,就产生了新的信息Nt。这部分用的就是conbine的部分。
②然后需要再忽略一部分,那就乘上遗忘因子it。it是一个向量,进行element wise乘法。it的计算依旧需要用到前面的STM和E的信息,这里又有一个小的神经网络,用到新的权重Wi和偏差bi。
如下图所示,这就是学习门的工作原理
这里遗忘门的目的就是接受long term memory,which 包含了自然和科学,我们需要忘记科学,留下自然。
工作原理和上面差不多,接收的是长期记忆LTM,对长期记忆进行遗忘,遗忘的过程就是乘ft,ft和it一样,计算需要用到前面的STM和E的信息,经过一层小小的神经网络就得到了ft。
这个门就更简单了,就是接收长期记忆和短期记忆,当然是处理过的长期记忆和短期记忆,就是把forget gate和learn gate的内容加起来就好了
这个门就是为了得到输出,也是短期记忆,这两个是一个东西。
比方说在这里,就是将long term memory中的东西里面找到一只熊,然后从short term memory中找到一只松鼠,然后得出“你的图片最有可能是一只狼,当然也涉及其他动物”
工作原理就是将遗忘门的输出结果放到一个小型神经网络里面,使用tanh激活函数;然后把短期记忆和事件放到另一个小型神经网络里面,使用sigmoid激活函数;最后一步就是把两者相乘,得到新的输出结果。
LSTM的结构我们已经知道了,但是为什么有的地方用tanh,有的地方用sigmoid,为什么有的地方要+,有的地方要×。其实就是因为试验下来这样有用。所以用了这样的结构。其实还有很多其他可行的结构,下面就介绍一些
GRU把遗忘门和学习门合并为更新门 update gate,更新门的结果交给合并门 combine gate来处理。他只会翻出一个工作记忆,而不是一对长期记忆和短期记忆。
下面列举了 GRUs 的一些参考文献
这里是另一种结构,叫做窥视孔连接。
回忆一下遗忘门的结构,短期以及和事件一起来决定该以往什么,也就是ft的产生,那么为什么长期记忆不来决定哪些内容该遗忘呢?所以peephole连接就加入了长期记忆,也来做决策。这样ft的生成就需要更大的数据。
可以把LSTM的所有遗忘门的部分都换成窥视孔连接,就成了具备窥视孔连接的LSTM
代码:
git clone https://github.com/udacity/deep-learning-v2-pytorch.git
转到
recurrent-neural-networks > time-series
代码分析请看【Pytorch】21. RNN代码分析
其他参考资料
本系列笔记来自Udacity课程《Intro to Deep Learning with Pytorch》
全部笔记请关注微信公众号【阿肉爱学习】,在菜单栏点击“利其器”,并选择“pytorch”查看
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。