赞
踩
听说过数据扩增(Data Augmentation),也听说过虚拟对抗训练(Virtual Adversarial Traning),但是我没想到会有人将其结合,谓之虚拟数据扩增(Virtual Data Augmentation)。这篇文章主要讲解EMNLP2021上的一篇论文Virtual Data Augmentation: A Robust and General Framework for Fine-tuning Pre-trained Models,该论文提出了一种鲁棒且通用的数据扩增方法,论文源码在https://github.com/RUCAIBox/VDA
论文开篇提到目前数据扩增存在的主要问题:产生数据多样性的同时如何保证其仍然在同一个语义空间中?简单地说,增强数据扩增的多样性很容易,核心就一个字:“乱”,例如许多数据扩增方法会随机打乱一个句子中token的位置,或者是随机删除某些token,随机插入某些token。这样虽然增强了样本的多样性,但是语义可能也会产生非常大的变化,甚至不再与原样本的语义相同。保持语义不变,或者说保证扩增后的样本和原样本在同一个语义空间中很容易,核心就是:“不要太乱”,例如通过同义词替换等,这种方法可以做到几乎不改变语义,但是数据多样性却不够,因为本质上还是同一句话
这两个需求实际上是矛盾的,我们所能做的只是尽力达到某种平衡。具体来说,作者所提出的方法包含两个重要部分:Embedding Augmentation以及Regularized Training
假设现在我们有句子「Time is enough for test」,对于每个位置的token,我们都可以将其替换为[MASK],然后通过MLM预测Vocabulary中所有token在该位置的概率,例如
[MASK] is enough for test
[MASK]位置输出的token及其概率为
Time p=0.5
Day p=0.3
Hours p=0.15
...
再比如
Times is enough for [MASK]
[MASK]位置输出的token及其概率为
test p=0.5
evaluation p=0.3
experiment p=0.1
...
看到这里大家脑海中可能已经有了一个数据扩增的想法,就是利用MLM任务对句子中每个位置的token进行预测,然后根据预测概率随机挑选出一个token进行替换,例如上面的句子可能就会被替换为「Hours is enough for evaluation」。这确实是一种还不错的数据扩增方法,但是论文作者却并不是这么做的
为了描述简单,我们仅讨论对于给定句子
S
S
S中的一个token
w
~
\tilde{w}
w~进行扩增的情况(实际上句子
S
S
S中的所有token都会进行该操作),通过MLM任务我们可以预测出Vocabulary中所有单词在
w
~
\tilde{w}
w~位置的概率
{
p
(
w
^
1
∣
S
)
,
.
.
.
,
p
(
w
^
V
∣
S
)
}
(1)
\{p(\hat{w}_1\mid S),...,p(\hat{w}_V\mid S)\}\tag{1}
{p(w^1∣S),...,p(w^V∣S)}(1)
其中,
V
V
V是Vocabulary中的token数量
为了增强数据扩增的多样性,或者说引入某些噪声以增强抗干扰性,我们从高斯分布中随机采样出一个向量
ϵ
∼
N
(
0
,
σ
2
)
(2)
\epsilon \sim \mathcal{N}(0, \sigma^2)\tag{2}
ϵ∼N(0,σ2)(2)
将该向量与公式(1)的概率分布进行混合,我们可以得到一个新的概率分布
p
′
(
w
^
i
∣
S
)
=
Softmax
(
p
(
w
^
i
∣
S
)
+
ϵ
)
(3)
p'(\hat{w}_i\mid S) = \text{Softmax}(p(\hat{w}_i\mid S) + \epsilon)\tag{3}
p′(w^i∣S)=Softmax(p(w^i∣S)+ϵ)(3)
然后对于每个即将被替换的token
w
~
\tilde{w}
w~,我们根据概率
p
′
(
w
^
i
∣
S
)
p'(\hat{w}_i\mid S)
p′(w^i∣S)加权融合所有token
w
^
i
\hat{w}_i
w^i的Embedding向量
e
^
w
~
=
p
w
~
⋅
M
E
(4)
\hat{\mathbf{e}}_{\tilde{w}}=\mathbf{p}_{\tilde{w}}\cdot\mathbf{M}_E\tag{4}
e^w~=pw~⋅ME(4)
其中,
p
w
~
=
{
p
′
(
w
^
i
∣
S
)
}
i
=
1
V
\mathbf{p}_{\tilde{w}}=\{p'(\hat{w}_i\mid S)\}_{i=1}^V
pw~={p′(w^i∣S)}i=1V,
M
E
∈
R
V
×
d
\mathbf{M}_E\in \mathbb{R}^{V\times d}
ME∈RV×d是MLM模型的词向量矩阵
举个简单的例子解释一下,为了方便,同样还是以替换一个token为例,并且整个Vocabulary只有4个token,词向量的维度为2。首先我们有一句话「She is a good student」,将「good」进行MASK,然后通过MLM模型,预测出概率分布为
p
(
w
^
i
∣
S
)
=
[
0.5
,
0.1
,
0.1
,
0.3
]
p(\hat{w}_i\mid S)=[0.5, 0.1, 0.1, 0.3]
p(w^i∣S)=[0.5,0.1,0.1,0.3]
从左到右分别是good, perfect, excellent, smart的概率,根据高斯分布
N
(
0
,
σ
2
)
\mathcal{N}(0, \sigma^2)
N(0,σ2)随机产生的向量为
ϵ
=
[
−
0.1
,
0.1
,
0.1
,
−
0.1
]
\epsilon = [-0.1, 0.1, 0.1, -0.1]
ϵ=[−0.1,0.1,0.1,−0.1]
这里我并没有具体指明方差 σ 2 \sigma^2 σ2到底是多少,因为我懒得算
将
p
(
w
^
i
∣
S
)
p(\hat{w}_i\mid S)
p(w^i∣S)与
ϵ
\epsilon
ϵ混合后进行Softmax得到新的概率分布为
p
′
(
w
^
i
∣
S
)
=
[
0.4
,
0.2
,
0.2
,
0.2
]
p'(\hat{w}_i\mid S) = [0.4, 0.2, 0.2, 0.2]
p′(w^i∣S)=[0.4,0.2,0.2,0.2]
假设Embedding矩阵为
M
E
=
[
0.2
,
0.3
0.1
,
0.5
0.4
,
0.2
0.1
,
0.4
]
\mathbf{M}_E =
那么最终「good」这个位置对应的embedding为
e
^
w
~
=
p
′
(
w
^
i
∣
S
)
⋅
M
E
=
[
0.4
0.2
0.2
0.2
]
T
⋅
[
0.2
,
0.3
0.1
,
0.5
0.4
,
0.2
0.1
,
0.4
]
=
[
0.2
,
0.34
]
到此为止,不知道大家有没有体会到什么叫「Virtual Data Augmentation」,Virtual本质上就是不用一个真实的token去替换,而是使用一个embedding去替换,而如果你用这个embedding去反查
M
E
\mathbf{M}_E
ME矩阵一般是找不到对应的索引的,也就是说我们生成的这个embedding并不对应一个实际存在的token
标题起的很有故事,但本质上就是多引入了一个损失函数,具体来说,现在我们的优化目标为
arg
min
θ
∑
i
=
1
n
L
c
(
f
(
x
i
)
,
y
i
)
+
λ
∑
j
=
1
k
L
r
e
g
(
f
(
x
i
)
,
f
(
x
^
j
)
)
(5)
\underset{\theta}{\arg \min } \sum_{i=1}^{n} \mathcal{L}_{c}\left(f\left(x_{i}\right), y_{i}\right)+\lambda \sum_{j=1}^{k} \mathcal{L}_{\mathrm{reg}}\left(f\left(x_{i}\right), f\left(\hat{x}_{j}\right)\right)\tag{5}
θargmini=1∑nLc(f(xi),yi)+λj=1∑kLreg(f(xi),f(x^j))(5)
其中
f
f
f表示含有参数
θ
\theta
θ的预训练模型,
n
n
n为样本个数,
k
k
k表示由一条句子扩增出了
k
k
k条句子。具体来说,如果是分类任务,则
L
c
(
θ
)
=
1
n
∑
i
=
1
n
CE
(
f
(
E
i
;
θ
)
,
y
i
)
(6)
\mathcal{L}_c(\theta) = \frac{1}{n}\sum_{i=1}^n \text{CE}(f(\mathbf{E}_i;\theta), y_i)\tag{6}
Lc(θ)=n1i=1∑nCE(f(Ei;θ),yi)(6)
其中,
CE
(
⋅
,
⋅
)
\text{CE}(\cdot ,\cdot)
CE(⋅,⋅)是Cross-Entropy Loss,可以根据具体任务替换的,
E
i
\mathbf{E}_i
Ei表示第
i
i
i条句子通过Word2Vec之后生成的向量,其维度为[seq_len, emd_dim]
为了防止扩增后的样本与原始样本间的语义产生巨大差距,换句话说,我们希望扩增后的样本与原样本间的分布是接近的,因此论文引入了KL散度作为第二项损失
L
reg
(
θ
)
=
1
k
∑
i
=
1
k
D
s
K
L
(
f
(
E
i
;
θ
)
,
f
(
E
^
i
;
θ
)
)
(7)
\mathcal{L}_{\text{reg}}(\theta)=\frac{1}{k}\sum_{i=1}^k D_{sKL}(f(\mathbf{E}_i;\theta), f(\hat{\mathbf{E}}_i;\theta))\tag{7}
Lreg(θ)=k1i=1∑kDsKL(f(Ei;θ),f(E^i;θ))(7)
其中,
k
k
k指的是原样本扩增出了
k
k
k个样本,
D
s
K
L
D_{sKL}
DsKL是对称的KL散度,具体来说
D
s
K
L
(
p
,
q
)
=
D
K
L
(
p
,
q
)
+
D
K
L
(
q
,
p
)
2
(8)
D_{sKL}(p, q) = \frac{D_{KL}(p, q) + D_{KL}(q, p)}{2}\tag{8}
DsKL(p,q)=2DKL(p,q)+DKL(q,p)(8)
实际上这种方法可以看作是多任务,我们希望模型参数训练到一种境界,这种境界是,不论模型对原样本进行下游任务,还是让模型判断原样本与扩增样本的差距,模型都能做的很好。最后给出论文中的一张图结束这部分(图中一个样本扩增了3条样本)
如果单看原始的准确率对比,似乎提升并不是很大,感觉我随便引入一些trick都能达到甚至超过Virtual Data Augmentation的效果。关键在于第二列「Att Acc」,这代表模型受到攻击时的结果,这部分的提升特别大,表明VDA这种方法确实有很强的抗干扰性,或者说鲁棒性很强
实际上前面已经把这篇论文讲的很清楚了,这里没有什么好总结的,但我倒是有一点个人拙见想和大家讨论一下,因为他做MLM任务时,将整个Vocabulary都作为候选集,这样无论是对计算速度还是显存占用都不是很友好,我觉得可以将其改为取出概率最大的前Top k个token,这个k可以取的稍微大一点,例如200, 300等,这样可以保证取到后面一些语义上不那么相近的token的同时,避免对整个Vocabulary进行运算,至少不会生成几万几十万那么夸张的概率分布
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。