赞
踩
背景
当今的SOTA的预训练语言模型,比如BERT,采用Mask language model(MLM)的方式破坏输入的内容,通过双向语言模型进行预测重构;然而这存在一个问题,那就是MASK这个token在训练中存在但是在实际预测中不存在,为了缓解这个问题,BERT采用了选择语料中15%的TOKEN,在其中80%进行MASK,10%随机替换,10%不变,这的确稍微缓解了训练预测不一致的问题(虽然在XLNet利用permutation language model得到解决),但是确使得BERT必须利用更多的训练语料,需要的算力也大幅增加,为此提出了ELECTRA这个模型解决这个问题
对应的解决方案
为了解决上述说的训练慢,数据要求多的问题,ELECTRA中训练不只是用语料中的subset(即BERT中只是MASK的token)进行预测,而是利用全部的token. 为了达成这个目的,作者训练语言模型的时候不是像bert一样把他看作generator(bert中通过重构被MASK的词,某种程度上可以看成为generator),而是看成discriminator,论文中引入另一个generator去生成相似的词进行替换,训练语言模型的任务就是去判断语料中的每个词是不是被替换了,这里有点对抗学习(GAN)的意思,但是这里并不是用GAN(因为GAN在本文和图片不一样不是连续的,将GAN用在文本生成上有难度)
大意
1.MLM的预训练方法类似于Bert的破坏输入,通过用mask标志替换,并且训练模型去重塑最原始的输
入。当它们在下游任务中产生好的结果,但是需要大量的数据。
2.我们提出一个更有效的预训练任务,方法是从小型生成器里采样的合理的替代品替换一些 token
来破坏输入。然后训练一个判别模型,该模型可以预测损坏的输入中的每个 token是否被生成器样本
取代。
3.我们的模型取和Bert相同的数据量、模型参数,所取得结果要优于Bert系列模型,并且训练的时
间大大的减少。
两个都是transformer的encoder结构,只是两个网络的尺寸不同:
generator-生成器:就是一个小的 masked language model(一般是 1/4 discriminator的size),该模块的具体作用是他采用了经典的bert的MLM方式:
即首先随机选取15%的tokens,替代为MASK token,(取消了bert的80%MASK,10%unchange, 10% random replaced 的操作,原因是因为没必要,因为本文中finetuning使用的discriminator)
使用generator去训练模型,使得模型预测masked token,得到corrupted tokens
generator的目标函数和bert一样,都是希望被masked的能够被还原成原本的original tokens
如上图, token,the 和 cooked 被随机选为被masked,然后generator预测得到corrupted tokens,变成了the和ate
discriminator-判别器:discriminator的接收被generator corrupt之后的输入,discriminator的作用是分辨输入的每一个token是original的还是replaced,如果generator生成的token和原始token一致,那么这个token仍然是original的
所以,对于每个token,discriminator都会进行一个二分类,最后获得loss
以上的方式被称为replaced token detection
生成器部分
该模型采用minimize the combined loss的方式进行训练
公式:
输入:
经过generator运行之后得到编码了上下文信息的vector representation
对于位置 t, 其被替换为MASK,那么它的output probability(经过softmax/逻辑回归从而达到二分类的目的)为
对应的LOSS函数
判别器部分
输入X_correct为生成器替换后的sample序列,hD为transformer,w为权重矩阵
对应的loss函数
整个网络的loss
论文中还提出两种训练方式
但是效果没有共同训练效果好
不比Bert差而且速度快
通过RTD的训练方式大大减少预训练时间
在轻量级模型中有着优异的表现
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。