当前位置:   article > 正文

Transformer-XL全解读_transformerxl

transformerxl

Motivation

Transformer最大的问题在于没有办法建模超过最大长度的序列,例如base bert其支持的序列最大长度是512,超过了该长度的序列需要进行截取,再把截取后的片段分别用bert进行编码,该方法虽然可行,但是存在上下文碎片化的问题,也就是说每个片段是单独建模的,互相之间没有上下文信息,并且,不同的片段位置编码都是从0开始,明显是有问题的。

可见Transformer对于较长的序列建模能力有限,如何解决该弊端就该Transformer-XL大显身手了。

Transformer-XL

Transformer-XL主要提出了两个优化点

  • Segment-Level Recurrence Mechanism 段级递归
  • Relative Positional Encodings 相对位置编码

接下来我们分别看下两个优化点是如何做的

1、Segment-Level Recurrence Mechanism

在讲解第一个优化点之前,我们简单回顾下vanilla transformer,在训练阶段如果要对多个片段编码,其训练过程如下图,可以看到,两个片段没有相互依赖,上下文信息会丢失,不同的片段位置编码一样,因此也不准确。

在这里插入图片描述
再看下inference阶段,对于第一个segment,预测和vanilla版本一样的,跨段预测时(大于第一个片段的序列),由于依赖的上下文长度是固定的,可以理解为使用了一个滑动窗口,每次窗口的值都不一样,所以每次只能预测一个字/词,并且每次都要完整的计算,例如下图中,每个segment长度是4,超过4的部分只能逐字/词计算。

在这里插入图片描述

为了解决固定长度的限制,Transformer-XL提出了一种递归机制,如下图,第一个segment计算完成后,把计算的结果保存下来,在计算第二个片段的时候,把第一个片段的hidden state和第二个片段的hidden state拼接在一起,再进行后续的计算。

在这里插入图片描述
我们看下具体的计算公式,其中h表示的是hidden state, τ \tau τ表示第 τ \tau τ个segment,SG函数表示的是不更新梯度,[]表示的是向量的拼接,第一个公式的意思即:第 τ + 1 \tau +1 τ+1个segment第n-1层的hidden state 等于第 τ \tau τ个segment第n-1层的hidden state拼接上第 τ + 1 \tau +1 τ+1个segment第n-1层的hidden state,后续两个公式和vanilla版本类似,但要注意,q是未拼接的hidden state,k、v是拼接过后的,因为q表示的是当前的segment,所以不需要拼接。

在这里插入图片描述
可以看到,对于第一个segment来说,hidden state是没有额外需要拼接的值的,从第二个segment开始才需要拼接,在论文中,每次都是和上一个segment进行拼接,理论上来说每次可以拼接多个segment,第n个segment可以和前n-1个segment进行拼接,不过这个就取决于你自己的显存了,并且一个segment通常来说不会像上图中的这么短(一个segment可能长度就512了),文本自身的上下文依赖一般也不会超过一个segment的长度。

再看下inference阶段,大于第一个segment的序列,均可以进行批计算,每个批的长度是segment的长度,并且,由于每次都会保存前一个segment的hidden state,所以不需要像vanilla版本重新计算。论文中对比了一下,Transformer-XL在enwiki8数据集上的inference速度是Vanilla Transformer的1800+倍

在这里插入图片描述

2、Relative Positional Encodings

接下来我们来看第二个优化点,相对位置编码。Vanilla Transformer使用的是绝对位置编码,其计算方式如下,pos表示的是token的下标, d m o d e l d_{model} dmodel表示的是hidden size,i表示的是具体的某个维度。

在这里插入图片描述
可见,不同的片段的同一个位置其位置编码都是一样的,模型没办法正确区分不同片段的位置信息,我们再看下Transformer-XL的位置编码是怎么做的。

Vanilla的位置编码是和embedding相加后输入到下一层的,Transformer-XL的位置编码没有在输入上做处理,而是对attention score进行了修改,先回顾下vanilla版本attention score的计算

A a b s = Q W q K W k A^{abs}=QW_q KW_k Aabs=QWqKWk

把Q和K展开,E表示embedding,U表示位置编码

A a b s = ( E q + U q ) W q ( E k + U k ) W k = ( E q W q + U q W q ) ( E k W k + U k W k ) = E q W q E k W k + E q W q U k W k + U q W q E k W k + U q W q U k W k

Aabs=(Eq+Uq)Wq(Ek+Uk)Wk=(EqWq+UqWq)(EkWk+UkWk)=EqWqEkWk+EqWqUkWk+UqWqEkWk+UqWqUkWk
Aabs=(Eq+Uq)Wq(Ek+Uk)Wk=(EqWq+UqWq)(EkWk+UkWk)=EqWqEkWk+EqWqUkWk+UqWqEkWk+UqWqUkWk

即论文中下图的公式

在这里插入图片描述

考虑一下,当query与key进行计算时,实际上并不需要知道key的绝对位置编码,模型实际上需要的是一个“时间线索”即字词的一个先后顺序,因此,知道query与key的相对位置即可。根据以上的思路,Transformer-XL做了三个方面的改进,分别如下

在这里插入图片描述

  • 把b与d中的key的绝对位置编码 U j U_j Uj替换为相对位置编码 R i − j R_{i-j} Rij,表示的是i和j的相对距离, R i − j R_{i-j} Rij是sinusoid encoding matrix,没有额外的训练参数。实际上和vanilla的位置编码一样的,关键是这里的位置编码只给key用,而key的长度,在第一个片段和query的长度一样,之后的片段,key长度=上一个片段hidden state长度+当前片段key的长度,因此 R i − j R_{i-j} Rij是能够表示出key的相对距离的。
  • 因为无论query在序列中的绝对位置如何,其相对于自身的相对位置都是一样的与在序列中的绝对位置无关,应当保持不变.。用两个可训练的参数u、v分别替换c、d中的 U i T W q T U_i^TW_q^T UiTWqT U i T W q T U_i^TW_q^T UiTWqT
  • vanilla版本的key位置编码与embedding都是采用同样的变化矩阵,xl中,把key的位置编码和embedding分别用了不同的线性变化,其中 W k , R W_{k,R} Wk,R对应位置编码, W k , E W_{k,E} Wk,E对应embedding。

在新的参数下,每一项都有了一个具体的含义,a表示的是query与key的内容相关性,b表示的是query的内容和key的位置的相关性,c表示的是query的位置与key的内容的相关性,d表示的是quey与key的位置的相关性

总结一下,对于一个N层1个head的Transformer-XL,其完整步骤如下

在这里插入图片描述
除此之外,论文中对b与d的计算做了一定的优化, R i − j R_{i-j} Rij需要分别计算i与j的值,时间复杂度是 O ( n 2 ) O(n^2) O(n2),优化后能达到 O ( n ) O(n) O(n)

首先定义一个Q矩阵,表示相对位置编码,注意,R是反着来的从 M + L − 1 M+L−1 M+L1到0
在这里插入图片描述

把b的结果展开,实际上是一个 L × ( M + L ) L × (M + L) L×(M+L)的matrix,其中L表示segment的长度,M表示memory的长,结果为

在这里插入图片描述
如果我们定义一个 B ~ = q Q T \widetilde B=qQ^T B =qQT则有

在这里插入图片描述
对比下 B B B B ~ \widetilde B B ,第i行的 B B B实际上就是第i行的 B ~ \widetilde B B 进行了左移,因此,计算 B B B只需要先计算出 B ~ \widetilde B B 然后按行左移。同理,d也可按照相同的方法进行计算

在这里插入图片描述

References

Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/Gausst松鼠会/article/detail/451210
推荐阅读
相关标签
  

闽ICP备14008679号