当前位置:   article > 正文

pytorch dropout_Attention Is All You Need | 源码解析(pytorch)

attention is all you need explained
  • Harvard NLP团队推出的注释版Pytroch代码:http://nlp.seas.harvard.edu/2018/04/03/attention.html
  • 图解Transformer: http://jalammar.github.io/illustrated-transformer/
  • 主要参考资料:https://mlexplained.com/2017/12/29/attention-is-all-you-need-explained/

该论文提出了完全基于attention和全连接的Transformer来替代RNNs,以实现seq-to-seq任务。与基于RNNs的seq-to-seq模型一样,Transformer整体上也是一个encoder-decoder结构。其中encoder以时间序列信息

为输入,输出相应的序列特征
. Decoder将
作为输入,以自回归的形式依次给出相应的预测输出
.

关于Attention

在正式了解Transformer前,先简单回顾下Transformer出现前主宰seq-to-seq任务的基于RNNs的encoder-decoder结构.

v2-1ce36f17d754c486f5b087fd4f09dcc5_b.png

早期基于RNNs的seq-to-seq的结构如上图所示,输入的序列信息(I like cats more than dogs)通过Encoder获得该序列信息的embedding特征向量。Decoder以该特征向量为输入生成相应的序列信息(私は犬より猫が好き)。在这个过程中有一个显而易见的问题,即所有输入的序列信息均包含在一个特征向量中。当序列信息比较短时,上述结构的表现还可以,但一旦序列长度很长时,一个单一的特征向量很难准确的表达所有的信息——“You can't cram the meaning of a whole %&!$# sentence into a single $&!#* vector!" (KyungHyun Cho's talk at ACL 2014 Workshop on Semantic Parsing)。

实际上,decoder在不同时刻需要关注的信息是不同的。例如,在将句子"I like cats more than dogs"翻译成"私は犬より猫が好き"的例子中,输入的第二个词“like”对应输出句子中的最后一个词“好き”,这就需要RNNs携带长时间依赖的信息。而在输出单词“犬”后,decoder只需要知道输出的这个词对应的是“dog”就行,而并不需要继续记住“dog”。受此启发,Attention机制应运而生。直观上,Attention机制的作用就是回顾整个句子,并有选择性地提取decode过程需要的信息。

为了能够有选择性地从encoder的输出中获得重要的信息,decoder需要访问encoder的所有隐藏层的输出。同时,为了方便encoder预测某一个特定的单词,我们仍然需要对所有encoder出来的信息进行处理,以得到一个单一向量的特征表示。这时,就需要利用Attention机制通过加权的方式来选择隐藏层状态向量,并以加权的形式获得特定的向量表示。原始的attention机制示意图如下所示:

v2-f6dd63f6ab3119c0722f8350b803fbdc_b.png

尽管Attention机制的出现在一定程度上解决了RNNs的一些短板,但是由于RNNs的时间序列属性,每个隐藏状态的计算都依赖于上一时刻的隐藏状态,大大限制了RNNs的计算效率。另外,在自然语言的机器翻译过程中,源内容和目标内容之间以及各自内部之间都存在一些依赖关系。而在传统的attention机制中,仅仅考虑了两者之间的依赖,而并没有考虑各自内部之间的依赖关系。Transformer的提出就是为了解决这两个短板:1)序列上前后依赖带来的低效率;2)没有考虑输入内容及输出内容各自内部之间关联的attention机制。

Multi-head attention

在传统seq-to-seq的attention机制中,计算attention的query, keys和values分别来自于:

  • query (Q): decoder hidden state
  • keys (K): encoder hidden states
  • values (V): encoder hidden states

通过decoder的hidden state (query) 查找encoder hidden states中的有用信息。实际上,query也可以完全可以跟keys和values一样来自inputs或者outputs,这就是self-attention机制。作为Transformer中的关键模块,multi-head attention便是利用了同样的embedding来实现attention的,并通过不同的线性变化得到的不同的V,K,Q,从而获得多个 (multi-head) attention.

v2-8ea15b6f631c2d075d55813eefc23e5f_b.png

上图所展示是作者称为“Scaled Dot-Product Attention”的一般attention机制,其计算过程可由下式表示:

其中

是V的向量长度,当其过大时会使得到的点积值过大,从而导致softmax的梯度很小,因此除上
来改善这种情况。

pytorch实现Scaled Dot-Product Attention:

class 

为了得到多样性的特征表示,通过对V,K,Q进行不同的线性变化,即可得到多组变换后的V,K,Q, 进而实现Multi-head attention. Multi-head的示意图如下图所示:

v2-3c4f90bd723d8442c90e4742389ff52f_b.png

每一个单独的attention可以表示为:

其中

,
,分别时计算第
个head的attention时对Q,K,V进行线性变换的参数。在获得
个head的attention后,Multi-head attention的最终输出由下式表示:

pytorch实现Multi-head Attention的代码如下:

class 

整体结构

如下图所示,Transformer也是一个典型的encoder-decoder结构。下图的左侧和右侧分别是encoder和decoder. Encoder的输入是输入序列的embeddings, 而Decoder的输入则是已经预测到的输出的embeddings. 途中的

表示共有N个相关的block (对于encdoer和Decoder均采用了N=6)。

v2-6a871d311d43bb1b13775da6d5f63724_b.png

从图中可以看到,在embeddings输入到encoder或者decoder时均有一个“Positional Encoding” ,关于该Positional Encoding我们稍后介绍,下面先介绍Encoder和Decoder的构成及实现。

Encoder

v2-7c450fcfb40df0add85b385500eb14e9_b.png

如上图所示,Encoder由两个子模块组成,一个是Multi-head attention (MHA), 另外一个是Feed Forward (FFN)。两个sub-layer均与相关输入构成了残差连接并进行了layer normalization, 相关过程可表达为:

Feed Forward Net (FFN) 由两层全连接层和激活函数ReLU构成:

pytorch实现单个Encoder block:

class 

基于此搭建整个Encoder:

class 

Decoder

v2-115a75482c5bffa6eef73b786aba04f8_b.png

相比于Encoder, Decoder中多了一个叫做“masked multi-head attention”的sub-layer。在Decoder预测下一个单词的时候,可以利用的只能是已经预测出的句子。因此,在训练过程中,Decoder在预测某个单词时,需要遮挡掉该单词及以后的ground-truth,只利用该单词以前的embeddings进行训练。如下图所示,在我们训练模型将“I like cats more than dogs”翻译成“私は犬よりも猫が好き”的过程中,当模型预测“犬”时,其可利用的信息是"私は",而之后的信息是不可用的。

v2-d9a3525e71524c3db7605150e20d21a4_b.png

pytorch实现单个Encoder block:

class 

基于此,Decoder的实现:

class 

Positional Encodings

在无任何循环网络和卷积网络的情况下,基于Multi-head attention的结构是无法利用单词的位置信息。换言之,对于multi-head attention而言“I like cats more than dogs” 与 "I like dogs more than cats" 并无差别。因此,有必要在序列中任务加入一些包含位置属性的信息,而Positional encodings的作用就在此。文章中采用了如下公式计算positional encodings:

其中

表示位置,
表示维度。位置编码的维度于embeddings的维度相同,因此可采用直接加合的方式。
class 

可视化添加的位置信息

plt

v2-077ce14ed89a04b1a9fe5ffb65ec44a1_b.png

Pytorch中的Transformer模块

Pytorch 1.2后的版本包含了标准的nn.Transformer模块,模块的各组件可单独使用,相关的API包括:

  • nn.Transformer
  • nn.TransformerEncoder 和 nn.TransformerEncoderLayer
  • nn.TransformerDecoder 和 nn.TransformerDecoderLayer
  1. @inproceedings{vaswani2017attention,
  2. title={Attention is all you need},
  3. author={Vaswani, Ashish and Shazeer, Noam and Parmar, Niki and Uszkoreit, Jakob and Jones, Llion and Gomez, Aidan N and Kaiser, {L}ukasz and Polosukhin, Illia},
  4. booktitle={Advances in neural information processing systems},
  5. pages={5998--6008},
  6. year={2017}
  7. }
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/Gausst松鼠会/article/detail/536094
推荐阅读
相关标签
  

闽ICP备14008679号