赞
踩
@TOC
知识蒸馏(KD)是一种减少大语言模型(LLMs)高计算需求的有效技术。然而,之前的知识蒸馏方法主要应用于白盒分类模型或训练小模型来模仿黑盒模型API(如ChatGPT)。如何有效地将白盒大语言模型的知识蒸馏到小模型中仍未得到充分探索,这在开源大语言模型兴起的背景下变得尤为重要。在这项工作中,我们提出了一种知识蒸馏方法,将大语言模型蒸馏到较小的语言模型中。我们首先将标准知识蒸馏方法中的正向Kullback-Leibler散度(KLD)目标替换为反向KLD,这更适用于生成语言模型的知识蒸馏,防止学生模型高估教师分布的低概率区域。然后,我们提出了一种有效的优化方法来学习这一目标。学生模型被命名为MINILLM。在指令跟随设置中的大量实验表明,MINILLM生成的响应更精确,整体质量更高,暴露偏差更低,长文本生成性能更好。
知识蒸馏(KD; Hinton et al., 2015),即使用大模型(教师模型)的监督来训练一个小模型(学生模型)。
通常应用的知识蒸馏有两类:黑盒知识蒸馏,教师生成的文本是唯一可访问的;
白盒知识蒸馏,教师模型的输出分布或中间隐藏状态也是可用的(Jianping et al., 2021)。
给定教师分布 ( p ( y ∣ x ) ) ( p(y|x) ) (p(y∣x))和由参数 ( θ ) ( \theta ) (θ)表示的学生分布 ( q θ ( y ∣ x ) ) (q_{\theta}(y|x)) (qθ(y∣x)),标准的知识蒸馏目标(包括几个序列级模型的变体)本质上是最小化教师和学生分布之间的近似正向Kullback-Leibler散度(KLD),称为 ( K L [ p ∣ ∣ q θ ] ) ( KL[p||q_\theta]) (KL[p∣∣qθ]),这迫使 ( q θ ) ( q_\theta) (qθ) 覆盖 ( p ) ( p ) (p)的所有模式。对于文本分类任务, ( K L [ p ∣ ∣ q θ ] ) ( KL[p||q_\theta] ) (KL[p∣∣qθ]) 运作良好,因为输出空间通常由有限数量的类别组成,使得 ( p ( y ∣ x ) ) ( p(y|x) ) (p(y∣x)) 和 ( q θ ( y ∣ x ) ) ( q_\theta(y|x) ) (qθ(y∣x)) 具有很少的模式。然而,对于开放式文本生成任务,输出空间要复杂得多, ( p ( y ∣ x ) ) (p(y|x)) (p(y∣x)) 可以包含比 ( q θ ( y ∣ x ) ) (q_\theta(y|x)) (qθ(y∣x)) 能表达的更多模式,因为模型容量有限。最小化正向KLD会导致 ( q θ ) ( q_\theta) (qθ) 向 ( p ) ( p) (p) 的空白区域分配不合理的高概率,并在自由运行生成期间生成在 ( p ) ( p ) (p) 下非常不可能的样本(Huszár, 2015)。
为了缓解这个问题,我们提出最小化反向KLD,即
(
K
L
[
q
θ
∣
∣
p
]
)
( KL[q_\theta||p])
(KL[qθ∣∣p]),这在计算机视觉(Lee et al., 2023)和强化学习(Czarnecki et al., 2019)中广泛使用。与
K
L
[
p
∣
∣
q
θ
]
KL[p||q_\theta]
KL[p∣∣qθ] 相比,最小化
K
L
[
q
θ
∣
∣
p
]
KL[q_\theta||p]
KL[qθ∣∣p] 使 $q_\theta $ 寻找
(
p
)
( p )
(p) 的主要模式,并对
(
p
)
( p )
(p) 的空白区域分配低概率。为了优化
min
θ
K
L
[
q
θ
∣
∣
p
]
\min_\theta KL[q_\theta||p]
minθKL[qθ∣∣p],如第2.2节所示,我们使用策略梯度(Policy Gradient)推导了目标的梯度。为了进一步稳定和加速训练,我们提出了
(1)单步分解以减少方差,
(2)教师混合采样以减轻奖励作弊,
(3)长度归一化以消除长度偏差。
我们考虑条件文本生成,其中模型根据从分布
p
x
p_x
px中抽样的提示
x
x
x生成响应
y
=
{
y
t
}
t
=
1
T
y = \{ y_t \}_{t=1}^T
y={yt}t=1T。我们将知识蒸馏表述为一个优化问题,目标是最小化固定教师模型分布
p
(
y
∣
x
)
p(y|x)
p(y∣x)和学生模型分布
q
θ
(
y
∣
x
)
q_\theta(y|x)
qθ(y∣x)由参数
θ
\theta
θ 表示)之间的差异。标准的知识蒸馏方法近似最小化正向KLD:
K
L
[
p
∣
∣
q
θ
]
=
E
x
∼
p
x
,
y
∼
p
′
log
p
(
y
∣
x
)
q
θ
(
y
∣
x
)
KL[p||q_\theta] = \mathbb{E}_{x \sim p_x, y \sim p'} \log \frac{p(y|x)}{q_\theta(y|x)}
KL[p∣∣qθ]=Ex∼px,y∼p′logqθ(y∣x)p(y∣x),其中
p
′
p'
p′可以是真实数据分布(词级别知识蒸馏)或教师分布
p
p
p(序列级别知识蒸馏)。
我们考虑最小化学生和教师模型分布之间的反向KLD作为MINILLM的学习目标:
[ θ = arg min θ L ( θ ) = arg min θ K L [ q θ ∣ ∣ p ] = arg min θ [ − E x ∼ p x , y ∼ q θ log p ( y ∣ x ) q θ ( y ∣ x ) ] ] [ \theta = \arg \min_\theta L(\theta) = \arg \min_\theta KL[q_\theta||p] = \arg \min_\theta \left[ - \mathbb{E}_{x \sim p_x, y \sim q_\theta} \log \frac{p(y|x)}{q_\theta(y|x)} \right] ] [θ=argθminL(θ)=argθminKL[qθ∣∣p]=argθmin[−Ex∼px,y∼qθlogqθ(y∣x)p(y∣x)]]
最小化反向KLD已被证明会导致生成建模中的模式寻求行为,即 q θ q_\theta qθ对 p p p 的主要模式赋予高概率并忽略小模式如图2)。
在本研究中,我们首先研究了这种特性在大语言模型文本生成中的知识蒸馏问题中。最小化正向KLD导致 q θ q_\theta qθ在 p p p的零概率区域上分配大量概率,这对应于实际中低质量的文本生成,而反向KLD则专注于 p p p 的主要模式,这对于确保文本生成的正确性和真实性至关重要。
如图3所示,与序列级别知识蒸馏不同,MINILLM通过最小化反向KLD不会强迫 q θ q_\theta qθ拟合从教师分布 p p p 中采样的所有 y y y。相反,它鼓励学生在其自身能力范围内生成教师偏好的样本,这更有可能实现。有趣的是,我们还发现了理解MINILLM的另一种视角,这一视角源于逆强化学习(Ziebart et al., 2008)。
梯度推导 我们注意到目标函数 L ( θ ) L(\theta) L(θ) 在公式(1)中的梯度可以使用策略梯度定理推导出来:
∇ L ( θ ) = − E x ∼ p x , y ∼ q θ ( ⋅ ∣ x ) ∑ t = 1 T ( R t − 1 ) ∇ log q θ ( y t ∣ y < t , x ) \nabla L(\theta) = - \mathbb{E}_{x \sim p_x, y \sim q_\theta(\cdot|x)} \sum_{t=1}^T (R_t - 1) \nabla \log q_\theta(y_t|y_{<t}, x) ∇L(θ)=−Ex∼px,y∼qθ(⋅∣x)t=1∑T(Rt−1)∇logqθ(yt∣y<t,x)
其中 T = ∣ y ∣ T = |y| T=∣y∣ 且 R t = ∑ t ′ = t T log p ( y t ′ ∣ y < t ′ , x ) q θ ( y t ′ ∣ y < t ′ , x ) R_t = \sum_{t'=t}^T \log \frac{p(y_{t'}|y_{<t'}, x)}{q_\theta(y_{t'}|y_{<t'}, x)} Rt=∑t′=tTlogqθ(yt′∣y<t′,x)p(yt′∣y<t′,x) 是 r t ′ = log p ( y t ′ ∣ y < t ′ , x ) q θ ( y t ′ ∣ y < t ′ , x ) r_{t'} = \log \frac{p(y_{t'}|y_{<t'}, x)}{q_\theta(y_{t'}|y_{<t'}, x)} rt′=logqθ(yt′∣y<t′,x)p(yt′∣y<t′,x) 的累积,度量每一步生成的质量。
直观上,生成的文本应该在教师分布下具有高概率,通过增加 p ( y t ′ ∣ y < t ′ , x ) p(y_{t'}|y_{<t'}, x) p(yt′∣y<t′,x),但同时通过降低 ( q θ ( y t ′ ∣ y < t ′ , x ) ) ( q_\theta(y_{t'}|y_{<t'}, x) ) (qθ(yt′∣y<t′,x)) 保持多样性。公式(2)中的期望通过蒙特卡罗采样计算。
然而,策略梯度在高方差和奖励作弊方面存在问题,尽管有一些后续解决方案。此外,我们注意到 R t R_t Rt 偏爱短句子,这导致学生模型输出空响应。因此,我们提出了三种策略来缓解这些问题。
单步分解 Czarnecki et al. (2019) 发现单步生成质量 r t r_t rt对训练方差至关重要,因为前面词元的错误会累积整个句子。为了更多关注 r t r_t rt,我们重新编写 ∇ L ( θ ) \nabla L(\theta) ∇L(θ) 来分解 r t r_t rt和 R t R_t Rt并直接计算 E y t ∼ q θ ( t ) [ r t ] \mathbb{E}_{y_t \sim q_\theta(t)}[r_t] Eyt∼qθ(t)[rt]的梯度:
∇ L ( θ ) = E x ∼ p x , y ∼ q θ ( ⋅ ∣ x ) [ − ∑ t = 1 T ∇ E y t ∼ q θ ( t ) [ r t ] ] + E x ∼ p x , y ∼ q θ ( ⋅ ∣ x ) [ − ∑ t = 1 T R t + 1 ∇ log q θ ( y t ∣ y < t , x ) ] = ( ∇ L ) Single + ( ∇ L ) Long \nabla L(\theta) = \mathbb{E}_{x \sim p_x, y \sim q_\theta(\cdot|x)} \left[ - \sum_{t=1}^T \nabla \mathbb{E}_{y_t \sim q_\theta(t)}[r_t] \right] + \mathbb{E}_{x \sim p_x, y \sim q_\theta(\cdot|x)} \left[ - \sum_{t=1}^T R_{t+1} \nabla \log q_\theta(y_t|y_{<t}, x) \right] = (\nabla L)_{\text{Single}} + (\nabla L)_{\text{Long}} ∇L(θ)=Ex∼px,y∼qθ(⋅∣x)[−t=1∑T∇Eyt∼qθ(t)[rt]]+Ex∼px,y∼qθ(⋅∣x)[−t=1∑TRt+1∇logqθ(yt∣y<t,x)]=(∇L)Single+(∇L)Long
其中 ( q_\theta(t) = q_\theta(\cdot|y_{<t}, x) )。注意 ( \mathbb{E}{y_t \sim q\theta(t)}[r_t] ) 可以通过对词汇表求和直接计算,而不是使用蒙特卡罗采样,并且可以对 ( \theta ) 求导。此分解提供了更精确和高效的单步生成质量估计,减少了训练中的方差,加速了收敛。
教师混合采样 我们在使用公式(2)进行训练时观察到奖励作弊,因为 q θ q_\theta qθ有时在采样期间生成退化的句子 y y y,这些句子在教师处获得高分(例如重复短语),特别是在小学生模型的情况下。为了创建更好的采样分布,我们在每个时间步将教师和学生分布混合:
p ~ ( y t ∣ y < t , x ) = α ⋅ p ( y t ∣ y < t , x ) + ( 1 − α ) ⋅ q θ ( y t ∣ y < t , x ) \tilde{p}(y_t|y_{<t}, x) = \alpha \cdot p(y_t|y_{<t}, x) + (1 - \alpha) \cdot q_\theta(y_t|y_{<t}, x) p~(yt∣y<t,x)=α⋅p(yt∣y<t,x)+(1−α)⋅qθ(yt∣y<t,x)
其中 α \alpha α控制教师混合的强度。从 p ~ \tilde{p} p~中采样可以在教师的帮助下抑制低质量生成并缓解奖励作弊。我们使用重要性采样重新编写 ( ∇ L ) Single (\nabla L)_{\text{Single}} (∇L)Single和 ( ∇ L ) Long (\nabla L)_{\text{Long}} (∇L)Long以获得梯度的无偏估计(Precup et al., 2000):
( ∇ L ) Single = − E x ∼ p x , y ∼ p ~ ( ⋅ ∣ x ) [ ∑ t = 1 T w t ∇ E y t ∼ q θ ( t ) [ r t ] ] (\nabla L)_{\text{Single}} = - \mathbb{E}_{x \sim p_x, y \sim \tilde{p}(\cdot|x)} \left[ \sum_{t=1}^T w_t \nabla \mathbb{E}_{y_t \sim q_\theta(t)}[r_t] \right] (∇L)Single=−Ex∼px,y∼p~(⋅∣x)[t=1∑Twt∇Eyt∼qθ(t)[rt]]
( ∇ L ) Long = − E x ∼ p x , y ∼ p ~ ( ⋅ ∣ x ) [ ∑ t = 1 T w t R t + 1 ∇ log q θ ( y t ∣ y < t , x ) ] (\nabla L)_{\text{Long}} = - \mathbb{E}_{x \sim p_x, y \sim \tilde{p}(\cdot|x)} \left[ \sum_{t=1}^T w_t R_{t+1} \nabla \log q_\theta(y_t|y_{<t}, x) \right] (∇L)Long=−Ex∼px,y∼p~(⋅∣x)[t=1∑TwtRt+1∇logqθ(yt∣y<t,x)]
其中 w t = ∏ t ′ = 1 t q θ ( y t ′ ∣ y < t ′ , x ) p ~ ( y t ′ ∣ y < t ′ , x ) w_t = \prod_{t'=1}^t \frac{q_\theta(y_{t'}|y_{<t'}, x)}{\tilde{p}(y_{t'}|y_{<t'}, x)} wt=∏t′=1tp~(yt′∣y<t′,x)qθ(yt′∣y<t′,x)是重要性权重。然而,实践中 w t w_t wt会带来高方差,因为它需要在多个时间步上相乘每个词元的重要性权重,因此每步的方差会累积。因此,我们近似设置 w t ≈ q θ ( y t ∣ y < t , x ) p ~ ( y t ∣ y < t , x ) w_t \approx \frac{q_\theta(y_t|y_{<t}, x)}{\tilde{p}(y_t|y_{<t}, x)} wt≈p~(yt∣y<t,x)qθ(yt∣y<t,x)以减少公式(5)中估计量的方差。
长度归一化 我们发现长序列倾向于具有较小的 ( R_{t+1} ),这鼓励模型生成短响应。因此,我们在公式(3)中的 ( R_{t+1} ) 中加入长度归一化:
R t + 1 Norm = 1 T − t − 1 ∑ t ′ = t + 1 T log p ( y t ′ ∣ y < t ′ , x ) q θ ( y t ′ ∣ y < t ′ , x ) R^{\text{Norm}}_{t+1} = \frac{1}{T - t - 1} \sum_{t'=t+1}^T \log \frac{p(y_{t'}|y_{<t'}, x)}{q_\theta(y_{t'}|y_{<t'}, x)} Rt+1Norm=T−t−11t′=t+1∑Tlogqθ(yt′∣y<t′,x)p(yt′∣y<t′,x)
总结 结合上述策略,我们得到最终的优化梯度:
∇ L ( θ ) = − E x ∼ p x , y ∼ p ~ ( ⋅ ∣ x ) [ ∑ t = 1 T w t [ ∇ ∑ y ′ ∈ V q θ ( y ′ ∣ y < t , x ) log p ( y ′ ∣ y < t , x ) q θ ( y ′ ∣ y < t , x ) ] Single + R t + 1 Norm ∇ log q θ ( y t ∣ y < t , x ) ] Long \nabla L(\theta) = - \mathbb{E}_{x \sim p_x, y \sim \tilde{p}(\cdot|x)} \left[ \sum_{t=1}^T w_t \left[ \nabla \sum_{y' \in V} q_\theta(y'|y_{<t}, x) \log \frac{p(y'|y_{<t}, x)}{q_\theta(y'|y_{<t}, x)} \right]_{\text{Single}} + R^{\text{Norm}}_{t+1} \nabla \log q_\theta(y_t|y_{<t}, x) \right]_{\text{Long}} ∇L(θ)=−Ex∼px,y∼p~(⋅∣x) t=1∑Twt ∇y′∈V∑qθ(y′∣y<t,x)logqθ(y′∣y<t,x)p(y′∣y<t,x) Single+Rt+1Norm∇logqθ(yt∣y<t,x) Long
其中 V V V是语言模型的词汇表大小, ( ∇ L ) Long Norm (\nabla L)^{\text{Norm}}_{\text{Long}} (∇L)LongNorm是 R t + 1 Norm R^{\text{Norm}}_{t+1} Rt+1Norm 归一化后的 ( ∇ L ) Long (\nabla L)_{\text{Long}} (∇L)Long。
我们从一个在大规模长文档语料库DPT上预训练的学生模型开始。训练MINILLM的算法是使用数据集D将学生模型调整到一个文本生成任务。我们假设有一个在D上表现良好的教师模型,如在D上微调的大语言模型(Taori et al., 2023; Chiang et al., 2023)或具有良好任务泛化性的模型(Chung et al., 2022; OpenAI, 2023)。
在训练算法中,我们首先在D上微调学生模型,并选择损失最低的检查点作为后续训练的初始化。然后,我们基于公式(5)和公式(6)计算 ( ∇ L ) Single (\nabla L)_{\text{Single}} (∇L)Single和 ( ∇ L ) Long Norm (\nabla L)^{\text{Norm}}_{\text{Long}} (∇L)LongNorm的梯度,并加入剪切策略(Schulman et al., 2017)进一步提高稳定性。与Ouyang等(2022)相同,我们包括语言建模损失 L P T = − E d ∼ D P T log q θ ( d ) L_{PT} = -\mathbb{E}_{d \sim DPT} \log q_\theta(d) LPT=−Ed∼DPTlogqθ(d)以保持模型在标准自然语言处理基准上的表现。学生模型最终使用组合梯度 ( ∇ L ) Single + ( ∇ L ) Long Norm + ∇ L P T (\nabla L)_{\text{Single}} + (\nabla L)^{\text{Norm}}_{\text{Long}} + \nabla L_{PT} (∇L)Single+(∇L)LongNorm+∇LPT进行更新。整个训练过程类似于从人类反馈中学习的强化学习(RLHF;Ouyang等,2022)。MINILLM训练算法的详细信息见附录B。
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。