当前位置:   article > 正文

全面解析RNN,LSTM,Seq2Seq,Attention注意力机制_rnn编码器最后输出的是什么

rnn编码器最后输出的是什么

原文

本文将会使用大量的图片和公式推导通俗易懂地讲解RNN,LSTM,Seq2Seq和attention注意力机制(结合colah’s blog 和CS583),希望帮助初学者更好掌握且入门,若有已经掌握RNN,LSTM的读者自行跳过阅读即可,更详细的讲解在Seq2Seq,Attention那。

目录

  • RNN
  • LSTM
  • Seq2Seq
  • 注意力机制
  • 参考

1 RNN(递归神经网络)

我们知道人类并不是从零开始思考东西,就像你读这篇文章的时候,你对每个字的理解都是建立在前几个字上面。你读完每个字后并不是直接丢弃然后又从零开始读下一个字,因为你的思想是具有持续性的,很多东西你要通过上下文才能理解。

然而传统的神经网络并不能做到持续记忆理解这一点,这是传统神经网络的主要缺点。举个例子,你打算使用传统的神经网络去对电影里每个时间点发生的事情进行分类的时候,传统的神经网络先让不能使用前一个事件去推理下一个事件。

RNN(递归神经网络)可以解决这个问题。他们是带有循环的神经网络,允许信息在其中保留。

img

在上图中,A代表神经网络主体,[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-IBnbXn3K-1641869995158)(https://www.zhihu.com/equation?tex=X_t)]表示网络的输入,[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-mYAiWVQK-1641869995160)(https://www.zhihu.com/equation?tex=h_t)]表示网络的输出。循环结构允许信息从当前输出传递到下一次(下个时间点)的网络输入。

这些循环让递归神经网络看起来有点神秘,然而如果你再思考一下,RNN其实和传统的神经网络并没有太多的不同。RNN可以看作是一个网络的多次拷贝,其中每次网络的输出都是下一次的输入。我们可以思考一下我们如果展开这个循环结构会是什么样的:

img

这种像是链状的网络结构表明RNN和序列以及列表有着天然的联系,他们是处理这些序列数据的天然的神经网络。而且很明显我们可以看出,输入输出的序列是具有相同的时间长度的,其中的每一个权值都是共享的(不要被链式形状误导,本质上只有一个cell)。

在最近的几年,RNN在很多问题上都取得了成功:比如语音识别,语音模型,翻译,图片注释等等,但是RNN存在着梯度消息/爆炸以及对长期信息不敏感的问题,所以LSTM就被提出来了。现在很多问题的成功都必须归功于LSTM,它是递归神经网络的一种,它在许多的任务中表现都比普通的RNN更好,所以接下来我们来探索一下这个神奇的网络。

2 LSTM

2.1 长期依赖问题

人们希望RNN可以将一些之前的信息连接到当前的任务中来,比如使用之前的视频帧来帮助理解当前帧。如果RNN可以做到将会非常有用。那实际RNN能做到吗?这要视情况而定。

有时候,我们只需要当前的信息来完成当前的任务。举个例子,一个语音模型试图基于之前的单词去预测下一个单词。如果我们尝试预测“the clouds are in the sky”,我们不需要太多的上下文信息——很明显最后一个单词会是sky。在像这样不需要太多的相关信息的场合下,RNN可以学习到之前使用的信息。

img

但是我们要注意,也有很多场景需要使用更多的上下文。当我们试图去预测“I grew up in France… I speak fluent French”这句话的最后一个单词,最近的信息会表明这应该是一种语言的名字,但是如果我们需要知道具体是哪一种语语言,我们需要France这个在句子中比较靠前的上下文信息,相关信息和需要预测的点的间隔很大的情况是经常发生的。

不幸的是,随着间隔变大,RNN变得无法连接到太前的信息。

img

理论上RNN完全可以处理这种长期依赖(long-term dependencies)的问题。人们可以通过小心地选择参数来解决这个问题。令人悲伤的是,实践表明RNN并不能很好地解决这个问题,Hochreiter (1991) [German] and Bengio, et al. (1994)发现了RNN为什么在这些问题上学习很困难的原因。

而LSTM则没有这个问题。

2.2 LSTM网络

长期短期记忆网络-通常叫做LSTM-是一种特殊结构的RNN,它能够学习长期依赖。它在大量的问题有惊人的效果,现在已经被广泛使用。

LSTM被明确设计来避免长期依赖问题,记住长时间的信息对LSTM来说只是常规操作,不像RNN那样费力不讨好。

所有的RNN都有不断重复网络本身的链式形式。在标准的RNN中,这个重复复制的模块只有一个非常简单的结果。例如一个tanh层:

img

LSTM也有这样的链式结构,但是这个重复的模块和上面RNN重复的模块结构不同:LSTM并不是只是增加一个简单的神经网络层,而是四个,他们以一种特殊的形式进行交互:

img

读者不需要担心看不懂,接下来我们将会一步步理解这个LSTM图。首先我们先了解一下图中的符号:

img

在上图中,每条线表示一个向量,从一个输出节点到其他节点的输入节点。粉红色的圆圈表示逐点式操作,就像向量加法。黄色的盒子是学习好的神经网络层。线条合代表联结,线条分叉则表示内容被复制到不同的地方。

2.3 LSTM背后的核心思想

LSTM的核心之处就是它的cell state(神经元状态),在下图中就是那条贯穿整个结果的水平线。这个cell state就像是一个传送带,他只有很小的线性作用,但却贯穿了整个链式结果。信息很容易就在这个传送带上流动但是状态却不会改变。cell state上的状态相当于长期记忆,而下面的[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-mUDq4noq-1641869995170)(https://www.zhihu.com/equation?tex=h_t)]则代表短期记忆。

img

LSTM有能力删除或者增加cell state中的信息,这一个机制是由被称为门限的结构精心设计的。

门限是一种让信息选择性通过的方式,它们是由sigmoid神经网络层和逐点相乘器做成的。

img

sigmoid层输出0和1之间的数字来描述一个神经元有多少信息应该被通过。输出0表示这些信息全部不能通过,而输出1则表示让所有信息都通过。

一个LSTM有三个这样的门限,去保护和控制神经元的状态。

2.4 一步步推导LSTM

LSTM的第一步就是决定什么信息应该被神经元遗忘。这是一个被称为“遗忘门层”的sigmod层组成。他输入[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-0XSLlGva-1641869995175)(https://www.zhihu.com/equation?tex=h_%7Bt-1%7D)]和[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-LvCJGxVu-1641869995177)(https://www.zhihu.com/equation?tex=X_t)](上一次的输出以及这轮的输入),然后在[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-3mbo1TqR-1641869995179)(https://www.zhihu.com/equation?tex=C_%7Bt-1%7D)]的每个神经元状态输出0和1之间的数字。同理1表示完全保留这些信息,0表示完全遗忘这个信息。

让我们再次回到一开始举的例子:根据之前的词语去预测下一个单词的语言模型。在这个问题中,cell state或许包括当前主语中的性别信息,所以我们可以使用正确的代词。而当我们看到一个新的主语(输入),我们会去遗忘之前的性别信息。我们使用下图中的公式计算我们的“遗忘系数”[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-fUta9fnp-1641869995180)(https://www.zhihu.com/equation?tex=f_t)]

img

下一步就是决定我们要在cell state中保留什么信息。这包括两个部分。首先,一个被称为“输入门层”的sigmoid层会决定我们要更新的数值。然后一个tanh层生成一个新的候选数值[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-B71VCV2N-1641869995183)(https://www.zhihu.com/equation?tex=C_t%5E%EF%BD%9E)],它会被增加到cell state中。在下一步中,我们将会组合这两步去生成一个新的更新状态值。

在那个语言模型例子中,我们想给cell state增加主语的性别,来替换我们将要遗忘的旧的主语。

img

现在是时候去更新旧的神经元状态[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-T2rDQAxy-1641869995186)(https://www.zhihu.com/equation?tex=C_%7Bt-1%7D)]到新的神经元状态[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-t62eY9RV-1641869995187)(https://www.zhihu.com/equation?tex=C_t)]。之前我们已经决定了要做什么,下一步我们就去做。

我们给旧的状态乘一个遗忘系数[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-H75MqHo2-1641869995189)(https://www.zhihu.com/equation?tex=f_t)],来遗忘掉我们之前决定要遗忘的信息,然后我们增加[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-1Gn5m9Jc-1641869995190)(https://www.zhihu.com/equation?tex=i_t+%2A+C_t)]。这是新的候选值,由我们想多大程度更新每个状态的值决定。

在语言模型中,就像上面描述的,这是我们实际上要丢弃之前主语的性别信息,增加新的主语的性别信息的地方。

img

最后,我们需要决定我们要输出什么。这个输出是建立在我们的cell state的基础上,但是这里会有一个滤波器。首先,我们使用sigmoid层决定哪一部分的神经元状态需要被输出;然后我们让cell state经过tanh(让输出值变成-1到1之间)层并且乘上sigmod门限的输出,这样我们就只输出我们想要输出的。

对于那个语言模型的例子,当我们看到一个新的主语的时候,或许我们想输出相关动词的信息,因为动词是跟在主语后面的。例如,它或许要输出主语是单数还是复数的,然后我们就知道主语后动词的语态了。

2.5 LSTM的一些变体

上面讲的都是一些常规的LSTM,但并不是所有的LSTM都是上面这种形式。实际上现在很多包含LSTM的论文都有小的差异,但是它值得一提。

**Gers & Schmidhuber (2000)**引入了一个流行的LSTM变体,它增加了一个窥视孔连接。这意味着我们让门限层监视cell state的状态。

img

上图中给每一个门限都增加了窥视孔,但是有些论文只是给一部分的门限增加窥视孔,并不是全部都加上。

另外一个变体是使用组合遗忘和输入门,而不是分开决定哪些神经元需要遗忘信息,哪些需要增加新的信息,我们组合起来决定。我们只遗忘那些需要被放入新信息的状态,同样我们旨在旧信息被遗忘之后才输入新的信息。

img

一个更神奇的LSTM变体是门递归单元(也就是大家常说的GRU),它组合遗忘门和输入门为一个更新门,它合并了cell state和隐层状态,并且做了一些其他的改变。最终这个模型比标准的LSTM更简单,并且变得越来越流行。

img

这里只介绍了几个最有名的LSTM的变体,还有更多变体没有介绍,就像Yao, et al.(2015)深度门递归神经网络(Depth Gated RNNs)。这里也有一些处理长期依赖问题问题的完全不同的方法,就像Koutnik, et al(2014)提出的时钟机递归神经网络(Clockwork RNNs)。

2.6 结论

我们一开始提到人们使用RNN取得了卓越的成果,但其实本质上都是使用LSTM取得的,他们的确在多数任务上表现得更好。

写下来一系列等式以后,LSTM看起来挺吓人,但在文中一步步解释后它变得可以理解了。我们不禁想问:是否有比LSTM更好的模型?学者一致认为:那就是attention注意力机制。核心观点就是让RNN每一步都监视一个更大的信息集合并从中挑选信息。例如:如果你使用RNN去为一个图像生成注释,它会从图像中挑选一部分去预测输出的单词。接下来在讲解attention之前,我们会先聊聊Seq2Seq。

3 Seq2Seq

我将会结合一个机器翻译的例子来给大家形象地介绍Seq2Seq

img

在这个例子中,我们试图将英语转换为德语,这里要注意这里是一个多对多的模型,而且输入和输出的长度都不固定。

3.1 准备数据

img

因为只是做一个例子,所以我们在http://www.manythings.org/anki/这个网站选一个小规模的数据来训练一个简单的Seq2Seq即可,我们可以看到左边是英语句子,右边则是翻译的德语句子。

img

我们先进行一下预处理,比如把大写字母变成小写,把标点符号去掉等等。

img

预处理完之后我们要做tokenization,即把一句话分成很多个单词或者字符,这里要注意做tokenization的时候要用两个tokenization,英语用一个,德语用一个;tokenization之后要建立两个字典,一个英语字典,一个德语字典,后面会解释我为什么要这么做。

img

tokenization可以是char-level,也可以是word-level,顾名思义前者就是会把一句话分为一个个字符,而后者则会把一句话分成一个个单词,为了简单方便,我们使用char-level来说明。

img

经过tokenization之后一句话变成了一个list,每个元素都是一个字符,但实际中一般都使用word-level,因为他们的数据集足够大,这在之后会解释。

img

我们前面说了tokenization要用两个不同的字典,这是因为不同的语言它的字母表不同,无法进行统一的映射,如上图所示。

img

如果你使用word-level,那就更有必要使用两个不同的字典,比如很多德语单词在英语字典中是找不到的,而且不同语言分词方便也是不一样的。

img

左边是英语字典,包括26个字母和一个空格符,德语字典删去了一些不常用字母后再加入空格符,另外可以发现德语字典多了一个起始符和一个终止符,这里用什么都行,只要别跟字典字符冲突就可以,后面大家就知道这两个符号的作用。

img

tokenization结束之后每句话就变成了一个字符字典,然后原字符经过字典映射后就变成了下面这个序列,对于德语也是一样。

img

接下来我们还可以把这些数字变成One-hot向量表示,黑色表示1,白色表示0。经过One-hot每个字符就变成了一个向量,每句话就变成了一个矩阵,这就是我们的输入,现在数组准备好了,我们来搭建我们的Seq2Seq模型。

3.2 搭建并训练Seq2Seq模型

Seq2Seq有一个编码器和一个解码器,编码器一般是LSTM或者其他模型用于提取特征,它的最后一个输出就是从这句话得出的最后的特征,而其他的隐层输出都被丢弃。

img

编码器提取特征之后就到了解码器,解码器靠编码器最后输出的特征也就是[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-jX6zMnUX-1641869995225)(https://www.zhihu.com/equation?tex=%28h%2Cc%29)]来知道这句话是"go away",这里要强调一下Decoder的初始状态就是Encoder的最后一个状态,现在Decoder开始输出德语字母,这里Decoder也是一个LSTM模型,他每次接受一个输入然后输出下一个字母的概率,第一个输入必须是起始符,这就是我们为什么要在德语字典中要加入起始符的原因。Decoder会输出一个概率分布p向量,起始符后面的第一个字母是m,我们将m做一个one-hot编码作为y标签,用标签y和预测p做一个CrossEntropy来作为我们的损失函数优化,梯度从Decoder传回Encoder。

img

然后输入是两个字符,起始符和m,下一个字母是a,我们将a做one-hot编码作为y标签,将它与我们输出的概率分布做一个CrossEntropy来作为损失函数,一直进行这个循环,应该就很好理解了。

img

最后一轮将整句德语作为输入,将停止符做标签y,再进行CrossEntropy,拿所有的英语和德语来训练我们的编码器和解码器,这就是我们的训练过程了。

img

总结一下,我们使用英语句子的one-hot矩阵作为encoder的输入,encoder网络由LSTM组成来提取特征,它的输出是最后一个状态[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-IExacI1e-1641869995235)(https://www.zhihu.com/equation?tex=h)]和传送带[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-i5Z9ro9c-1641869995237)(https://www.zhihu.com/equation?tex=c)],decoder网络的初始状态是[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-cceoqH1X-1641869995239)(https://www.zhihu.com/equation?tex=%28h%2Cc%29)],decoder网络的输入是德语句子,decoder输出当前状态[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-Lt55vJI4-1641869995240)(https://www.zhihu.com/equation?tex=h%5E%60)]`,然后全连接层输出下一个字符的预测,这样我们的训练阶段就结束了。

img

3.3 预测阶段

同样,我们先把句子输入到我们的Encoder里面,Encoder会输入最后状态[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-ALDUC551-1641869995244)(https://www.zhihu.com/equation?tex=%28h_0%2Cc_0%29)],作为这句话的特征送给Decoder。

img

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-ysi4okqN-1641869995248)(https://www.zhihu.com/equation?tex=%28h_0%2Cc_0%29)]作为Decoder的初始状态,这样解码器就知道这句话是go away,首先把起始符输入,有了新的状态解码器就会把状态更新为[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-Pm1JZDuP-1641869995249)(https://www.zhihu.com/equation?tex=%28h_1%2Cc_1%29)]并且预测下一个字符,decoder输出的是每个字符的概率值,我们可以根据这个概率值进行预测,比如我们可以选取概率值最大的字符,也可以对概率进行随机抽样,我可能会得到字符m,于是我把m记录下来。

img

现在状态是[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-6AyxXyHI-1641869995256)(https://www.zhihu.com/equation?tex=%28h_1%2Cc_1%29)],把新生成的字符m作为LSTM的输入,接下来再更新状态为[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-6Pw96b8s-1641869995259)(https://www.zhihu.com/equation?tex=%28h_2%2Cc_2%29)],并且输出一个概率分布,根据概率分布抽样我们得到字符a,记录下字符a,并一直进行这个循环。

img

运行14轮了状态是[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-TnkT1nmR-1641869995262)(https://www.zhihu.com/equation?tex=%28h_%7B14%7D%2Cc_%7B14%7D%29)],再结合上一轮生成的字符e,根据decoder输出的概率分布抽样,我们抽到了终止符,一旦抽到了终止符,就终止文本生成,并返回记录下的字符串,德语也就被成功翻译了。

3.4 总结

Seq2Seq模型有一个encoder网络和一个Decoder网络,在我们的例子中encoder的输入是英语句子,每输入一个词RNN就会更新状态并记录下来,encoder最后一个状态就是这个句子的特征,并把之前的状态丢弃。把这个状态作为decoder的初始状态,初始化后decoder就知道这个句子了,首先把起始符作为decoder的输入,然后一步步更新,输出状态和概率分布预测下一个字符,再把预测的字符作为下一个输入,重复这个过程,最后直到预测终止符就返回输出的这个序列。

3.5 如何提升?

我们的encoder和decoder都是LSTM,encoder把所有句子的特征压缩到最后一个状态,理想情况下encoder最后一个状态包含完整的信息,假如句子很长,那么句子有些信息就会被遗忘,那么Decoder就没有完整的句子信息,那decoder输出的德语句子就不完整。

img

一种简单方法就是使用双向LSTM,双向LSTM简单来说就是用两条链,从左到右这条链可能会遗忘最左边的信息,而从右往左的这条链可能会遗忘右边的信息,这样结合起来就不容易遗忘句子信息,这里要注意只是encoder用双向LSTM,decoder是单向LSTM,他要生成正确顺序的序列。

img

这次我们用的是char-level比较方便,但是最好还是使用word-level,因为用单词代替字母,序列就会短大概4.5倍,就不容易遗忘,但是用word-level需要大的数据集,得到的单词大概就是一万,one-hot之后向量的维度也就是一万,太大了,需要embedding进行降维,因为embedding参数很多,所以如果数据集不够很容易过拟合。

img

另外一种方法改进就是multi-Task learning,我们还可以多加入几个任务,比如让英语句子让他自己翻译成英语句子,这样encoder只有一个但是数据多了一倍,这样encoder就能被训练的更好,当然你还可以添加其他语言的任务,通过借助其他语言更好训练encoder,这样虽然decoder没有变得更好,但是因为encoder提取的更好最后效果也会变好。

当然还有一个方法就是使用注意力机制,这个对机器翻译提高作用很大,我们接下来就讲解这个注意力机制。

4 注意力机制

我们知道Seq2Seq模型有一个缺点就是句子太长的话encoder会遗忘,那么decoder接受到的句子特征也就不完全,我们看一下下面这个图,纵轴BLUE是机器翻译的指标,横轴是句子的单词量,我们可以看出用了attention之后模型的性能大大提升。

img

用了注意力机制,Decoder每次更新状态的时候都会再看一遍encoder所有状态,还会告诉decoder要更关注哪部分,这也是attention名字的由来。但是缺点就是计算量很大。

img

4.1 attention原理

在encoder结束之后,attention和decoder同时工作,回忆一下,decoder的初始状态[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-irxph1y4-1641869995275)(https://www.zhihu.com/equation?tex=s_0)]是encoder最后一个状态,不同于常规的Seq2Seq,encoder所有状态都要保留,这里需要计算[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-UdEJ10Ln-1641869995277)(https://www.zhihu.com/equation?tex=s_0)]与每个状态的相关性,我使用[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-L2Uxy3CK-1641869995278)(https://www.zhihu.com/equation?tex=%5Calpha_i+%3D+aligh%28h_i%2Cs_0%29)]这个公式表示计算两者相关性,把结果即为[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-KJwGFKba-1641869995280)(https://www.zhihu.com/equation?tex=%5Calpha_i)],记做Weight,encoder有m个状态,所以一共有m个[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-TENGrHi9-1641869995281)(https://www.zhihu.com/equation?tex=%5Calpha)],这里所有的值都是介于0和1的实数,全部加起来为1。

img

下面看一下怎么计算这个相似性。第一种方法是把[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-S7ZLpN66-1641869995286)(https://www.zhihu.com/equation?tex=h_i)]和[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-peKHzf4E-1641869995287)(https://www.zhihu.com/equation?tex=S_o)]做concat得到更高的向量,然后求矩阵W与这个向量的乘积,得到一个向量,然后再将tanh作用于向量每一个元素,将他压到-1和1之间,最后计算向量V与刚才计算出来的向量的内积,这里的向量V和矩阵W都是参数,需要从训练数据里学习,算出m个[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-b5OgOKrW-1641869995289)(https://www.zhihu.com/equation?tex=%5Calpha)]后,需要对他们做一个softmax变换,把输出结果记做[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-yzm7LRdz-1641869995291)(https://www.zhihu.com/equation?tex=%5Calpha_1)]到[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-oESygOTZ-1641869995292)(https://www.zhihu.com/equation?tex=%5Calpha_m)],因为是softmax输出,所以他们都大于0相加为1,这是第一篇attention论文提出计算的方法,往后有很多其他计算的方法,我们来介绍一种更常用的方法。

img

输入还是[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-kJpEk4Oo-1641869995295)(https://www.zhihu.com/equation?tex=h_i)]和[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-qItWUJZ5-1641869995296)(https://www.zhihu.com/equation?tex=S_0)],第一步是分别使用两个参数矩阵[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-uVLT843D-1641869995298)(https://www.zhihu.com/equation?tex=W_k)],[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-yUeiMIM5-1641869995300)(https://www.zhihu.com/equation?tex=W_q)]做线性变换,得到[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-FxbTsHen-1641869995301)(https://www.zhihu.com/equation?tex=k_i)]和[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-ews4UWoH-1641869995303)(https://www.zhihu.com/equation?tex=q_0)]这两个向量,这两个参数矩阵要从训练数据中学习。第二步是计算[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-vFyBC5o2-1641869995309)(https://www.zhihu.com/equation?tex=k_i)]与[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-iTF9J8G7-1641869995310)(https://www.zhihu.com/equation?tex=q_0)]的内积,由于有m个K向量,所以得到L个[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-uTFtcShP-1641869995311)(https://www.zhihu.com/equation?tex=%5Calpha_i)]。第三步就是对这些值做一个softmax变换,[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-3b6yrKuq-1641869995313)(https://www.zhihu.com/equation?tex=%5Calpha_1)]到[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-pTrlYgfe-1641869995315)(https://www.zhihu.com/equation?tex=%5Calpha_m)],因为是softmax输出,所以他们都大于0相加为1。这种计算方法被Transformer模型采用,Transformer模型是当前很多nlp问题采用的先进模型。

img

刚才讲了两种方法来计算[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-I7GF5zGZ-1641869995318)(https://www.zhihu.com/equation?tex=h_i)]和[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-cWJfyEud-1641869995319)(https://www.zhihu.com/equation?tex=S_0)]的相关性,现在我们得到了m个相关性[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-czDyCga1-1641869995321)(https://www.zhihu.com/equation?tex=%5Calpha)],每个[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-IsZcMk0Y-1641869995323)(https://www.zhihu.com/equation?tex=%5Calpha)]对应每个状态[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-2Z2aVWLQ-1641869995324)(https://www.zhihu.com/equation?tex=h_i)],有了这些权重[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-Hk5EaG7R-1641869995326)(https://www.zhihu.com/equation?tex=%5Calpha)]我们可以对m个状态计算加权平均,得到一个Context vector [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-t45FxLwZ-1641869995327)(https://www.zhihu.com/equation?tex=C_0)]。每一个Context vector都会对应一个decoder状态[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-RWMIWyvv-1641869995329)(https://www.zhihu.com/equation?tex=s_i)]

img

接下来我们来看一下decoder是怎么计算新的状态的。我们来回顾一下,假如不用attention,我们是这样更新状态的,新的状态[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-XQdq7Wqq-1641869995332)(https://www.zhihu.com/equation?tex=S_1)]是旧状态[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-6hALjeWi-1641869995334)(https://www.zhihu.com/equation?tex=S_0)]与新输入[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-Rsf18c6s-1641869995335)(https://www.zhihu.com/equation?tex=X_1%5E%60)]`的函数,看一下下图左边的公式,将两者做concat,然后乘上权重矩阵加上偏置b,最后通过tanh就是我们的新状态,也就是说状态的更新仅仅是根据上一个状态,并不会看encoder的状态。用attention的话更新状态还要用到我们计算出的Context vector [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-7qFoLwfj-1641869995337)(https://www.zhihu.com/equation?tex=C_0)],把三个参数一起做concat后更新。

img

回忆一下,[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-Jf9RDBCH-1641869995341)(https://www.zhihu.com/equation?tex=C_0)]是所有encoder状态[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-j9RtlnN9-1641869995342)(https://www.zhihu.com/equation?tex=h_i)]的加权平均,所以[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-jyyobC9I-1641869995343)(https://www.zhihu.com/equation?tex=C_0)]知道输入[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-7uQRb7Qx-1641869995345)(https://www.zhihu.com/equation?tex=X_1)]到[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-vLx7qczX-1641869995347)(https://www.zhihu.com/equation?tex=X_m)]的完整信息,decoder新的状态[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-8ha4p34k-1641869995348)(https://www.zhihu.com/equation?tex=S_1)]依赖于[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-hMnvHCIo-1641869995351)(https://www.zhihu.com/equation?tex=C_0)],这样RNN遗忘的问题就解决了。下一步则是计算context vector [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-SzsC0L2h-1641869995353)(https://www.zhihu.com/equation?tex=C_1)],跟之前一样,先计算权重[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-cZFgzUkD-1641869995354)(https://www.zhihu.com/equation?tex=%5Calpha_i)],这里是计算[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-Zydqw1Cd-1641869995356)(https://www.zhihu.com/equation?tex=S_1)]跟之前encoder所有状态的相关性,得到了m个[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-eALeg1EU-1641869995358)(https://www.zhihu.com/equation?tex=%5Calpha)],注意一下这里的权重也是要更新的,上一轮算的是跟[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-PujJRlBN-1641869995360)(https://www.zhihu.com/equation?tex=s_0)]的相关性现在算的是跟[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-O13A2esZ-1641869995361)(https://www.zhihu.com/equation?tex=S_1)]的相关性,这样就可以通过加权平均计算出新的[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-CCwIdp4E-1641869995364)(https://www.zhihu.com/equation?tex=C_1)]。

img

Decoder接受新的输入[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-n1x26wW1-1641869995368)(https://www.zhihu.com/equation?tex=X_2)],还是用那个公式计算出新状态,然后一直循环下去直到结束。

img

我们知道在这个过程中我们会计算出很多权重[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-h1NsmNRr-1641869995371)(https://www.zhihu.com/equation?tex=%5Calpha_i)],我们思考一下我们究竟计算了多少个[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-emthMK1A-1641869995373)(https://www.zhihu.com/equation?tex=%5Calpha)]?想要计算出一个context vector[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-VxO27zLL-1641869995375)(https://www.zhihu.com/equation?tex=C_j)],我们要计算出m个相似性权重[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-qtLBYnNM-1641869995376)(https://www.zhihu.com/equation?tex=%5Calpha)],所以每轮更新都需要计算m个权重,假如一共有t个state,那么一共就要计算m×t个权重,也就是encoder和decoder数量的乘积。attention为了不遗忘,代价就是高数量级的计算。

img

4.2 权重[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-M4mAQ4l3-1641869995379)(https://www.zhihu.com/equation?tex=%5Calpha)]的实际意义

这张图下面是encoder,上面是decoder,attention会把decoder所有状态与encoder所有状态计算相似性,也就是[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-c6TEwbOV-1641869995380)(https://www.zhihu.com/equation?tex=%5Calpha)].在这张图中每条线就对应一个[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-lcztSDcv-1641869995382)(https://www.zhihu.com/equation?tex=%5Calpha)],线越粗说明相关性越高。

img

比如下面,法语中的zone就是英语的Area,所以两者的线就很粗。

img

4.3 总结

这次仅仅是从机器翻译的角度介绍了attention的一个应用,attention在业界还是有很多应用的,比如self-attention,Transformer应用,希望以此为印子能够打开读者attention的大门。

5 参考

1.colah’s blog**http://colah.github.io/posts/2015-08-Understanding-LSTMs/**

2.CS583 https://github.com/wangshusen/DeepLearning

原文

https://zhuanlan.zhihu.com/p/135970560

写在最后

欢迎大家关注鄙人的公众号【麦田里的守望者zhg】,让我们一起成长,谢谢。
微信公众号

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/小丑西瓜9/article/detail/475023
推荐阅读
相关标签
  

闽ICP备14008679号