当前位置:   article > 正文

XLNet:通用自回归预训练方法_自回归训练

自回归训练


BERT(前文有 介绍)火了以后 XLNet算是首个真正意义上能和其叫板的工作了。在20个任务上都超过BERT,其中很多还是大幅的超越。

AR和AE

作者首先对今天NLP的主流预训练方法进行了分类:自回归语言模型(AR)和自编码(AE)。这样就把ULMFit,ELMo, GPT,GPT2这些依靠传统的语言模型进行预训练的方法分成了一类(AR)。大名鼎鼎的BERT在它的独创性的MLM(Masked LM)中,利用corrupted版本的输入(即用[MASK]来遮住一些token的输入)来恢复原来的tokens,这本质上是denoising autoencoder,所以BERT属于自编码流派。

这两种方法到底孰优孰劣?作者提到了BERT在利用上下文的信息上有很大的灵活性,它不像AR语言模型只能利用单向的信息(或者forward或者backward),所以在很多下游的语言任务中有很大的提升。

然而作者也指出了BERT的两个缺陷:

  1. [MASK]在fine tuning的时候不存在,这导致了预训练-微调不一致性(pretrain-finetune discrepancy)。(注:BERT的原始论文也提到过这个问题,并且相应的有采取处理手段,所以我个人觉得这未必是个大问题。
  2. 被预测的tokens都在输入中用[MASK]替代了,这意味着BERT假设这些tokens(在其他unmasked tokens存在的情况下)是相互独立的。(这个有点拗口,其实意思就是说如果同时mask了多个tokens,那么除了被预测的那个词,其他的被mask的词在训练时也用不上。)

我个人的看法:AE方法最大的一个限制是不方便当做decoder来使用(从名字上看,它是个auto “encoder”而不是decoder),所以在文本生成类的任务上不好用。比如BERT,它在预训练的时候利用了所有的环境信息,但是在生成文本的时候不可能去“向后看”得到未来的信息,这也算是一种训练和推断的不一致。但是像AR这一类利用LM预训练出的模型就很容易拿来作为decoder使用。

XLNet的比较优势

XLNet作为一种广义上的自回归方法,融合了AR和AE,取长补短,融汇贯通,成功的保留了两个流派的优点,并且避免了它们的局限。
下面是作者总结的XLNet的优点:

  1. XLNet把一个序列所有可能的排列都拿来作为LM的输入,这使得每一个位置上都能够利用到所有其他的位置的信息,从而真正的捕获了上下文。
  2. XLNet作为AR语言模型,不再依赖于data corruption。从而避免了上面提到的BRRT的两个缺陷。

另外,XLNet在架构上利用了Transformer-XLTransformer-XL的创新之处在于它的segment recurrence机制和相对位置编码方法,这带来了它在处理长文本上效果的提升。

Transformer-XL

Transformer-XL值得单独拿出来讲一讲,我觉得在XLNet的成功一定会带动未来更多的工作采用Transformer-XL。它作为Transformer的改进版有逐渐取而代之的可能。篇幅所限,这里仅从high level上解释一下直觉上的意义。以下用XL代替Transformer-XL。

Transformer的特点和缺陷

首先我们得重新审视一下Transformer。它利用self-attenetion机制来产生long-range dependency,从而避免了LSTM里的recurrent的机制带来的vanishing/exploding gradient的问题。为了同时保有序列性,它引入了位置编码。而这些都完美的避开了LSTM里的序列性的计算,使其更易于并行化。

然而魔鬼在细节,如果我们审视一下Transformer的计算复杂度:(下图来自Transformer的原始论文attention is all you need)
在这里插入图片描述这里的n可以理解为输入的长度,d是每个token对应的表征维度,那么Self-Attention对应的复杂度是 O ( n 2 d ) O(n^2 d) O(n2d) (因为每个token都要attend to每一个其他的token),看上去竟然远大于Recurrent(即LSTM)类的 O ( n d 2 ) O(nd^2) O(nd2)
看样子,Transformer的计算效率比Recurrent类型的神经网络还要低!

在实际应用中,情况不是这样。因为d的取值往往远大于n。句子的长度一般是64或者128,但是d往往可以很大比如512或者1024,这样的话 O ( n 2 d ) O(n^2d) O(n2d)远小于 O ( n d 2 ) O(nd^2) O(nd2),Transformer的复杂度的确小于Recurrent Networks。

Github上发布的BERT Base和Large应该是用512的序列长度做的预训练,这已经是非常巨大的参数了。即使是直接把发布的BERT模型拿来(或经过fine tuning后)使用,在对inference的速度有要求的工业界,相信绝大部分人会有针对性的选择小的多的max sequence length。

如果我们能够理解Transformer在复杂度上的特点,我们也很容易理解它的缺陷了。那就是context fragmentation。简单说,就是Transformer只能选择固定长度的连续tokens做计算(根据前面的分析,这个固定长度往往有限),不能考虑到句子或者其他任何语义边界,从而缺乏必要的语境信息,这必然带来优化问题。看下面的例子。

语言模型里的Transformer

在这里插入图片描述这是一个语言模型。在训练阶段,信息不能在不同的分段(segment)之间流动。
作者提到这种训练方式会导致两个问题:

首先,最大可能的dependency length被分段长度给限制住了,这导致了模型不能够充分利用self-attention机制的优势。

注意图1(a)里的 x 5 x_5 x5,它和前一个分段里的 x 1 x_1 x1 x 4 x_4 x4没有任何连接。前面的任何内容,在这个分段都不会存在任何记忆。

第二,如前所述,这种做法没有照顾到句子或者其他形式的语义边界,带来了语境碎片化问题(context fragmentation)。

如果图1(a)中的 x 1 x_1 x1 x 8 x_8 x8恰好是一个独立的语义单位,比如说是一个完整的句子,那么上面的分段就导致了语境的碎片化。

在evaluation阶段,该语言模型每次向右移动一个单位,这种方式效率非常低。因为每一次移动都要重新进行处理当前的segment,而前面提到过,每层的计算复杂度是 O ( n 2 d ) O(n^2d) O(n2d)

XL是什么

XL就是extremly long的意思。它旨在克服上面提到的Transformer的困难,从而使之能够处理非常长的信息。一句话总结XL:它是带有recurrence机制的Transformer。

原始论文里的图:
在这里插入图片描述图(a)中的New Segment部分指的是当前正在进行处理的部分,阴影部分指的是在上一个时间步处理过的部分。显然这个处理过的部分也参与到了当前的计算。绿色的连接代表参与的方式:即隐藏层序列被固定在内存中,作为extended context传给当前部分。

用数学公式表示更清楚一些:
假设有两个连续的长度都为L的分段 s τ = [ x τ , 1 , . . . , x τ , L ] s_{\tau}=[x_{\tau,1},...,x_{\tau,L}] sτ=[xτ,1,...,xτ,L] s τ + 1 = [ x τ + 1 , 1 , . . . , x τ + 1 , L ] s_{\tau+1}=[x_{\tau+1,1},...,x_{\tau+1,L}] sτ+1=[xτ+1,1

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/AllinToyou/article/detail/337017
推荐阅读
相关标签
  

闽ICP备14008679号