赞
踩
原文链接:https://blog.csdn.net/qq_29695701/article/details/88096455
基本是机器翻译,进行了简单的人工修正,凑活看吧
原论文: 《Attention Is All You Need》
源代码:https://github.com/tensorflow/tensor2tensor关于注意力机制的更多信息,参考
https://blog.csdn.net/qq_29695701/article/details/88896227
http://nlp.seas.harvard.edu/2018/04/03/attention.html(推荐)
摘要
主要的序列转导模型是基于复杂的循环或卷积神经网络,包括编码器和解码器。性能最好的模型还通过一个注意机制连接编码器和解码器。我们提出了一种新的简单网络结构,即Transformer,它完全基于注意机制,完全不需要递归和卷积。对两个机器翻译任务的实验表明,这些模型在质量上更优,同时更具并行性,训练时间明显更少。我们的模型在WMT 2014英语翻译任务中实现了28.4 BLEU,比现有的最佳效果(包括合奏)提高了2倍以上。在WMT2014英语到法语翻译任务中,我们的模型在8个GPU上训练3.5天后建立了一个新的单一模型,即最先进的BLEU分数41.8,这只是文献中最佳模型训练成本的一小部分。结果表明,该Transformer可以很好地推广到其他任务中,并成功地应用于大样本和有限样本的英语用户分析。
循环神经网络,特别是长期短期记忆[13]和门控复发性[7]神经网络,已作为最先进的顺序建模和转导问题(如语言建模和机器翻译)方法而牢固地建立起来[35,2,5]。此后,许多工作继续扩大了循环神经网络和编码器-解码器体系结构的界限[38、24、15]。
循环模型通常是沿着输入和输出序列的符号位置进行因子计算。将位置与计算时间中的步骤对齐,它们生成一系列隐藏状态
h t h_{t} </span><span class="katex-html"><span class="base"><span class="strut" style="height: 0.84444em; vertical-align: -0.15em;"></span><span class="mord"><span class="mord mathit">h</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.280556em;"><span class="" style="top: -2.55em; margin-left: 0em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathit mtight">t</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height: 0.15em;"><span class=""></span></span></span></span></span></span></span></span></span></span>,作为先前隐藏状态<span class="katex--inline"><span class="katex"><span class="katex-mathml"> h t − 1 h_{t−1} </span><span class="katex-html"><span class="base"><span class="strut" style="height: 0.902771em; vertical-align: -0.208331em;"></span><span class="mord"><span class="mord mathit">h</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.301108em;"><span class="" style="top: -2.55em; margin-left: 0em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathit mtight">t</span><span class="mord mtight">−</span><span class="mord mtight">1</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height: 0.208331em;"><span class=""></span></span></span></span></span></span></span></span></span></span>的函数和位置t的输入。这种固有的顺序性排除了训练示例中的并行化,而训练示例在较长的序列长度下变得至关重要,因为内存约束限制了批处理。通过例子。最近的工作通过因子分解技巧[21]和条件计算[32]显著提高了计算效率,同时在后者的情况下也提高了模型性能。然而,顺序计算的基本约束仍然存在。</p>
注意力机制已经成为各种任务中引人注目的序列建模和转换模型的组成部分,允许在不考虑依赖项在输入或输出序列中的距离的情况下对依赖项进行建模[2,19]。然而,在除少数情况外的所有情况下[27],这种注意力机制常与循环网络结合使用。
在这项工作中,我们提出了Transformer,一个避免了循环的模型架构,它完全依赖一个注意机制来绘制输入和输出之间的全局依赖性。Transformer 允许更大程度的并行化,可以在8个p100 gpu上经过短短12小时的训练后,在翻译质量上达到一个新的水平。
减少顺序计算的目的形成了扩展神经GPU〔16〕、ByteNet〔18〕和ConvS2S〔9〕的基础,所有这些都使用卷积神经网络作为基本构建块,并使用并行的方式来计算所有输入和输出位置的隐藏表示。在这些模型中,将来自两个任意输入或输出位置的信号联系起来所需的操作次数随着位置之间的距离增加而增加,这对于ConvS2S是线性的,对于ByteNet是对数。这使得学习远距离位置之间的依赖性变得更加困难[12]。在Transformer中,这被减少到一个恒定的操作次数,尽管平均注意力加权位置而导致了有效分辨率(resolution)的降低,这是我们在3.2节中描述的多头注意力(Multi-Head Attention)的效果。
自我注意(Self-attention),有时被称为内注意,是一种注意力机制,它将一个序列的不同位置联系起来,以计算序列的表示。在阅读理解、抽象总结、文本蕴涵和学习任务独立句子表达等多种任务中,人们成功地运用了 Self-attention[4]、[27]、[28]、[22]。
端到端的记忆网络是一种基于循环的注意力机制,而不是顺序一致的循环,并且在简单的语言问答和语言建模任务上表现良好[34]。
然而,据我们所知,Transformer是第一个完全依赖于 Self-Attention 来计算其输入和输出表示的转导模型,而不使用序列对齐的RNN或卷积。在下面的章节中,我们将描述Transformer,激发 Self-attention(motivate self-attention),并讨论它相对于[17],[18]和[9]等模型的优势。
很多具有竞争力的神经序列转导模型都含有编码器-解码器结构[5,2,35]。在这里,编码器将符号表示的输入序列(
x 1 , … , x n x_{1},…,x_{n} </span><span class="katex-html"><span class="base"><span class="strut" style="height: 0.625em; vertical-align: -0.19444em;"></span><span class="mord"><span class="mord mathit">x</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.301108em;"><span class="" style="top: -2.55em; margin-left: 0em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight">1</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height: 0.15em;"><span class=""></span></span></span></span></span></span><span class="mpunct">,</span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="minner">…</span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="mpunct">,</span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="mord"><span class="mord mathit">x</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.151392em;"><span class="" style="top: -2.55em; margin-left: 0em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathit mtight">n</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height: 0.15em;"><span class=""></span></span></span></span></span></span></span></span></span></span>)映射为连续表示序列<span class="katex--inline"><span class="katex"><span class="katex-mathml"> z = ( z 1 , … , z n ) z=(z_{1},…,z_{n}) </span><span class="katex-html"><span class="base"><span class="strut" style="height: 0.43056em; vertical-align: 0em;"></span><span class="mord mathit" style="margin-right: 0.04398em;">z</span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right: 0.277778em;"></span></span><span class="base"><span class="strut" style="height: 1em; vertical-align: -0.25em;"></span><span class="mopen">(</span><span class="mord"><span class="mord mathit" style="margin-right: 0.04398em;">z</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.301108em;"><span class="" style="top: -2.55em; margin-left: -0.04398em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight">1</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height: 0.15em;"><span class=""></span></span></span></span></span></span><span class="mpunct">,</span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="minner">…</span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="mpunct">,</span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="mord"><span class="mord mathit" style="margin-right: 0.04398em;">z</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.151392em;"><span class="" style="top: -2.55em; margin-left: -0.04398em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathit mtight">n</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height: 0.15em;"><span class=""></span></span></span></span></span></span><span class="mclose">)</span></span></span></span></span>。给定<span class="katex--inline"><span class="katex"><span class="katex-mathml"> z z </span><span class="katex-html"><span class="base"><span class="strut" style="height: 0.43056em; vertical-align: 0em;"></span><span class="mord mathit" style="margin-right: 0.04398em;">z</span></span></span></span></span>,解码器然后一次生成一个符号的输出序列(<span class="katex--inline"><span class="katex"><span class="katex-mathml"> y 1 , … , y m y_{1},…,y_{m} </span><span class="katex-html"><span class="base"><span class="strut" style="height: 0.625em; vertical-align: -0.19444em;"></span><span class="mord"><span class="mord mathit" style="margin-right: 0.03588em;">y</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.301108em;"><span class="" style="top: -2.55em; margin-left: -0.03588em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight">1</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height: 0.15em;"><span class=""></span></span></span></span></span></span><span class="mpunct">,</span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="minner">…</span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="mpunct">,</span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="mord"><span class="mord mathit" style="margin-right: 0.03588em;">y</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.151392em;"><span class="" style="top: -2.55em; margin-left: -0.03588em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathit mtight">m</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height: 0.15em;"><span class=""></span></span></span></span></span></span></span></span></span></span>)。在每一步中,模型都是自动回归(auto-regressive)的[10],在生成下一步时,将先前生成的符号序列作为附加输入</strong>。</p>
Transformer遵循这一总体架构,使用 堆叠的 Self-attention 和 逐点(point-wise)、全连接的层用于编码器和解码器,分别如图1的左半部分和右半部分所示。
编码器:编码器由一组
N = 6 N=6 </span><span class="katex-html"><span class="base"><span class="strut" style="height: 0.68333em; vertical-align: 0em;"></span><span class="mord mathit" style="margin-right: 0.10903em;">N</span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right: 0.277778em;"></span></span><span class="base"><span class="strut" style="height: 0.64444em; vertical-align: 0em;"></span><span class="mord">6</span></span></span></span></span>的相同层堆叠而成。<strong>每层有两个子层</strong>。<strong>第一个子层采用 Multi-Head Self-Attention 机制</strong>,<strong>第二个是一个简单的,位置导向的,全连接的前馈网络</strong>。我们在两个子层的<strong>每一个子层周围使用一个 Residual 连接</strong>[11],<strong>然后是层 Normalization</strong>[1]。也就是说,每个子层的输出是<span class="katex--inline"><span class="katex"><span class="katex-mathml"> L a y e r N o r m ( x + S u b l a y e r ( x ) ) LayerNorm(x+Sublayer(x)) </span><span class="katex-html"><span class="base"><span class="strut" style="height: 1em; vertical-align: -0.25em;"></span><span class="mord mathit">L</span><span class="mord mathit">a</span><span class="mord mathit" style="margin-right: 0.03588em;">y</span><span class="mord mathit">e</span><span class="mord mathit" style="margin-right: 0.02778em;">r</span><span class="mord mathit" style="margin-right: 0.10903em;">N</span><span class="mord mathit">o</span><span class="mord mathit" style="margin-right: 0.02778em;">r</span><span class="mord mathit">m</span><span class="mopen">(</span><span class="mord mathit">x</span><span class="mspace" style="margin-right: 0.222222em;"></span><span class="mbin">+</span><span class="mspace" style="margin-right: 0.222222em;"></span></span><span class="base"><span class="strut" style="height: 1em; vertical-align: -0.25em;"></span><span class="mord mathit" style="margin-right: 0.05764em;">S</span><span class="mord mathit">u</span><span class="mord mathit">b</span><span class="mord mathit" style="margin-right: 0.01968em;">l</span><span class="mord mathit">a</span><span class="mord mathit" style="margin-right: 0.03588em;">y</span><span class="mord mathit">e</span><span class="mord mathit" style="margin-right: 0.02778em;">r</span><span class="mopen">(</span><span class="mord mathit">x</span><span class="mclose">)</span><span class="mclose">)</span></span></span></span></span>,其中<span class="katex--inline"><span class="katex"><span class="katex-mathml"> S u b l a y e r ( x ) Sublayer(x) </span><span class="katex-html"><span class="base"><span class="strut" style="height: 1em; vertical-align: -0.25em;"></span><span class="mord mathit" style="margin-right: 0.05764em;">S</span><span class="mord mathit">u</span><span class="mord mathit">b</span><span class="mord mathit" style="margin-right: 0.01968em;">l</span><span class="mord mathit">a</span><span class="mord mathit" style="margin-right: 0.03588em;">y</span><span class="mord mathit">e</span><span class="mord mathit" style="margin-right: 0.02778em;">r</span><span class="mopen">(</span><span class="mord mathit">x</span><span class="mclose">)</span></span></span></span></span>是子层实现的函数。为了方便这些 Residual 连接,<strong>模型中的所有子层以及嵌入层都会生成维度为<span class="katex--inline"><span class="katex"><span class="katex-mathml"> d m o d e l = 512 d_{model}=512 </span><span class="katex-html"><span class="base"><span class="strut" style="height: 0.84444em; vertical-align: -0.15em;"></span><span class="mord"><span class="mord mathit">d</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.336108em;"><span class="" style="top: -2.55em; margin-left: 0em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathit mtight">m</span><span class="mord mathit mtight">o</span><span class="mord mathit mtight">d</span><span class="mord mathit mtight">e</span><span class="mord mathit mtight" style="margin-right: 0.01968em;">l</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height: 0.15em;"><span class=""></span></span></span></span></span></span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right: 0.277778em;"></span></span><span class="base"><span class="strut" style="height: 0.64444em; vertical-align: 0em;"></span><span class="mord">5</span><span class="mord">1</span><span class="mord">2</span></span></span></span></span> 的输出</strong>。</p>
解码器:解码器也由一个
N = 6 N=6 </span><span class="katex-html"><span class="base"><span class="strut" style="height: 0.68333em; vertical-align: 0em;"></span><span class="mord mathit" style="margin-right: 0.10903em;">N</span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right: 0.277778em;"></span></span><span class="base"><span class="strut" style="height: 0.64444em; vertical-align: 0em;"></span><span class="mord">6</span></span></span></span></span>的相同层堆叠而成。除了每个编码器层中的两个子层外,解码器<strong>还插入第三个子层</strong>,该子层<strong>在编码器堆栈的输出上执行 Multi-Head Attention</strong>。与编码器类似,我们在每个子层周围使用 Residual 连接,然后进行层Normalization。我们还<strong>修改了解码器堆栈中的自注意子层</strong>,以防止每个位置去关注其后续位置。这个掩膜,结合输出嵌入(the output embeddings)被一个位置偏移(offset)的事实,确保了位置<span class="katex--inline"><span class="katex"><span class="katex-mathml"> i i </span><span class="katex-html"><span class="base"><span class="strut" style="height: 0.65952em; vertical-align: 0em;"></span><span class="mord mathit">i</span></span></span></span></span>的预测只能依赖于位置小于<span class="katex--inline"><span class="katex"><span class="katex-mathml"> i i </span><span class="katex-html"><span class="base"><span class="strut" style="height: 0.65952em; vertical-align: 0em;"></span><span class="mord mathit">i</span></span></span></span></span>的已知输出。</p>
注意力函数可以描述为从一个查询(query)和一组键值对(key-value pairs)到一个输出的映射,其中,查询(query)、键(key)、值(value)和输出(output)都是向量。输出(output)是以值(value)的加权和进行计算的,其中分配给每个值(value)的权重是通过查询(query)的匹配函数(compatibility function)和相应的键(key)计算的。
我们称我们的特别的关注(attention)为“Scaled Dot-Product Attention”(图2)。输入包括
d k d_{k} </span><span class="katex-html"><span class="base"><span class="strut" style="height: 0.84444em; vertical-align: -0.15em;"></span><span class="mord"><span class="mord mathit">d</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.336108em;"><span class="" style="top: -2.55em; margin-left: 0em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathit mtight" style="margin-right: 0.03148em;">k</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height: 0.15em;"><span class=""></span></span></span></span></span></span></span></span></span></span>维的查询和键,以及<span class="katex--inline"><span class="katex"><span class="katex-mathml"> d v d_{v} </span><span class="katex-html"><span class="base"><span class="strut" style="height: 0.84444em; vertical-align: -0.15em;"></span><span class="mord"><span class="mord mathit">d</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.151392em;"><span class="" style="top: -2.55em; margin-left: 0em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathit mtight" style="margin-right: 0.03588em;">v</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height: 0.15em;"><span class=""></span></span></span></span></span></span></span></span></span></span>维的值。我们<strong>计算查询与所有键的点积,并将每个点积除以<span class="katex--inline"><span class="katex"><span class="katex-mathml"> d k \sqrt{d_{k}} </span><span class="katex-html"><span class="base"><span class="strut" style="height: 1.04em; vertical-align: -0.18278em;"></span><span class="mord sqrt"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.85722em;"><span class="svg-align" style="top: -3em;"><span class="pstrut" style="height: 3em;"></span><span class="mord" style="padding-left: 0.833em;"><span class="mord"><span class="mord mathit">d</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.336108em;"><span class="" style="top: -2.55em; margin-left: 0em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathit mtight" style="margin-right: 0.03148em;">k</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height: 0.15em;"><span class=""></span></span></span></span></span></span></span></span><span class="" style="top: -2.81722em;"><span class="pstrut" style="height: 3em;"></span><span class="hide-tail" style="min-width: 0.853em; height: 1.08em;"> <svg width="400em" height="1.08em" viewBox="0 0 400000 1080" preserveAspectRatio="xMinYMin slice"> <path d="M95,702c-2.7,0,-7.17,-2.7,-13.5,-8c-5.8,-5.3,-9.5,
-10,-9.5,-14c0,-2,0.3,-3.3,1,-4c1.3,-2.7,23.83,-20.7,67.5,-54c44.2,-33.3,65.8,
-50.3,66.5,-51c1.3,-1.3,3,-2,5,-2c4.7,0,8.7,3.3,12,10s173,378,173,378c0.7,0,
35.3,-71,104,-213c68.7,-142,137.5,-285,206.5,-429c69,-144,104.5,-217.7,106.5,
-221c5.3,-9.3,12,-14,20,-14H400000v40H845.2724s-225.272,467,-225.272,467
s-235,486,-235,486c-2.7,4.7,-9,7,-19,7c-6,0,-10,-1,-12,-3s-194,-422,-194,-422
s-65,47,-65,47z M834 80H400000v40H845z">
,然后应用SoftMax函数得到这些值的 权重。
在实践中,我们同时计算一组查询上的注意力函数,将它们打包成矩阵Q。键和值也打包成矩阵K和V。我们计算输出矩阵的方式为:
A t t e n t i o n ( Q , K , V ) = s o f t m a x ( Q K T d k ) V Attention(Q,K,V)=softmax(\frac{QK^{T}}{\sqrt{d_{k}}})V </span><span class="katex-html"><span class="base"><span class="strut" style="height: 1em; vertical-align: -0.25em;"></span><span class="mord mathit">A</span><span class="mord mathit">t</span><span class="mord mathit">t</span><span class="mord mathit">e</span><span class="mord mathit">n</span><span class="mord mathit">t</span><span class="mord mathit">i</span><span class="mord mathit">o</span><span class="mord mathit">n</span><span class="mopen">(</span><span class="mord mathit">Q</span><span class="mpunct">,</span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="mord mathit" style="margin-right: 0.07153em;">K</span><span class="mpunct">,</span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="mord mathit" style="margin-right: 0.22222em;">V</span><span class="mclose">)</span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right: 0.277778em;"></span></span><span class="base"><span class="strut" style="height: 2.44833em; vertical-align: -0.93em;"></span><span class="mord mathit">s</span><span class="mord mathit">o</span><span class="mord mathit" style="margin-right: 0.10764em;">f</span><span class="mord mathit">t</span><span class="mord mathit">m</span><span class="mord mathit">a</span><span class="mord mathit">x</span><span class="mopen">(</span><span class="mord"><span class="mopen nulldelimiter"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 1.51833em;"><span class="" style="top: -2.25278em;"><span class="pstrut" style="height: 3em;"></span><span class="mord"><span class="mord sqrt"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.85722em;"><span class="svg-align" style="top: -3em;"><span class="pstrut" style="height: 3em;"></span><span class="mord" style="padding-left: 0.833em;"><span class="mord"><span class="mord mathit">d</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.336108em;"><span class="" style="top: -2.55em; margin-left: 0em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathit mtight" style="margin-right: 0.03148em;">k</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height: 0.15em;"><span class=""></span></span></span></span></span></span></span></span><span class="" style="top: -2.81722em;"><span class="pstrut" style="height: 3em;"></span><span class="hide-tail" style="min-width: 0.853em; height: 1.08em;"> <svg width="400em" height="1.08em" viewBox="0 0 400000 1080" preserveAspectRatio="xMinYMin slice"> <path d="M95,702c-2.7,0,-7.17,-2.7,-13.5,-8c-5.8,-5.3,-9.5,
-10,-9.5,-14c0,-2,0.3,-3.3,1,-4c1.3,-2.7,23.83,-20.7,67.5,-54c44.2,-33.3,65.8,
-50.3,66.5,-51c1.3,-1.3,3,-2,5,-2c4.7,0,8.7,3.3,12,10s173,378,173,378c0.7,0,
35.3,-71,104,-213c68.7,-142,137.5,-285,206.5,-429c69,-144,104.5,-217.7,106.5,
-221c5.3,-9.3,12,-14,20,-14H400000v40H845.2724s-225.272,467,-225.272,467
s-235,486,-235,486c-2.7,4.7,-9,7,-19,7c-6,0,-10,-1,-12,-3s-194,-422,-194,-422
s-65,47,-65,47z M834 80H400000v40H845z">
QKT)V
两个最常用的注意力函数是加性注意(additive attention)[2]和点积(乘法)注意(dot-product attention)。除比例因子
1 d k \frac{1}{\sqrt{d_{k}}} </span><span class="katex-html"><span class="base"><span class="strut" style="height: 1.38311em; vertical-align: -0.538em;"></span><span class="mord"><span class="mopen nulldelimiter"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.845108em;"><span class="" style="top: -2.58644em;"><span class="pstrut" style="height: 3em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord sqrt mtight"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.862231em;"><span class="svg-align" style="top: -3em;"><span class="pstrut" style="height: 3em;"></span><span class="mord mtight" style="padding-left: 0.833em;"><span class="mord mtight"><span class="mord mathit mtight">d</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.3448em;"><span class="" style="top: -2.34877em; margin-left: 0em; margin-right: 0.0714286em;"><span class="pstrut" style="height: 2.5em;"></span><span class="sizing reset-size3 size1 mtight"><span class="mord mtight"><span class="mord mathit mtight" style="margin-right: 0.03148em;">k</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height: 0.151229em;"><span class=""></span></span></span></span></span></span></span></span><span class="" style="top: -2.82223em;"><span class="pstrut" style="height: 3em;"></span><span class="hide-tail mtight" style="min-width: 0.853em; height: 1.08em;"> <svg width="400em" height="1.08em" viewBox="0 0 400000 1080" preserveAspectRatio="xMinYMin slice"> <path d="M95,702c-2.7,0,-7.17,-2.7,-13.5,-8c-5.8,-5.3,-9.5,
-10,-9.5,-14c0,-2,0.3,-3.3,1,-4c1.3,-2.7,23.83,-20.7,67.5,-54c44.2,-33.3,65.8,
-50.3,66.5,-51c1.3,-1.3,3,-2,5,-2c4.7,0,8.7,3.3,12,10s173,378,173,378c0.7,0,
35.3,-71,104,-213c68.7,-142,137.5,-285,206.5,-429c69,-144,104.5,-217.7,106.5,
-221c5.3,-9.3,12,-14,20,-14H400000v40H845.2724s-225.272,467,-225.272,467
s-235,486,-235,486c-2.7,4.7,-9,7,-19,7c-6,0,-10,-1,-12,-3s-194,-422,-194,-422
s-65,47,-65,47z M834 80H400000v40H845z">
1外,点积注意与我们的算法相同。加性注意使用一个前馈网络和一个单独的隐藏层来计算兼容性函数(compatibility function)。虽然二者在理论复杂度上相似,但在实践中,点积注意速度更快,空间效率更高,因为它可以使用高度优化的矩阵乘法代码来实现。
对于较小的
d k d_{k} </span><span class="katex-html"><span class="base"><span class="strut" style="height: 0.84444em; vertical-align: -0.15em;"></span><span class="mord"><span class="mord mathit">d</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.336108em;"><span class="" style="top: -2.55em; margin-left: 0em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathit mtight" style="margin-right: 0.03148em;">k</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height: 0.15em;"><span class=""></span></span></span></span></span></span></span></span></span></span>值,这两种机制的性能相似,但加性注意优于点积注意,而不会缩放较大的<span class="katex--inline"><span class="katex"><span class="katex-mathml"> d k d_{k} </span><span class="katex-html"><span class="base"><span class="strut" style="height: 0.84444em; vertical-align: -0.15em;"></span><span class="mord"><span class="mord mathit">d</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.336108em;"><span class="" style="top: -2.55em; margin-left: 0em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathit mtight" style="margin-right: 0.03148em;">k</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height: 0.15em;"><span class=""></span></span></span></span></span></span></span></span></span></span>值[3]。我们怀疑,对于<span class="katex--inline"><span class="katex"><span class="katex-mathml"> d k d_{k} </span><span class="katex-html"><span class="base"><span class="strut" style="height: 0.84444em; vertical-align: -0.15em;"></span><span class="mord"><span class="mord mathit">d</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.336108em;"><span class="" style="top: -2.55em; margin-left: 0em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathit mtight" style="margin-right: 0.03148em;">k</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height: 0.15em;"><span class=""></span></span></span></span></span></span></span></span></span></span>的较大值,点积的增长幅度较大,会将SoftMax函数推送到具有极小梯度的区域。为了抵消这种影响,我们将点积乘以<span class="katex--inline"><span class="katex"><span class="katex-mathml"> 1 d k \frac{1}{\sqrt{d_{k}}} </span><span class="katex-html"><span class="base"><span class="strut" style="height: 1.38311em; vertical-align: -0.538em;"></span><span class="mord"><span class="mopen nulldelimiter"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.845108em;"><span class="" style="top: -2.58644em;"><span class="pstrut" style="height: 3em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord sqrt mtight"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.862231em;"><span class="svg-align" style="top: -3em;"><span class="pstrut" style="height: 3em;"></span><span class="mord mtight" style="padding-left: 0.833em;"><span class="mord mtight"><span class="mord mathit mtight">d</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.3448em;"><span class="" style="top: -2.34877em; margin-left: 0em; margin-right: 0.0714286em;"><span class="pstrut" style="height: 2.5em;"></span><span class="sizing reset-size3 size1 mtight"><span class="mord mtight"><span class="mord mathit mtight" style="margin-right: 0.03148em;">k</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height: 0.151229em;"><span class=""></span></span></span></span></span></span></span></span><span class="" style="top: -2.82223em;"><span class="pstrut" style="height: 3em;"></span><span class="hide-tail mtight" style="min-width: 0.853em; height: 1.08em;"> <svg width="400em" height="1.08em" viewBox="0 0 400000 1080" preserveAspectRatio="xMinYMin slice"> <path d="M95,702c-2.7,0,-7.17,-2.7,-13.5,-8c-5.8,-5.3,-9.5,
-10,-9.5,-14c0,-2,0.3,-3.3,1,-4c1.3,-2.7,23.83,-20.7,67.5,-54c44.2,-33.3,65.8,
-50.3,66.5,-51c1.3,-1.3,3,-2,5,-2c4.7,0,8.7,3.3,12,10s173,378,173,378c0.7,0,
35.3,-71,104,-213c68.7,-142,137.5,-285,206.5,-429c69,-144,104.5,-217.7,106.5,
-221c5.3,-9.3,12,-14,20,-14H400000v40H845.2724s-225.272,467,-225.272,467
s-235,486,-235,486c-2.7,4.7,-9,7,-19,7c-6,0,-10,-1,-12,-3s-194,-422,-194,-422
s-65,47,-65,47z M834 80H400000v40H845z">
1。
不同于使用
d m o d e l d_{model} </span><span class="katex-html"><span class="base"><span class="strut" style="height: 0.84444em; vertical-align: -0.15em;"></span><span class="mord"><span class="mord mathit">d</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.336108em;"><span class="" style="top: -2.55em; margin-left: 0em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathit mtight">m</span><span class="mord mathit mtight">o</span><span class="mord mathit mtight">d</span><span class="mord mathit mtight">e</span><span class="mord mathit mtight" style="margin-right: 0.01968em;">l</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height: 0.15em;"><span class=""></span></span></span></span></span></span></span></span></span></span>维度的键、值和查询来执行单一注意功能,我们发现,使用另一种可学习的线性投影(projection)分别对查询、键和值进行h次线性投影(projection)会更有效,这些投影 <strong>将这些元素分别映射到维度为<span class="katex--inline"><span class="katex"><span class="katex-mathml"> d k d_{k} </span><span class="katex-html"><span class="base"><span class="strut" style="height: 0.84444em; vertical-align: -0.15em;"></span><span class="mord"><span class="mord mathit">d</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.336108em;"><span class="" style="top: -2.55em; margin-left: 0em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathit mtight" style="margin-right: 0.03148em;">k</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height: 0.15em;"><span class=""></span></span></span></span></span></span></span></span></span></span>、<span class="katex--inline"><span class="katex"><span class="katex-mathml"> d k d_{k} </span><span class="katex-html"><span class="base"><span class="strut" style="height: 0.84444em; vertical-align: -0.15em;"></span><span class="mord"><span class="mord mathit">d</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.336108em;"><span class="" style="top: -2.55em; margin-left: 0em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathit mtight" style="margin-right: 0.03148em;">k</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height: 0.15em;"><span class=""></span></span></span></span></span></span></span></span></span></span>和<span class="katex--inline"><span class="katex"><span class="katex-mathml"> d v d_{v} </span><span class="katex-html"><span class="base"><span class="strut" style="height: 0.84444em; vertical-align: -0.15em;"></span><span class="mord"><span class="mord mathit">d</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.151392em;"><span class="" style="top: -2.55em; margin-left: 0em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathit mtight" style="margin-right: 0.03588em;">v</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height: 0.15em;"><span class=""></span></span></span></span></span></span></span></span></span></span>的空间</strong>。<strong>在这些元素的每个投影版本上,我们并行执行注意力函数,得到<span class="katex--inline"><span class="katex"><span class="katex-mathml"> d v d_{v} </span><span class="katex-html"><span class="base"><span class="strut" style="height: 0.84444em; vertical-align: -0.15em;"></span><span class="mord"><span class="mord mathit">d</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.151392em;"><span class="" style="top: -2.55em; margin-left: 0em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathit mtight" style="margin-right: 0.03588em;">v</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height: 0.15em;"><span class=""></span></span></span></span></span></span></span></span></span></span>维的输出值</strong>。如图2所示,<strong>之后它们被连接起来,并再次进行投影,从而得到最终的值</strong>。</p>
Multi-Head Attention 允许模型关注来自不同位置的不同表示子空间的信息。只需一个single attention head,平均值就可以抑制这种情况。
M u l t i H e a d ( Q , K , V ) = C o n c a t ( h e a d 1 , . . . , h e a d n ) W O MultiHead(Q,K,V)=Concat(head_{1},...,head_{n})W^{O} </span><span class="katex-html"><span class="base"><span class="strut" style="height: 1em; vertical-align: -0.25em;"></span><span class="mord mathit" style="margin-right: 0.10903em;">M</span><span class="mord mathit">u</span><span class="mord mathit" style="margin-right: 0.01968em;">l</span><span class="mord mathit">t</span><span class="mord mathit">i</span><span class="mord mathit" style="margin-right: 0.08125em;">H</span><span class="mord mathit">e</span><span class="mord mathit">a</span><span class="mord mathit">d</span><span class="mopen">(</span><span class="mord mathit">Q</span><span class="mpunct">,</span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="mord mathit" style="margin-right: 0.07153em;">K</span><span class="mpunct">,</span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="mord mathit" style="margin-right: 0.22222em;">V</span><span class="mclose">)</span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right: 0.277778em;"></span></span><span class="base"><span class="strut" style="height: 1.14133em; vertical-align: -0.25em;"></span><span class="mord mathit" style="margin-right: 0.07153em;">C</span><span class="mord mathit">o</span><span class="mord mathit">n</span><span class="mord mathit">c</span><span class="mord mathit">a</span><span class="mord mathit">t</span><span class="mopen">(</span><span class="mord mathit">h</span><span class="mord mathit">e</span><span class="mord mathit">a</span><span class="mord"><span class="mord mathit">d</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.301108em;"><span class="" style="top: -2.55em; margin-left: 0em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight">1</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height: 0.15em;"><span class=""></span></span></span></span></span></span><span class="mpunct">,</span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="mord">.</span><span class="mord">.</span><span class="mord">.</span><span class="mpunct">,</span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="mord mathit">h</span><span class="mord mathit">e</span><span class="mord mathit">a</span><span class="mord"><span class="mord mathit">d</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.151392em;"><span class="" style="top: -2.55em; margin-left: 0em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathit mtight">n</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height: 0.15em;"><span class=""></span></span></span></span></span></span><span class="mclose">)</span><span class="mord"><span class="mord mathit" style="margin-right: 0.13889em;">W</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height: 0.891331em;"><span class="" style="top: -3.113em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathit mtight" style="margin-right: 0.02778em;">O</span></span></span></span></span></span></span></span></span></span></span></span></span></span><br> <span class="katex--display"><span class="katex-display"><span class="katex"><span class="katex-mathml"> w h e r e h e a d i = A t t e n t i o n ( Q W i Q , K W i K , V W i V ) where\ \ head_{i}=Attention(QW^{Q}_{i},KW^{K}_{i},VW^{V}_{i}) </span><span class="katex-html"><span class="base"><span class="strut" style="height: 0.84444em; vertical-align: -0.15em;"></span><span class="mord mathit" style="margin-right: 0.02691em;">w</span><span class="mord mathit">h</span><span class="mord mathit">e</span><span class="mord mathit" style="margin-right: 0.02778em;">r</span><span class="mord mathit">e</span><span class="mspace"> </span><span class="mspace"> </span><span class="mord mathit">h</span><span class="mord mathit">e</span><span class="mord mathit">a</span><span class="mord"><span class="mord mathit">d</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.311664em;"><span class="" style="top: -2.55em; margin-left: 0em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathit mtight">i</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height: 0.15em;"><span class=""></span></span></span></span></span></span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right: 0.277778em;"></span></span><span class="base"><span class="strut" style="height: 1.2361em; vertical-align: -0.276864em;"></span><span class="mord mathit">A</span><span class="mord mathit">t</span><span class="mord mathit">t</span><span class="mord mathit">e</span><span class="mord mathit">n</span><span class="mord mathit">t</span><span class="mord mathit">i</span><span class="mord mathit">o</span><span class="mord mathit">n</span><span class="mopen">(</span><span class="mord mathit">Q</span><span class="mord"><span class="mord mathit" style="margin-right: 0.13889em;">W</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.959239em;"><span class="" style="top: -2.42314em; margin-left: -0.13889em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathit mtight">i</span></span></span></span><span class="" style="top: -3.18091em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathit mtight">Q</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height: 0.276864em;"><span class=""></span></span></span></span></span></span><span class="mpunct">,</span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="mord mathit" style="margin-right: 0.07153em;">K</span><span class="mord"><span class="mord mathit" style="margin-right: 0.13889em;">W</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.891331em;"><span class="" style="top: -2.453em; margin-left: -0.13889em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathit mtight">i</span></span></span></span><span class="" style="top: -3.113em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathit mtight" style="margin-right: 0.07153em;">K</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height: 0.247em;"><span class=""></span></span></span></span></span></span><span class="mpunct">,</span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="mord mathit" style="margin-right: 0.22222em;">V</span><span class="mord"><span class="mord mathit" style="margin-right: 0.13889em;">W</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.891331em;"><span class="" style="top: -2.453em; margin-left: -0.13889em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathit mtight">i</span></span></span></span><span class="" style="top: -3.113em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathit mtight" style="margin-right: 0.22222em;">V</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height: 0.247em;"><span class=""></span></span></span></span></span></span><span class="mclose">)</span></span></span></span></span></span><br> 其中,投影指的是参数矩阵 <span class="katex--inline"><span class="katex"><span class="katex-mathml"> W i Q ∈ R d m o d e l × d k , W i K ∈ R d m o d e l × d k , W i V ∈ R d m o d e l × d v W^{Q}_{i}\in \mathbb{R}^{d_{model}\times d_{k}}, W^{K}_{i}\in \mathbb{R}^{d_{model}\times d_{k}}, W^{V}_{i}\in \mathbb{R}^{d_{model}\times d_{v}} </span><span class="katex-html"><span class="base"><span class="strut" style="height: 1.2361em; vertical-align: -0.276864em;"></span><span class="mord"><span class="mord mathit" style="margin-right: 0.13889em;">W</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.959239em;"><span class="" style="top: -2.42314em; margin-left: -0.13889em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathit mtight">i</span></span></span></span><span class="" style="top: -3.18091em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathit mtight">Q</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height: 0.276864em;"><span class=""></span></span></span></span></span></span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mrel">∈</span><span class="mspace" style="margin-right: 0.277778em;"></span></span><span class="base"><span class="strut" style="height: 1.10777em; vertical-align: -0.258664em;"></span><span class="mord"><span class="mord"><span class="mord mathbb">R</span></span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height: 0.849108em;"><span class="" style="top: -3.063em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight"><span class="mord mathit mtight">d</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.3448em;"><span class="" style="top: -2.34877em; margin-left: 0em; margin-right: 0.0714286em;"><span class="pstrut" style="height: 2.5em;"></span><span class="sizing reset-size3 size1 mtight"><span class="mord mtight"><span class="mord mathit mtight">m</span><span class="mord mathit mtight">o</span><span class="mord mathit mtight">d</span><span class="mord mathit mtight">e</span><span class="mord mathit mtight" style="margin-right: 0.01968em;">l</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height: 0.151229em;"><span class=""></span></span></span></span></span></span><span class="mbin mtight">×</span><span class="mord mtight"><span class="mord mathit mtight">d</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.3448em;"><span class="" style="top: -2.34877em; margin-left: 0em; margin-right: 0.0714286em;"><span class="pstrut" style="height: 2.5em;"></span><span class="sizing reset-size3 size1 mtight"><span class="mord mtight"><span class="mord mathit mtight" style="margin-right: 0.03148em;">k</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height: 0.151229em;"><span class=""></span></span></span></span></span></span></span></span></span></span></span></span></span></span><span class="mpunct">,</span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="mord"><span class="mord mathit" style="margin-right: 0.13889em;">W</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.841331em;"><span class="" style="top: -2.44134em; margin-left: -0.13889em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathit mtight">i</span></span></span></span><span class="" style="top: -3.063em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathit mtight" style="margin-right: 0.07153em;">K</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height: 0.258664em;"><span class=""></span></span></span></span></span></span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mrel">∈</span><span class="mspace" style="margin-right: 0.277778em;"></span></span><span class="base"><span class="strut" style="height: 1.10777em; vertical-align: -0.258664em;"></span><span class="mord"><span class="mord"><span class="mord mathbb">R</span></span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height: 0.849108em;"><span class="" style="top: -3.063em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight"><span class="mord mathit mtight">d</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.3448em;"><span class="" style="top: -2.34877em; margin-left: 0em; margin-right: 0.0714286em;"><span class="pstrut" style="height: 2.5em;"></span><span class="sizing reset-size3 size1 mtight"><span class="mord mtight"><span class="mord mathit mtight">m</span><span class="mord mathit mtight">o</span><span class="mord mathit mtight">d</span><span class="mord mathit mtight">e</span><span class="mord mathit mtight" style="margin-right: 0.01968em;">l</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height: 0.151229em;"><span class=""></span></span></span></span></span></span><span class="mbin mtight">×</span><span class="mord mtight"><span class="mord mathit mtight">d</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.3448em;"><span class="" style="top: -2.34877em; margin-left: 0em; margin-right: 0.0714286em;"><span class="pstrut" style="height: 2.5em;"></span><span class="sizing reset-size3 size1 mtight"><span class="mord mtight"><span class="mord mathit mtight" style="margin-right: 0.03148em;">k</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height: 0.151229em;"><span class=""></span></span></span></span></span></span></span></span></span></span></span></span></span></span><span class="mpunct">,</span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="mord"><span class="mord mathit" style="margin-right: 0.13889em;">W</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.841331em;"><span class="" style="top: -2.44134em; margin-left: -0.13889em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathit mtight">i</span></span></span></span><span class="" style="top: -3.063em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathit mtight" style="margin-right: 0.22222em;">V</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height: 0.258664em;"><span class=""></span></span></span></span></span></span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mrel">∈</span><span class="mspace" style="margin-right: 0.277778em;"></span></span><span class="base"><span class="strut" style="height: 0.849108em; vertical-align: 0em;"></span><span class="mord"><span class="mord"><span class="mord mathbb">R</span></span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height: 0.849108em;"><span class="" style="top: -3.063em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight"><span class="mord mathit mtight">d</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.3448em;"><span class="" style="top: -2.34877em; margin-left: 0em; margin-right: 0.0714286em;"><span class="pstrut" style="height: 2.5em;"></span><span class="sizing reset-size3 size1 mtight"><span class="mord mtight"><span class="mord mathit mtight">m</span><span class="mord mathit mtight">o</span><span class="mord mathit mtight">d</span><span class="mord mathit mtight">e</span><span class="mord mathit mtight" style="margin-right: 0.01968em;">l</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height: 0.151229em;"><span class=""></span></span></span></span></span></span><span class="mbin mtight">×</span><span class="mord mtight"><span class="mord mathit mtight">d</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.164543em;"><span class="" style="top: -2.357em; margin-left: 0em; margin-right: 0.0714286em;"><span class="pstrut" style="height: 2.5em;"></span><span class="sizing reset-size3 size1 mtight"><span class="mord mtight"><span class="mord mathit mtight" style="margin-right: 0.03588em;">v</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height: 0.143em;"><span class=""></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span>。</p>
在这项工作中,我们使用
H = 8 H=8 </span><span class="katex-html"><span class="base"><span class="strut" style="height: 0.68333em; vertical-align: 0em;"></span><span class="mord mathit" style="margin-right: 0.08125em;">H</span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right: 0.277778em;"></span></span><span class="base"><span class="strut" style="height: 0.64444em; vertical-align: 0em;"></span><span class="mord">8</span></span></span></span></span>的并行 attention layers 或 heads。对于每个模型,<span class="katex--inline"><span class="katex"><span class="katex-mathml"> d k = d v = d m o d e l / h = 64 d_{k}=d_{v}=d_{model}/h=64 </span><span class="katex-html"><span class="base"><span class="strut" style="height: 0.84444em; vertical-align: -0.15em;"></span><span class="mord"><span class="mord mathit">d</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.336108em;"><span class="" style="top: -2.55em; margin-left: 0em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathit mtight" style="margin-right: 0.03148em;">k</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height: 0.15em;"><span class=""></span></span></span></span></span></span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right: 0.277778em;"></span></span><span class="base"><span class="strut" style="height: 0.84444em; vertical-align: -0.15em;"></span><span class="mord"><span class="mord mathit">d</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.151392em;"><span class="" style="top: -2.55em; margin-left: 0em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathit mtight" style="margin-right: 0.03588em;">v</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height: 0.15em;"><span class=""></span></span></span></span></span></span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right: 0.277778em;"></span></span><span class="base"><span class="strut" style="height: 1em; vertical-align: -0.25em;"></span><span class="mord"><span class="mord mathit">d</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.336108em;"><span class="" style="top: -2.55em; margin-left: 0em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathit mtight">m</span><span class="mord mathit mtight">o</span><span class="mord mathit mtight">d</span><span class="mord mathit mtight">e</span><span class="mord mathit mtight" style="margin-right: 0.01968em;">l</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height: 0.15em;"><span class=""></span></span></span></span></span></span><span class="mord">/</span><span class="mord mathit">h</span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right: 0.277778em;"></span></span><span class="base"><span class="strut" style="height: 0.64444em; vertical-align: 0em;"></span><span class="mord">6</span><span class="mord">4</span></span></span></span></span>。由于每个头部的降维,总的计算成本与 single-head 全尺寸注意力相似。</p>
Transformer 采用三种不同的方式使用Multi-Head Attention:
在“编-解码器 注意”层中,查询(Query)来自前一个解码器层,记忆的键和值(the memory keys and values)来自编码器的输出。这使得解码器中的每个位置都可以处理输入序列中的所有位置。这模仿了典型的编码器-解码器注意机制的Seq2Seq模型,如[38,2,9]。
编码器包含自我关注层。在一个自我关注层中,所有键、值和查询都来自同一个地方,在本例中,是编码器中前一层的输出。编码器中的每个位置都可以处理编码器前一层中的所有位置。
类似地,解码器中的自我关注层允许解码器中的每个位置关注解码器中的所有位置,包括其自身。为了保持自回归(Auto-Regressive)特性,需要防止解码器中出现向左的信息流。我们通过屏蔽SoftMax输入中与非法连接相对应的所有值(设置为
−
∞
-\infty
</span><span class="katex-html"><span class="base"><span class="strut" style="height: 0.66666em; vertical-align: -0.08333em;"></span><span class="mord">−</span><span class="mord">∞</span></span></span></span></span>)</strong>,在缩放的 Dot-Product Attention 的内部实现了这一点。请参阅图2。</p> </li></ul>
除了关注子层之外,我们的编码器和解码器中的每个层都包含一个完全连接的前馈网络,该网络分别应用于每个position,并且完全相同。该网络也包括有两个通过ReLU连接起来的线性变换。
F F N ( x ) = m a x ( 0 , x W 1 + b 1 ) W 2 + b 2 FFN(x)=max(0,xW_{1}+b_{1})W_{2}+b_{2} </span><span class="katex-html"><span class="base"><span class="strut" style="height: 1em; vertical-align: -0.25em;"></span><span class="mord mathit" style="margin-right: 0.13889em;">F</span><span class="mord mathit" style="margin-right: 0.13889em;">F</span><span class="mord mathit" style="margin-right: 0.10903em;">N</span><span class="mopen">(</span><span class="mord mathit">x</span><span class="mclose">)</span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right: 0.277778em;"></span></span><span class="base"><span class="strut" style="height: 1em; vertical-align: -0.25em;"></span><span class="mord mathit">m</span><span class="mord mathit">a</span><span class="mord mathit">x</span><span class="mopen">(</span><span class="mord">0</span><span class="mpunct">,</span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="mord mathit">x</span><span class="mord"><span class="mord mathit" style="margin-right: 0.13889em;">W</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.301108em;"><span class="" style="top: -2.55em; margin-left: -0.13889em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight">1</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height: 0.15em;"><span class=""></span></span></span></span></span></span><span class="mspace" style="margin-right: 0.222222em;"></span><span class="mbin">+</span><span class="mspace" style="margin-right: 0.222222em;"></span></span><span class="base"><span class="strut" style="height: 1em; vertical-align: -0.25em;"></span><span class="mord"><span class="mord mathit">b</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.301108em;"><span class="" style="top: -2.55em; margin-left: 0em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight">1</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height: 0.15em;"><span class=""></span></span></span></span></span></span><span class="mclose">)</span><span class="mord"><span class="mord mathit" style="margin-right: 0.13889em;">W</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.301108em;"><span class="" style="top: -2.55em; margin-left: -0.13889em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight">2</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height: 0.15em;"><span class=""></span></span></span></span></span></span><span class="mspace" style="margin-right: 0.222222em;"></span><span class="mbin">+</span><span class="mspace" style="margin-right: 0.222222em;"></span></span><span class="base"><span class="strut" style="height: 0.84444em; vertical-align: -0.15em;"></span><span class="mord"><span class="mord mathit">b</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.301108em;"><span class="" style="top: -2.55em; margin-left: 0em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight">2</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height: 0.15em;"><span class=""></span></span></span></span></span></span></span></span></span></span></span><br> 虽然线性变换在不同的位置上是相同的,但它们在不同的层之间使用不同的参数。<strong>另一种描述这一点的方法是两个内核大小为1的卷积</strong>。输入输出维数为<span class="katex--inline"><span class="katex"><span class="katex-mathml"> d m o d e l = 512 d_{model}=512 </span><span class="katex-html"><span class="base"><span class="strut" style="height: 0.84444em; vertical-align: -0.15em;"></span><span class="mord"><span class="mord mathit">d</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.336108em;"><span class="" style="top: -2.55em; margin-left: 0em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathit mtight">m</span><span class="mord mathit mtight">o</span><span class="mord mathit mtight">d</span><span class="mord mathit mtight">e</span><span class="mord mathit mtight" style="margin-right: 0.01968em;">l</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height: 0.15em;"><span class=""></span></span></span></span></span></span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right: 0.277778em;"></span></span><span class="base"><span class="strut" style="height: 0.64444em; vertical-align: 0em;"></span><span class="mord">5</span><span class="mord">1</span><span class="mord">2</span></span></span></span></span>,内层维数为<span class="katex--inline"><span class="katex"><span class="katex-mathml"> d f f = 2048 d_{ff}=2048 </span><span class="katex-html"><span class="base"><span class="strut" style="height: 0.980548em; vertical-align: -0.286108em;"></span><span class="mord"><span class="mord mathit">d</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.336108em;"><span class="" style="top: -2.55em; margin-left: 0em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathit mtight" style="margin-right: 0.10764em;">f</span><span class="mord mathit mtight" style="margin-right: 0.10764em;">f</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height: 0.286108em;"><span class=""></span></span></span></span></span></span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right: 0.277778em;"></span></span><span class="base"><span class="strut" style="height: 0.64444em; vertical-align: 0em;"></span><span class="mord">2</span><span class="mord">0</span><span class="mord">4</span><span class="mord">8</span></span></span></span></span>。</p>
与其他序列转导(Sequence Transduction)模型类似,我们使用可学习的嵌入(Embeddings)将输入tokens和输出tokens转换为
d m o d e l d_{model} </span><span class="katex-html"><span class="base"><span class="strut" style="height: 0.84444em; vertical-align: -0.15em;"></span><span class="mord"><span class="mord mathit">d</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.336108em;"><span class="" style="top: -2.55em; margin-left: 0em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathit mtight">m</span><span class="mord mathit mtight">o</span><span class="mord mathit mtight">d</span><span class="mord mathit mtight">e</span><span class="mord mathit mtight" style="margin-right: 0.01968em;">l</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height: 0.15em;"><span class=""></span></span></span></span></span></span></span></span></span></span>维的向量</strong>。我们还使<strong>用常用的可学习的线性变换和SoftMax函数将解码器的输出转换为待预测的下一个token的概率</strong>。在我们的模型中,我们<strong>在两个嵌入层之间共享相同的权重矩阵和Pre-SoftMax线性变换</strong>,类似于[30]。<strong>在嵌入层中,我们将这些权重乘以<span class="katex--inline"><span class="katex"><span class="katex-mathml"> d m o d e l \sqrt{d_{model}} </span><span class="katex-html"><span class="base"><span class="strut" style="height: 1.04em; vertical-align: -0.18278em;"></span><span class="mord sqrt"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.85722em;"><span class="svg-align" style="top: -3em;"><span class="pstrut" style="height: 3em;"></span><span class="mord" style="padding-left: 0.833em;"><span class="mord"><span class="mord mathit">d</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.336108em;"><span class="" style="top: -2.55em; margin-left: 0em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathit mtight">m</span><span class="mord mathit mtight">o</span><span class="mord mathit mtight">d</span><span class="mord mathit mtight">e</span><span class="mord mathit mtight" style="margin-right: 0.01968em;">l</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height: 0.15em;"><span class=""></span></span></span></span></span></span></span></span><span class="" style="top: -2.81722em;"><span class="pstrut" style="height: 3em;"></span><span class="hide-tail" style="min-width: 0.853em; height: 1.08em;"> <svg width="400em" height="1.08em" viewBox="0 0 400000 1080" preserveAspectRatio="xMinYMin slice"> <path d="M95,702c-2.7,0,-7.17,-2.7,-13.5,-8c-5.8,-5.3,-9.5,
-10,-9.5,-14c0,-2,0.3,-3.3,1,-4c1.3,-2.7,23.83,-20.7,67.5,-54c44.2,-33.3,65.8,
-50.3,66.5,-51c1.3,-1.3,3,-2,5,-2c4.7,0,8.7,3.3,12,10s173,378,173,378c0.7,0,
35.3,-71,104,-213c68.7,-142,137.5,-285,206.5,-429c69,-144,104.5,-217.7,106.5,
-221c5.3,-9.3,12,-14,20,-14H400000v40H845.2724s-225.272,467,-225.272,467
s-235,486,-235,486c-2.7,4.7,-9,7,-19,7c-6,0,-10,-1,-12,-3s-194,-422,-194,-422
s-65,47,-65,47z M834 80H400000v40H845z">
。
由于我们的模型不包含递归和卷积,为了使模型能够利用序列的顺序(the order of the sequence),我们必须注入一些关于序列中tokens的相对或绝对位置的信息。为此,我们将“位置编码”添加到编码器和解码器堆栈底部的输入嵌入(embeddings)中。位置编码与嵌入具有相同的维度
d m o d e l d_{model} </span><span class="katex-html"><span class="base"><span class="strut" style="height: 0.84444em; vertical-align: -0.15em;"></span><span class="mord"><span class="mord mathit">d</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.336108em;"><span class="" style="top: -2.55em; margin-left: 0em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathit mtight">m</span><span class="mord mathit mtight">o</span><span class="mord mathit mtight">d</span><span class="mord mathit mtight">e</span><span class="mord mathit mtight" style="margin-right: 0.01968em;">l</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height: 0.15em;"><span class=""></span></span></span></span></span></span></span></span></span></span>,因此可以将两者相加。位置编码有很多选择,可学习的和可固定的[9]。<br> <img src="https://img-blog.csdnimg.cn/20190303235721466.png?x-oss-process=image/watermark,type_ZmFuZ3poZW5naGVpdGk,shadow_10,text_aHR0cHM6Ly9ibG9nLmNzZG4ubmV0L3FxXzI5Njk1NzAx,size_16,color_FFFFFF,t_70" alt="在这里插入图片描述"><br> 在这项工作中,我们将<strong>使用不同频率的正余弦函数</strong>:<br> <span class="katex--display"><span class="katex-display"><span class="katex"><span class="katex-mathml"> P E p o s , 2 i = s i n ( p o s / 1000 0 2 i / d m o d e l ) PE_{pos,2i}=sin(pos/10000^{2i/d_{model}}) </span><span class="katex-html"><span class="base"><span class="strut" style="height: 0.969438em; vertical-align: -0.286108em;"></span><span class="mord mathit" style="margin-right: 0.13889em;">P</span><span class="mord"><span class="mord mathit" style="margin-right: 0.05764em;">E</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.311664em;"><span class="" style="top: -2.55em; margin-left: -0.05764em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathit mtight">p</span><span class="mord mathit mtight">o</span><span class="mord mathit mtight">s</span><span class="mpunct mtight">,</span><span class="mord mtight">2</span><span class="mord mathit mtight">i</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height: 0.286108em;"><span class=""></span></span></span></span></span></span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right: 0.277778em;"></span></span><span class="base"><span class="strut" style="height: 1.188em; vertical-align: -0.25em;"></span><span class="mord mathit">s</span><span class="mord mathit">i</span><span class="mord mathit">n</span><span class="mopen">(</span><span class="mord mathit">p</span><span class="mord mathit">o</span><span class="mord mathit">s</span><span class="mord">/</span><span class="mord">1</span><span class="mord">0</span><span class="mord">0</span><span class="mord">0</span><span class="mord"><span class="mord">0</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height: 0.938em;"><span class="" style="top: -3.113em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight">2</span><span class="mord mathit mtight">i</span><span class="mord mtight">/</span><span class="mord mtight"><span class="mord mathit mtight">d</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.3448em;"><span class="" style="top: -2.34877em; margin-left: 0em; margin-right: 0.0714286em;"><span class="pstrut" style="height: 2.5em;"></span><span class="sizing reset-size3 size1 mtight"><span class="mord mtight"><span class="mord mathit mtight">m</span><span class="mord mathit mtight">o</span><span class="mord mathit mtight">d</span><span class="mord mathit mtight">e</span><span class="mord mathit mtight" style="margin-right: 0.01968em;">l</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height: 0.151229em;"><span class=""></span></span></span></span></span></span></span></span></span></span></span></span></span></span><span class="mclose">)</span></span></span></span></span></span><br> <span class="katex--display"><span class="katex-display"><span class="katex"><span class="katex-mathml"> P E p o s , 2 i + 1 = c o s ( p o s / 1000 0 2 i / d m o d e l ) PE_{pos,2i+1}=cos(pos/10000^{2i/d_{model}}) </span><span class="katex-html"><span class="base"><span class="strut" style="height: 0.969438em; vertical-align: -0.286108em;"></span><span class="mord mathit" style="margin-right: 0.13889em;">P</span><span class="mord"><span class="mord mathit" style="margin-right: 0.05764em;">E</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.311664em;"><span class="" style="top: -2.55em; margin-left: -0.05764em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathit mtight">p</span><span class="mord mathit mtight">o</span><span class="mord mathit mtight">s</span><span class="mpunct mtight">,</span><span class="mord mtight">2</span><span class="mord mathit mtight">i</span><span class="mbin mtight">+</span><span class="mord mtight">1</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height: 0.286108em;"><span class=""></span></span></span></span></span></span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right: 0.277778em;"></span></span><span class="base"><span class="strut" style="height: 1.188em; vertical-align: -0.25em;"></span><span class="mord mathit">c</span><span class="mord mathit">o</span><span class="mord mathit">s</span><span class="mopen">(</span><span class="mord mathit">p</span><span class="mord mathit">o</span><span class="mord mathit">s</span><span class="mord">/</span><span class="mord">1</span><span class="mord">0</span><span class="mord">0</span><span class="mord">0</span><span class="mord"><span class="mord">0</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height: 0.938em;"><span class="" style="top: -3.113em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight">2</span><span class="mord mathit mtight">i</span><span class="mord mtight">/</span><span class="mord mtight"><span class="mord mathit mtight">d</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.3448em;"><span class="" style="top: -2.34877em; margin-left: 0em; margin-right: 0.0714286em;"><span class="pstrut" style="height: 2.5em;"></span><span class="sizing reset-size3 size1 mtight"><span class="mord mtight"><span class="mord mathit mtight">m</span><span class="mord mathit mtight">o</span><span class="mord mathit mtight">d</span><span class="mord mathit mtight">e</span><span class="mord mathit mtight" style="margin-right: 0.01968em;">l</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height: 0.151229em;"><span class=""></span></span></span></span></span></span></span></span></span></span></span></span></span></span><span class="mclose">)</span></span></span></span></span></span><br> 其中,pos是位置,i是维度。也就是说,位置编码的每个维度都对应于一个正弦曲线。波长形成一个从<span class="katex--inline"><span class="katex"><span class="katex-mathml"> 2 π 2\pi </span><span class="katex-html"><span class="base"><span class="strut" style="height: 0.64444em; vertical-align: 0em;"></span><span class="mord">2</span><span class="mord mathit" style="margin-right: 0.03588em;">π</span></span></span></span></span>到<span class="katex--inline"><span class="katex"><span class="katex-mathml"> 10000 ⋅ 2 π 10000\cdot2\pi </span><span class="katex-html"><span class="base"><span class="strut" style="height: 0.64444em; vertical-align: 0em;"></span><span class="mord">1</span><span class="mord">0</span><span class="mord">0</span><span class="mord">0</span><span class="mord">0</span><span class="mspace" style="margin-right: 0.222222em;"></span><span class="mbin">⋅</span><span class="mspace" style="margin-right: 0.222222em;"></span></span><span class="base"><span class="strut" style="height: 0.64444em; vertical-align: 0em;"></span><span class="mord">2</span><span class="mord mathit" style="margin-right: 0.03588em;">π</span></span></span></span></span>的几何轨迹。我们之所以选择这个函数,是因为我们假设它可以让模型很容易地通过相对位置进行学习,因为对于任何固定的偏移量k,<span class="katex--inline"><span class="katex"><span class="katex-mathml"> P E p o s + k PE_{pos+k} </span><span class="katex-html"><span class="base"><span class="strut" style="height: 0.969438em; vertical-align: -0.286108em;"></span><span class="mord mathit" style="margin-right: 0.13889em;">P</span><span class="mord"><span class="mord mathit" style="margin-right: 0.05764em;">E</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.336108em;"><span class="" style="top: -2.55em; margin-left: -0.05764em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathit mtight">p</span><span class="mord mathit mtight">o</span><span class="mord mathit mtight">s</span><span class="mbin mtight">+</span><span class="mord mathit mtight" style="margin-right: 0.03148em;">k</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height: 0.286108em;"><span class=""></span></span></span></span></span></span></span></span></span></span>都可以表示为<span class="katex--inline"><span class="katex"><span class="katex-mathml"> P E p o s PE_{pos} </span><span class="katex-html"><span class="base"><span class="strut" style="height: 0.969438em; vertical-align: -0.286108em;"></span><span class="mord mathit" style="margin-right: 0.13889em;">P</span><span class="mord"><span class="mord mathit" style="margin-right: 0.05764em;">E</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.151392em;"><span class="" style="top: -2.55em; margin-left: -0.05764em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathit mtight">p</span><span class="mord mathit mtight">o</span><span class="mord mathit mtight">s</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height: 0.286108em;"><span class=""></span></span></span></span></span></span></span></span></span></span>的线性函数。<br> 我们对learned positional embeddings[9]进行了实验,发现两个版本产生了几乎相同的结果(见表3第(e)行)。我们选择正弦波模型是因为它<strong>可以让模型外推到比训练中遇到的序列长度更长的序列</strong>。</p>
在本节中,我们将自关注层的各个方面与递归和卷积层进行比较,后两个通常被用于将一个可变长度的符号表示序列(
x 1 , … , x n x_{1},…,x_{n} </span><span class="katex-html"><span class="base"><span class="strut" style="height: 0.625em; vertical-align: -0.19444em;"></span><span class="mord"><span class="mord mathit">x</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.301108em;"><span class="" style="top: -2.55em; margin-left: 0em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight">1</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height: 0.15em;"><span class=""></span></span></span></span></span></span><span class="mpunct">,</span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="minner">…</span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="mpunct">,</span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="mord"><span class="mord mathit">x</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.151392em;"><span class="" style="top: -2.55em; margin-left: 0em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathit mtight">n</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height: 0.15em;"><span class=""></span></span></span></span></span></span></span></span></span></span>)映射到另一个等长序列(<span class="katex--inline"><span class="katex"><span class="katex-mathml"> z 1 , … , z n z_{1},…,z_{n} </span><span class="katex-html"><span class="base"><span class="strut" style="height: 0.625em; vertical-align: -0.19444em;"></span><span class="mord"><span class="mord mathit" style="margin-right: 0.04398em;">z</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.301108em;"><span class="" style="top: -2.55em; margin-left: -0.04398em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight">1</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height: 0.15em;"><span class=""></span></span></span></span></span></span><span class="mpunct">,</span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="minner">…</span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="mpunct">,</span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="mord"><span class="mord mathit" style="margin-right: 0.04398em;">z</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.151392em;"><span class="" style="top: -2.55em; margin-left: -0.04398em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathit mtight">n</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height: 0.15em;"><span class=""></span></span></span></span></span></span></span></span></span></span>),其中<span class="katex--inline"><span class="katex"><span class="katex-mathml"> x i , z i ∈ R d x_{i},\ z_{i}\in \mathbb{R}^{d} </span><span class="katex-html"><span class="base"><span class="strut" style="height: 0.73354em; vertical-align: -0.19444em;"></span><span class="mord"><span class="mord mathit">x</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.311664em;"><span class="" style="top: -2.55em; margin-left: 0em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathit mtight">i</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height: 0.15em;"><span class=""></span></span></span></span></span></span><span class="mpunct">,</span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="mspace"> </span><span class="mord"><span class="mord mathit" style="margin-right: 0.04398em;">z</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.311664em;"><span class="" style="top: -2.55em; margin-left: -0.04398em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathit mtight">i</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height: 0.15em;"><span class=""></span></span></span></span></span></span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mrel">∈</span><span class="mspace" style="margin-right: 0.277778em;"></span></span><span class="base"><span class="strut" style="height: 0.849108em; vertical-align: 0em;"></span><span class="mord"><span class="mord"><span class="mord mathbb">R</span></span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height: 0.849108em;"><span class="" style="top: -3.063em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathit mtight">d</span></span></span></span></span></span></span></span></span></span></span></span></span>,如同在一个典型的序列转导编码器或解码器中的隐藏层。为了使用自我关注,我们考虑了三个目标。</p>
一个是每层的总计算复杂性。另一个是可以并行化的计算量,用所需的最小顺序操作数来衡量。
第三个是网络中远程依赖项之间的路径长度。学习长期依赖性是许多序列转导任务中的一个关键挑战。影响学习这种依赖性能力的一个关键因素是必须在网络中遍历的前向和后向信号的路径长度。输入序列和输出序列中任意位置组合之间的这些路径越短,学习长期依赖关系就越容易[12]。因此,我们还比较了由不同层类型组成的网络中任意两个输入和输出位置之间的最大路径长度。
如表1所示,一个自我关注层将所有位置与恒定数量的顺序执行操作连接起来,而一个循环层则需要O(N)顺序操作。在计算复杂度方面,当序列长度n小于表示维数d时,自注意层比循环层更快,这是机器翻译中最先进的模型(如word-piece[38]和byte-pair[31] representations)使用的句子表示最常见的情况。为了提高涉及非常长序列的任务的计算性能,可以将自我关注限制为仅考虑以各自输出位置为中心的输入序列中大小为
r r </span><span class="katex-html"><span class="base"><span class="strut" style="height: 0.43056em; vertical-align: 0em;"></span><span class="mord mathit" style="margin-right: 0.02778em;">r</span></span></span></span></span>的邻域。这将把最大路径长度增加到<span class="katex--inline"><span class="katex"><span class="katex-mathml"> O ( n / r ) O(n/r) </span><span class="katex-html"><span class="base"><span class="strut" style="height: 1em; vertical-align: -0.25em;"></span><span class="mord mathit" style="margin-right: 0.02778em;">O</span><span class="mopen">(</span><span class="mord mathit">n</span><span class="mord">/</span><span class="mord mathit" style="margin-right: 0.02778em;">r</span><span class="mclose">)</span></span></span></span></span>。我们计划在今后的工作中进一步研究这种方法。</p>
核宽
k &lt; n k&lt;n </span><span class="katex-html"><span class="base"><span class="strut" style="height: 0.73354em; vertical-align: -0.0391em;"></span><span class="mord mathit" style="margin-right: 0.03148em;">k</span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mrel"><</span><span class="mspace" style="margin-right: 0.277778em;"></span></span><span class="base"><span class="strut" style="height: 0.43056em; vertical-align: 0em;"></span><span class="mord mathit">n</span></span></span></span></span>的单个卷积层不连接所有输入和输出位置对。这样做需要一堆O(n/r)卷积层(对于连续的内核)或O(<span class="katex--inline"><span class="katex"><span class="katex-mathml"> l o g k ( n ) log_{k}(n) </span><span class="katex-html"><span class="base"><span class="strut" style="height: 1em; vertical-align: -0.25em;"></span><span class="mord mathit" style="margin-right: 0.01968em;">l</span><span class="mord mathit">o</span><span class="mord"><span class="mord mathit" style="margin-right: 0.03588em;">g</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.336108em;"><span class="" style="top: -2.55em; margin-left: -0.03588em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathit mtight" style="margin-right: 0.03148em;">k</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height: 0.15em;"><span class=""></span></span></span></span></span></span><span class="mopen">(</span><span class="mord mathit">n</span><span class="mclose">)</span></span></span></span></span>)(对于扩展卷积[18]),增加网络中任意两个位置之间最长路径的长度。卷积层通常比循环层更昂贵,其系数为k。可分离卷积[6]可将复杂性大大降低至o(k·n·d+n·<span class="katex--inline"><span class="katex"><span class="katex-mathml"> d 2 d^{2} </span><span class="katex-html"><span class="base"><span class="strut" style="height: 0.814108em; vertical-align: 0em;"></span><span class="mord"><span class="mord mathit">d</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height: 0.814108em;"><span class="" style="top: -3.063em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight">2</span></span></span></span></span></span></span></span></span></span></span></span></span>)。然而,即使k=n,可分离卷积的复杂度也等于我们在模型中采用的自注意层和 point-wise 前馈层的组合。</p>
作为附带的好处,自我关注可以产生更多可解释的模型。我们检查模型中的注意力分布,并在附录中展示和讨论示例。个体的注意力不仅能清楚地学习执行不同的任务,而且许多注意力表现出与句子的句法和语义结构相关的行为。
本节介绍了我们的模型的训练方式。
我们训练了标准的WMT 2014英语-德语数据集,包含约450万个句子对。语句使用字节对编码[3]进行编码,该编码具有大约37000个标记的共享源-目标词汇表。对于英语-法语,我们使用了更大的WMT 2014英语-法语数据集,该数据集包含3600万句句子,并将标记拆分为32000个词条词汇[38]。句子对按近似的序列长度分批在一起。每个训练批包含一组句子对,其中包含大约25000个源标记和25000个目标标记。
我们用8个Nvidia P100 GPU在一台机器上训练我们的模型。对于使用本文中描述的超参数的基本模型,每个训练步骤大约需要0.4秒。我们对基础模型进行了总计100000步或12小时的训练。对于我们的大型模型(如表3的底线所述),步进时间为1.0秒。大模型接受了300000步(3.5天)的训练。
我们使用Adam优化方案[20],其中
β 1 = 0.9 , β 2 = 0.98 , ϵ = 1 0 − 9 \beta_{1}=0.9,\ \beta_{2}=0.98,\ \epsilon=10^{-9} </span><span class="katex-html"><span class="base"><span class="strut" style="height: 0.88888em; vertical-align: -0.19444em;"></span><span class="mord"><span class="mord mathit" style="margin-right: 0.05278em;">β</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.301108em;"><span class="" style="top: -2.55em; margin-left: -0.05278em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight">1</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height: 0.15em;"><span class=""></span></span></span></span></span></span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right: 0.277778em;"></span></span><span class="base"><span class="strut" style="height: 0.88888em; vertical-align: -0.19444em;"></span><span class="mord">0</span><span class="mord">.</span><span class="mord">9</span><span class="mpunct">,</span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="mspace"> </span><span class="mord"><span class="mord mathit" style="margin-right: 0.05278em;">β</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.301108em;"><span class="" style="top: -2.55em; margin-left: -0.05278em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight">2</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height: 0.15em;"><span class=""></span></span></span></span></span></span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right: 0.277778em;"></span></span><span class="base"><span class="strut" style="height: 0.83888em; vertical-align: -0.19444em;"></span><span class="mord">0</span><span class="mord">.</span><span class="mord">9</span><span class="mord">8</span><span class="mpunct">,</span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="mspace"> </span><span class="mord mathit">ϵ</span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right: 0.277778em;"></span></span><span class="base"><span class="strut" style="height: 0.814108em; vertical-align: 0em;"></span><span class="mord">1</span><span class="mord"><span class="mord">0</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height: 0.814108em;"><span class="" style="top: -3.063em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight">−</span><span class="mord mtight">9</span></span></span></span></span></span></span></span></span></span></span></span></span>。根据如下的公式,我们在整个训练过程中改变了学习速度:<br> <span class="katex--display"><span class="katex-display"><span class="katex"><span class="katex-mathml"> l r a t e = d m o d e l − 0.5 ⋅ m i n ( s t e p _ n u m − 0.5 , s t e p _ n u m ⋅ w a r m u p _ s t e p s − 1.5 ) ( 3 ) lrate=d^{-0.5}_{model}\cdot min(step\_num^{-0.5},step\_num\cdot warmup\_steps^{-1.5})\ \ \ \ (3) </span><span class="katex-html"><span class="base"><span class="strut" style="height: 0.69444em; vertical-align: 0em;"></span><span class="mord mathit" style="margin-right: 0.01968em;">l</span><span class="mord mathit" style="margin-right: 0.02778em;">r</span><span class="mord mathit">a</span><span class="mord mathit">t</span><span class="mord mathit">e</span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right: 0.277778em;"></span></span><span class="base"><span class="strut" style="height: 1.15555em; vertical-align: -0.291439em;"></span><span class="mord"><span class="mord mathit">d</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.864108em;"><span class="" style="top: -2.40856em; margin-left: 0em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathit mtight">m</span><span class="mord mathit mtight">o</span><span class="mord mathit mtight">d</span><span class="mord mathit mtight">e</span><span class="mord mathit mtight" style="margin-right: 0.01968em;">l</span></span></span></span><span class="" style="top: -3.113em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight">−</span><span class="mord mtight">0</span><span class="mord mtight">.</span><span class="mord mtight">5</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height: 0.291439em;"><span class=""></span></span></span></span></span></span><span class="mspace" style="margin-right: 0.222222em;"></span><span class="mbin">⋅</span><span class="mspace" style="margin-right: 0.222222em;"></span></span><span class="base"><span class="strut" style="height: 1.17411em; vertical-align: -0.31em;"></span><span class="mord mathit">m</span><span class="mord mathit">i</span><span class="mord mathit">n</span><span class="mopen">(</span><span class="mord mathit">s</span><span class="mord mathit">t</span><span class="mord mathit">e</span><span class="mord mathit">p</span><span class="mord" style="margin-right: 0.02778em;">_</span><span class="mord mathit">n</span><span class="mord mathit">u</span><span class="mord"><span class="mord mathit">m</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height: 0.864108em;"><span class="" style="top: -3.113em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight">−</span><span class="mord mtight">0</span><span class="mord mtight">.</span><span class="mord mtight">5</span></span></span></span></span></span></span></span></span><span class="mpunct">,</span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="mord mathit">s</span><span class="mord mathit">t</span><span class="mord mathit">e</span><span class="mord mathit">p</span><span class="mord" style="margin-right: 0.02778em;">_</span><span class="mord mathit">n</span><span class="mord mathit">u</span><span class="mord mathit">m</span><span class="mspace" style="margin-right: 0.222222em;"></span><span class="mbin">⋅</span><span class="mspace" style="margin-right: 0.222222em;"></span></span><span class="base"><span class="strut" style="height: 1.17411em; vertical-align: -0.31em;"></span><span class="mord mathit" style="margin-right: 0.02691em;">w</span><span class="mord mathit">a</span><span class="mord mathit" style="margin-right: 0.02778em;">r</span><span class="mord mathit">m</span><span class="mord mathit">u</span><span class="mord mathit">p</span><span class="mord" style="margin-right: 0.02778em;">_</span><span class="mord mathit">s</span><span class="mord mathit">t</span><span class="mord mathit">e</span><span class="mord mathit">p</span><span class="mord"><span class="mord mathit">s</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height: 0.864108em;"><span class="" style="top: -3.113em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight">−</span><span class="mord mtight">1</span><span class="mord mtight">.</span><span class="mord mtight">5</span></span></span></span></span></span></span></span></span><span class="mclose">)</span><span class="mspace"> </span><span class="mspace"> </span><span class="mspace"> </span><span class="mspace"> </span><span class="mopen">(</span><span class="mord">3</span><span class="mclose">)</span></span></span></span></span></span><br> 这对应于在第一个warmup_steps的训练步骤中线性地增加学习速率,然后与步骤数的平方反比成比例地减少学习速率。我们使用warmup_steps=4000。</p>
我们在训练过程中使用三种类型的正则方案:
Residual Dropout 我们将Dropout[33]应用于每个子层的输出,然后将其添加到子层输入并进行规范化。此外,我们还将Dropout应用于编码器和解码器堆栈中嵌入和位置编码的和。对于基本模型,我们使用
P d r o p P_{drop} </span><span class="katex-html"><span class="base"><span class="strut" style="height: 0.969438em; vertical-align: -0.286108em;"></span><span class="mord"><span class="mord mathit" style="margin-right: 0.13889em;">P</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.336108em;"><span class="" style="top: -2.55em; margin-left: -0.13889em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathit mtight">d</span><span class="mord mathit mtight" style="margin-right: 0.02778em;">r</span><span class="mord mathit mtight">o</span><span class="mord mathit mtight">p</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height: 0.286108em;"><span class=""></span></span></span></span></span></span></span></span></span></span>=0.1的速率。<br> <img src="https://img-blog.csdnimg.cn/20190303235805366.png?x-oss-process=image/watermark,type_ZmFuZ3poZW5naGVpdGk,shadow_10,text_aHR0cHM6Ly9ibG9nLmNzZG4ubmV0L3FxXzI5Njk1NzAx,size_16,color_FFFFFF,t_70" alt="在这里插入图片描述"><br> <strong>Label Smoothing</strong> 在训练过程中,我们采用了<span class="katex--inline"><span class="katex"><span class="katex-mathml"> ϵ l s = 0.1 \epsilon_{ls}=0.1 </span><span class="katex-html"><span class="base"><span class="strut" style="height: 0.58056em; vertical-align: -0.15em;"></span><span class="mord"><span class="mord mathit">ϵ</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.336108em;"><span class="" style="top: -2.55em; margin-left: 0em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathit mtight" style="margin-right: 0.01968em;">l</span><span class="mord mathit mtight">s</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height: 0.15em;"><span class=""></span></span></span></span></span></span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right: 0.277778em;"></span></span><span class="base"><span class="strut" style="height: 0.64444em; vertical-align: 0em;"></span><span class="mord">0</span><span class="mord">.</span><span class="mord">1</span></span></span></span></span>[36]的标签平滑。这样做,虽然使模型有更大的不确定性,但也提高了准确性和BLEU分数。</p>
在WMT 2014英语到德语翻译任务中,表2中的大 transformer 模型比之前报告的最好的模型(包括Ensembles)强2.0 Bleu以上,建立了一个新的最先进的BLEU分数28.4。该模型的配置列在表3的底线中。训练时间为3.5天,平均成绩为100分。即使是我们的基础模型也超越了以前发布的所有模型和集成,而这也只是训练成本的一小部分。
在WMT 2014英语到法语翻译任务中,我们的大模型获得了41.0的BLEU分数,超过了之前发布的所有单一模型,低于之前最先进模型训练成本的1/4。为英语到法语训练的Transformer(大)模型使用辍学率
P d r o p = 0.1 P_{drop}=0.1 </span><span class="katex-html"><span class="base"><span class="strut" style="height: 0.969438em; vertical-align: -0.286108em;"></span><span class="mord"><span class="mord mathit" style="margin-right: 0.13889em;">P</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.336108em;"><span class="" style="top: -2.55em; margin-left: -0.13889em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathit mtight">d</span><span class="mord mathit mtight" style="margin-right: 0.02778em;">r</span><span class="mord mathit mtight">o</span><span class="mord mathit mtight">p</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height: 0.286108em;"><span class=""></span></span></span></span></span></span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right: 0.277778em;"></span></span><span class="base"><span class="strut" style="height: 0.64444em; vertical-align: 0em;"></span><span class="mord">0</span><span class="mord">.</span><span class="mord">1</span></span></span></span></span>,而不是0.3。</p>
对于基本模型,我们使用了一个单一模型,该模型通过均值化最后5个 checkpoints 获得,这些检查点以10分钟的间隔写入。对于大型模型,我们均值化最后20个检查点。我们使用beam search,波束大小为4,长度惩罚
α
=
0.6
\alpha=0.6
</span><span class="katex-html"><span class="base"><span class="strut" style="height: 0.43056em; vertical-align: 0em;"></span><span class="mord mathit" style="margin-right: 0.0037em;">α</span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right: 0.277778em;"></span></span><span class="base"><span class="strut" style="height: 0.64444em; vertical-align: 0em;"></span><span class="mord">0</span><span class="mord">.</span><span class="mord">6</span></span></span></span></span>[38]。这些超参数是在对开发集进行实验后选择的。我们将推理期间的最大输出长度设置为输入长度+50,但在可能的情况下提前终止[38]。</p>
表2总结了我们的结果,并将我们的翻译质量和训练成本与文献中的其他模型架构进行了比较。我们通过乘以训练时间、使用的GPU数量和每个GPU 5的持续单精度浮点容量来估计用于训练模型的浮点操作数。
为了评估Transformer不同组件的重要性,我们以不同的方式改变了我们的基础模型,测量了开发集《2013年新闻测试》中英译德翻译的性能变化。我们使用了前一节中描述的波束搜索,但没有检查点平均值。我们在表3中给出了这些结果。
在表3的第(a)行中,我们改变了注意头的数量、注意键和值的尺寸,保持计算量不变,如第3.2.2节所述。虽然单头关注是0.9布鲁比最佳设置差,质量也下降了与太多的头。
在表3行(b)中,我们观察到减少注意键大小dk会损害模型质量。这表明,确定兼容性并不容易,比点积更复杂的兼容性函数可能是有益的。我们在(c)和(d)行中进一步观察到,正如预期的那样,较大的模型更好,而退出对于避免过度拟合非常有帮助。在第(e)行中,我们将正弦位置编码替换为学习的位置嵌入[9],并观察到与基本模型几乎相同的结果。
为了评估Transformer是否可以推广到其他任务,我们进行了英语选区分析实验。这项任务提出了具体的挑战:输出受到强大的结构约束,并且明显长于输入。此外,RNN序列到序列模型无法在小数据状态下获得最先进的结果[37]。
我们在宾夕法尼亚州Treebank[25]的《华尔街日报》(WSJ)部分训练了一个4层Transformer,其
d m o d e l = 1024 d_{model}=1024 </span><span class="katex-html"><span class="base"><span class="strut" style="height: 0.84444em; vertical-align: -0.15em;"></span><span class="mord"><span class="mord mathit">d</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.336108em;"><span class="" style="top: -2.55em; margin-left: 0em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathit mtight">m</span><span class="mord mathit mtight">o</span><span class="mord mathit mtight">d</span><span class="mord mathit mtight">e</span><span class="mord mathit mtight" style="margin-right: 0.01968em;">l</span></span></span></span></span><span class="vlist-s"></span></span><span class="vlist-r"><span class="vlist" style="height: 0.15em;"><span class=""></span></span></span></span></span></span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right: 0.277778em;"></span></span><span class="base"><span class="strut" style="height: 0.64444em; vertical-align: 0em;"></span><span class="mord">1</span><span class="mord">0</span><span class="mord">2</span><span class="mord">4</span></span></span></span></span>,大约有40k个训练句子。我们也在一个半监督的环境中训练它,使用更大的高置信度和Berkleyparser语料库,大约有1700万句话[37]。我们只在wsj设置中使用16Ktoken的词汇表,在半监督设置中使用32K token的词汇表。</p>
我们只做了少量的实验来选择第22节开发集上的辍学、注意力和残余(第5.4节)、学习率和光束大小,所有其他参数从英语到德语的基本翻译模型保持不变。在推理过程中,我们将最大输出长度增加到输入长度+300。我们仅对WSJ和半监督设置使用了21和
α
=
0.3
\alpha=0.3
</span><span class="katex-html"><span class="base"><span class="strut" style="height: 0.43056em; vertical-align: 0em;"></span><span class="mord mathit" style="margin-right: 0.0037em;">α</span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right: 0.277778em;"></span></span><span class="base"><span class="strut" style="height: 0.64444em; vertical-align: 0em;"></span><span class="mord">0</span><span class="mord">.</span><span class="mord">3</span></span></span></span></span>的 beam size。</p>
我们在表4中的结果表明,尽管缺乏特定于任务的调整,我们的模型仍然运行得非常好,产生的结果比以前报告的所有模型都好,除了重复性神经网络语法[8]。
与RNN序列到序列模型[37]相比,Transformer优于Berkeley-Parser[29],即使仅在仅针对40K句子的WSJ训练集进行训练时也是如此。
在这项工作中,我们提出了完全基于注意的第一序列转导模型Transformer,用多头自注意取代了编码器-解码器体系结构中最常用的循环层。
对于翻译任务,Transformer的训练速度明显快于基于循环层或卷积层的架构。在WMT 2014英语到德语和WMT 2014英语到法语的翻译任务中,我们实现了一种新的艺术状态。在前一项任务中,我们的最佳模型甚至超过了之前报道的所有合奏。
我们对基于注意力的模型的未来感到兴奋,并计划将其应用到其他任务中。我们计划将Transformer扩展到涉及输入和输出模式(文本除外)的问题,并调查本地、受限注意机制,以有效处理图像、音频和视频等大型输入和输出。我们的另一个研究目标是减少一代人的顺序。
我们用来训练和评估模型的代码可以在https://github.com/tensorflow/tensor2tensor上找到。
参考文献
[1] Jimmy Lei Ba, Jamie Ryan Kiros, and Geoffrey E Hinton. Layer normalization. arXiv preprint arXiv:1607.06450, 2016.
[2] Dzmitry Bahdanau, Kyunghyun Cho, and Yoshua Bengio. Neural machine translation by jointly learning to align and translate. CoRR, abs/1409.0473, 2014.
[3] Denny Britz, Anna Goldie, Minh-Thang Luong, and Quoc V. Le. Massive exploration of neural machine translation architectures. CoRR, abs/1703.03906, 2017.
[4] Jianpeng Cheng, Li Dong, and Mirella Lapata. Long short-term memory-networks for machine reading. arXiv preprint arXiv:1601.06733, 2016.
[5] Kyunghyun Cho, Bart van Merrienboer, Caglar Gulcehre, Fethi Bougares, Holger Schwenk, and Yoshua Bengio. Learning phrase representations using rnn encoder-decoder for statistical machine translation. CoRR, abs/1406.1078, 2014.
[6] Francois Chollet. Xception: Deep learning with depthwise separable convolutions. arXiv preprint arXiv:1610.02357, 2016.10
[7] Junyoung Chung, Çaglar Gülçehre, Kyunghyun Cho, and Yoshua Bengio. Empirical evaluation of gated recurrent neural networks on sequence modeling. CoRR, abs/1412.3555, 2014.
[8] Chris Dyer, Adhiguna Kuncoro, Miguel Ballesteros, and Noah A. Smith. Recurrent neural network grammars. In Proc. of NAACL, 2016.
[9] Jonas Gehring, Michael Auli, David Grangier, Denis Yarats, and Yann N. Dauphin. Convolu- tional sequence to sequence learning. arXiv preprint arXiv:1705.03122v2, 2017.
[10] Alex Graves. Generating sequences with recurrent neural networks. arXiv preprint arXiv:1308.0850, 2013.
[11] Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun. Deep residual learning for im- age recognition. In Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pages 770–778, 2016.
[12] Sepp Hochreiter, Yoshua Bengio, Paolo Frasconi, and Jürgen Schmidhuber. Gradient flow in recurrent nets: the difficulty of learning long-term dependencies, 2001.
[13] Sepp Hochreiter and Jürgen Schmidhuber. Long short-term memory. Neural computation, 9(8):1735–1780, 1997.
[14] Zhongqiang Huang and Mary Harper. Self-training PCFG grammars with latent annotations across languages. In Proceedings of the 2009 Conference on Empirical Methods in Natural Language Processing, pages 832–841. ACL, August 2009.
[15] Rafal Jozefowicz, Oriol Vinyals, Mike Schuster, Noam Shazeer, and Yonghui Wu. Exploring the limits of language modeling. arXiv preprint arXiv:1602.02410, 2016.
[16] Łukasz Kaiser and Samy Bengio. Can active memory replace attention? In Advances in Neural Information Processing Systems, (NIPS), 2016.
[17] Łukasz Kaiser and Ilya Sutskever. Neural GPUs learn algorithms. In International Conference on Learning Representations (ICLR), 2016.
[18] Nal Kalchbrenner, Lasse Espeholt, Karen Simonyan, Aaron van den Oord, Alex Graves, and Ko- ray Kavukcuoglu. Neural machine translation in linear time. arXiv preprint arXiv:1610.10099v2, 2017.
[19] Yoon Kim, Carl Denton, Luong Hoang, and Alexander M. Rush. Structured attention networks. In International Conference on Learning Representations, 2017.
[20] Diederik Kingma and Jimmy Ba. Adam: A method for stochastic optimization. In ICLR, 2015.
[21] Oleksii Kuchaiev and Boris Ginsburg. Factorization tricks for LSTM networks. arXiv preprint
arXiv:1703.10722, 2017.
[22] Zhouhan Lin, Minwei Feng, Cicero Nogueira dos Santos, Mo Yu, Bing Xiang, Bowen Zhou, and Yoshua Bengio. A structured self-attentive sentence embedding. arXiv preprint arXiv:1703.03130, 2017.
[23] Minh-Thang Luong, Quoc V. Le, Ilya Sutskever, Oriol Vinyals, and Lukasz Kaiser. Multi-task sequence to sequence learning. arXiv preprint arXiv:1511.06114, 2015.
[24] Minh-Thang Luong, Hieu Pham, and Christopher D Manning. Effective approaches to attention- based neural machine translation. arXiv preprint arXiv:1508.04025, 2015.
[25] Mitchell P Marcus,Mary Ann Marcinkiewicz, and Beatrice Santorini. Building a large annotated corpus of english: The penn treebank. Computational linguistics, 19(2):313–330, 1993.
[26] David McClosky, Eugene Charniak, and Mark Johnson. Effective self-training for parsing. In Proceedings of the Human Language Technology Conference of the NAACL, Main Conference, pages 152–159. ACL, June 2006.11
[27] Ankur Parikh, Oscar Täckström, Dipanjan Das, and Jakob Uszkoreit. A decomposable attention model. In Empirical Methods in Natural Language Processing, 2016.
[28] Romain Paulus, Caiming Xiong, and Richard Socher. A deep reinforced model for abstractive summarization. arXiv preprint arXiv:1705.04304, 2017.
[29] Slav Petrov, Leon Barrett, Romain Thibaux, and Dan Klein. Learning accurate, compact, and interpretable tree annotation. In Proceedings of the 21st International Conference on Computational Linguistics and 44th Annual Meeting of the ACL, pages 433–440. ACL, July 2006.
[30] Ofir Press and Lior Wolf. Using the output embedding to improve language models. arXiv preprint arXiv:1608.05859, 2016.
[31] Rico Sennrich, Barry Haddow, and Alexandra Birch. Neural machine translation of rare words with subword units. arXiv preprint arXiv:1508.07909, 2015.
[32] Noam Shazeer, Azalia Mirhoseini, Krzysztof Maziarz, Andy Davis, Quoc Le, Geoffrey Hinton, and Jeff Dean. Outrageously large neural networks: The sparsely-gated mixture-of-experts layer. arXiv preprint arXiv:1701.06538, 2017.
[33] Nitish Srivastava, Geoffrey E Hinton, Alex Krizhevsky, Ilya Sutskever, and Ruslan Salakhutdi- nov. Dropout: a simple way to prevent neural networks from overfitting. Journal of Machine Learning Research, 15(1):1929–1958, 2014.
[34] Sainbayar Sukhbaatar, Arthur Szlam, Jason Weston, and Rob Fergus. End-to-end memory networks. In C. Cortes, N. D. Lawrence, D. D. Lee, M. Sugiyama, and R. Garnett, editors, Advances in Neural Information Processing Systems 28, pages 2440–2448. Curran Associates, Inc., 2015.
[35] Ilya Sutskever, Oriol Vinyals, and Quoc VV Le. Sequence to sequence learning with neural networks. In Advances in Neural Information Processing Systems, pages 3104–3112, 2014.
[36] Christian Szegedy, Vincent Vanhoucke, Sergey Ioffe, Jonathon Shlens, and Zbigniew Wojna. Rethinking the inception architecture for computer vision. CoRR, abs/1512.00567, 2015.
[37] Vinyals & Kaiser, Koo, Petrov, Sutskever, and Hinton. Grammar as a foreign language. In Advances in Neural Information Processing Systems, 2015.
[38] Yonghui Wu, Mike Schuster, Zhifeng Chen, Quoc V Le, Mohammad Norouzi, Wolfgang Macherey, Maxim Krikun, Yuan Cao, Qin Gao, Klaus Macherey, et al. Google’s neural machine translation system: Bridging the gap between human and machine translation. arXiv preprint arXiv:1609.08144, 2016.
[39] Jie Zhou, Ying Cao, Xuguang Wang, Peng Li, and Wei Xu. Deep recurrent models with fast-forward connections for neural machine translation. CoRR, abs/1606.04199, 2016.
[40] Muhua Zhu, Yue Zhang, Wenliang Chen, Min Zhang, and Jingbo Zhu. Fast and accurate shift-reduce constituent parsing. In Proceedings of the 51st Annual Meeting of the ACL (Volume 1: Long Papers), pages 434–443. ACL, August 2013.
附录
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。