当前位置:   article > 正文

几个知识蒸馏相关的BERT变体_软目标和硬目标的区别

软目标和硬目标的区别

引言

预训练的BERT模型具有大量的参数,导致它们无法在边缘设备如智能手机上应用。

为了解决这个问题,我们可以基于知识蒸馏(knowledge distillation,KD)从一个大的预训练BERT迁移学习到一个小的BERT模型上。

知识蒸馏简介

知识蒸馏是一种模型压缩(model compression)技术,用于训练小模型来重现大的预训练模型的表现。也成为教师-学生学习(teacher-student learning),其中大的预训练模型就是教师(模型),而小的模型则为学生(模型)。

假设我们预训练了一个大模型来预测句子中下一个单词。我们称该大模型为教师网络。如果我们输入一个句子让该模型来预测该句子的下一个单词,那么它会返回词表中所有单词作为下一个单词的概率分布,如下图所示。这里为了简化,假设词表中只有5个单词。

img

该概率分布主要通过应用输出层的logits到Softmax中得到,然后我们可以选择概率最大的单词作为预测的下一个单词。这里Homework的概率最大,所以输出的下一个单词为Homework

除了选择概率最大的单词外,我们还能从该概率分布中得到什么有用的信息吗?答案是肯定的。基于下图,我们可以看到,除了概率最大的单词外,还有一些单词的概率与其他单词相比也是较大的。即,单词BookAssignment与其他单词如CakeCar相比,它们的概率更大。

img

这说明,除了单词Homework外,BookAssignment也和给定的句子有相关性。这称为暗知识(dark knowledge)。在知识蒸馏期间,我们希望学生能从教师中学到这些暗知识。

听起来不错,但是我们知道,一个好的模型通常对于正确类别会返回一个接近于1的概率,而对于其他类别返回接近与0的概率。确实是这样,考虑下面的例子。假设我们的模型返回了下面这样的概率分布。

img

其中对于单词Homework返回了一个很高的概率,而其他单词返回的概率接近于0。除了正确答案之外,我们无法从该概率分布中获得更多的信息。所以,我们如何在这里抽取暗知识呢?

此时,我们可以使用带有温度的Softmax函数(与蒸馏相呼应),该温度通常成为Softmax温度。我们在输出层中使用该Softmax温度。它用于平滑概率分布,公式如下:
P i = exp ⁡ ( z i / T ) ∑ j exp ⁡ ( z j / T ) P_i = \frac{\exp(z_i/T)}{\sum_j \exp(z_j/T)} Pi=jexp(zj/T)exp(zi/T)
在上面的公式中, T T T就是温度。当 T = 1 T=1 T=1时,就是标准的Softmax函数。增加 T T T值会使分布更平滑,同时带来更多其他类的信息。

比如,下图中,当 T = 1 T=1 T=1时,我们得到了标准Softmax输出的概率分布。当 T = 2 T=2 T=2时,输出的概率分布更加平滑,当 T = 5 T=5 T=5时,概率分布更加平滑。所以通过增加 T T T值,我们可以得到一个平滑的概率分布,它可以给其他类别更多的信息。

img

一定程度内,增加 T T T不会影响输出概率的相对大小,比如Homework T = 5 T=5 T=5时也是最高概率。但是不能无限增大,比如另 T T T接近无穷大,那么就没有意义了。

这样我们通过Softmax温度获取了暗知识。首先我们会预训练教师模型获得暗知识。然后,在知识蒸馏时,我们从教师模型中转移这些暗知识到学生模型。

训练学生网络

上小节中,我们了解了一个预测句子中下一个单词的预训练网络。该预训练网络就是教师网络。现在,我们来学习如何从教师网络中迁移知识到学生网络。注意学生网络不是预训练的,只有教师网络是预训练的,同时是带有Softmax温度的预训练。

正如下图,我们将输入句子喂给教师和学生网络,然后得到概率分布作为输出。我们知道教师网络是预训练的,所以它输出的概率分布就是我们的目标输出。教师网络的输出成为软目标(soft target),由学生网络做的预测成为软预测(soft prediction)。

img

现在,我们计算软目标和软预测之间的交叉熵,然后训练学生网络以最小化该交叉熵损失,该损失称为蒸馏损失(distillation loss)。从下图可以看到,我们将教师和学生网络中的温度 T T T设为同一个大于 1 1 1的值。

img

这样通过反向传播我们就可以最小化蒸馏损失来训练学生网络。除了蒸馏损失外,我们还使用另一个损失,称为学生损失(student loss)。

为了理解学生损失,我们先来理解软目标和硬目标(hard target)之间的区别。如下图所示,由教师网络返回的概率分布成称软目标,而硬目标,我们将最大概率设为1,其他单词设为0。

img

现在,我们来理解软预测和硬预测(hard prediction)的区别。软预测是由基于大于 1 1 1温度的学生网络得到的概率分布,而硬预测是由基于温度 T = 1 T=1 T=1 的学生网络得到的概率分布。即硬预测时使用的就是标准的Softmax函数。

学生损失基本上就是硬目标和硬预测之间的交叉熵损失。下图可以帮助我们理解如何计算学生损失和蒸馏损失。首先,我们来看学生损失。为了计算学生损失,我们在学生网络中使用 T = 1 T=1 T=1的Softmax函数,得到硬预测。而硬目标软目标中概率最大的位置设为 1 1 1,其他位置设为 0 0 0得到的。然后我们将学生损失计算为硬预测硬目标之间的交叉熵。

img

为了计算蒸馏损失,我们使用大于 1 1 1的Softmax函数温度,我们将蒸馏损失计算为软预测软目标之间的交叉熵损失。

我们最终的损失函数是学生损失和蒸馏损失之间的加权和:
L = α ⋅ student loss + β ⋅ distillation loss L = \alpha \cdot \text{student loss} + \beta \cdot \text{distillation loss} L=αstudent loss+βdistillation loss
α \alpha α β \beta β是用于计算学生损失和蒸馏损失之间加权平均的超参数。我们通过最小化上面的损失函数来训练学生网络。

这样,在知识蒸馏中,我们把预训练的网络作为教师网络。我们通过蒸馏训练学生网络获得教师网络的知识。通过最小化上面的损失函数来训练学生网络。

DistilBERT - 蒸馏版的BERT

DistilBERT是一个更小、更快、更便宜、轻量级版本的BERT。它使用了知识蒸馏。DistilBERT的最终思想是,我们采用一个预先训练好的大型BERT模型,通过知识蒸馏将其知识转移到一个小型BERT。

大型预训练BERT称为教师BERT(teacher BERT),而小型BERT称为学生BERT(student BERT)。

DistilBERT比大型BERT快60%,同时小40%。

教师-学生结构

我们先来理解这种教师-学生结构。

教师BERT

教师BERT是一个大型预训练BERT模型。我们使用预训练的BERT-base模型作为教师。

因为BERT是使用掩码语言建模任务进行预训练的,我们可以使用预训练的BERT模型来预测掩码单词。

img

上图就是BERT做掩码建模任务的过程,输入一句话,它可以输出被掩码单词属于词表中每个单词的概率分布。该概率分布包含我们需要转移到学生BERT中的暗知识。

学生BERT

与教师BERT不同,学生BERT不是预训练好的。学生BERT需要从教师BERT中学习。相比教师BERT,学生BERT包含的网络层数更少。教师BERT包含110M个参数,而学生BERT只包含66M个参数。

因为学生BERT中包含更少的网络层,与教师BERT(BERT-base)相比,它能训练得更快。

DistilBERT的作者将学生BERT隐藏状态维度设为768,与教师BERT相同。他们发现减少学生BERT的维度对于计算性能没有太大的影响。所以他们关注于减少网络层数。

训练学生BERT

我们可以使用和预训练教师BERT时一样的数据集来训练学生BERT。

这里我们从RoBERTa中借鉴一些策略,比如我们只训练掩码语言建模任务,并在该任务中,我们使用动态掩码(dynamic masking),同时我们也采用较大的批大小。

如下图所示,我们将掩码句子喂给教师BERT和学生BERT,分别得到一个基于词表的概率分布输出。接着,我们计算软目标和软预测之间的蒸馏损失和交叉熵损失。

2b3bc2b5-eb3e-4676-8194-38c7c6659b41

除了蒸馏损失,我们还计算学生损失,它是掩码语言建模损失,即,硬目标(真实标签)和硬预测( T = 1 T=1 T=1的标准Softmax预测)之间的交叉熵损失,如下图所示:

3c055ff5-f726-46f7-8f84-172052fac1cd

除此之外,我们还计算余弦嵌入损失(cosine embedding loss)。它基本上是教师和学生BERT所学的表示之间的距离度量。最小化余弦嵌入损失使得学生的表示更加准确,更接近于教师的嵌入。

这样,我们最终的损失函数为下列三个损失之和:

  • 蒸馏损失
  • 掩码语言建模损失(学生损失)
  • 余弦嵌入损失

通过最小化上面三个损失之和来训练我们的学生BERT(DistilBERT)。在训练之后,我们的学生BERT会获得教师BERT的知识。

DistilBERT为我们提供了接近97%的原始BERT-Base模型的准确结果,同时推理速度快了60%。因为DistilBERT更轻量,所以我们可以很容易地将它部署在边缘设备上。

DistilBERT在8块16G V100 GPU上训练了近90个小时。预训练好的DistilBERT已经由声明:本文内容由网友自发贡献,转载请注明出处:【wpsshop】

推荐阅读
相关标签