当前位置:   article > 正文

Rdrop技术(Regularized Dropout)

rdrop

Rdrop理论

每个数据样本重复经过带有Dropout的同一个模型,再使用KL散度约束两次的输出,使得尽可能一致,而由于 Dropout的随机性,可以近似把输入X走过两次的路径网络当作两个略有不同的模型,如下图所示:
请添加图片描述

【补充知识点一】-- 损失函数
一部分的损失函数是常规的交叉熵

训练数据为 ( x i , y i ) ({x_i,y_i}) (xi,yi) ,模型为 P θ ( y ∣ x ) P_{\theta}(y|x) Pθ(yx), 每个样本的交叉熵都是
L i = − l o g P θ ( y i ∣ x i ) L_i = -logP_{\theta}(y_i|x_i) Li=logPθ(yixi)

在“Dropout两次”的情况下,其实我们可以认为样本已经通过了两个略有不同的模型,分别记为: P θ ( 1 ) ( y i ∣ x i ) , P θ ( 2 ) ( y i ∣ x i ) P_{\theta}^{(1)}(y_i|x_i), P_{\theta}^{(2)}(y_i|x_i) Pθ(1)(yixi),Pθ(2)(yixi)

所以其中一部分的损失函数是
L 1 = − l o g P θ ( 1 ) ( y i ∣ x i ) − l o g P θ ( 2 ) ( y i ∣ x i ) L1 = -logP_{\theta}^{(1)}(y_i|x_i) -logP_{\theta}^{(2)}(y_i|x_i) L1=logPθ(1)(yixi)logPθ(2)(yixi)

另一部分的损失函数是KL散度,目的是让两次的输出,使得尽可能一致。
L 2 = 1 2 [ K L ( P θ ( 1 ) ( y i ∣ x i ) ∣ P θ ( 2 ) ( y i ∣ x i ) ) + K L ( P θ ( 2 ) ( y i ∣ x i ) ∣ P θ ( 1 ) ( y i ∣ x i ) ) ] L2 = \frac{1}{2} [KL(P_{\theta}^{(1)}(y_i|x_i) | P_{\theta}^{(2)}(y_i|x_i)) + KL(P_{\theta}^{(2)}(y_i|x_i) | P_{\theta}^{(1)}(y_i|x_i)) ] L2=21[KL(Pθ(1)(yixi)Pθ(2)(yixi))+KL(Pθ(2)(yixi)Pθ(1)(yixi))]

最终loss就是两个loss的加权和.
L = L 1 + α L 2 L=L1 + \alpha L2 L=L1+αL2

【补充知识点二】-- KL散度的定义
KL散度从信息损耗的角度来度量两个数据分布的差异程度
它是用来描述两个概率分布P和Q的差异的一种方法。

在信息系统中称为相对熵
在连续时间序列中称为随机性
在统计模型推断中称为信息增益,也称信息散度

【补充】这个指标不能用作距离衡量,因为该指标不具有对称性,即D(P||Q) ≠ D(Q||P)

D(P||Q) 表示用概率分布Q 来拟合 真实分布P时,产生的信息损耗
【其中P表示真实分布,Q表示P的拟合分布】请添加图片描述
物理意义:
在信息论中,它是用来度量使用基于Q分布的编码来编码来自P分布的样本平均所需的额外的比特(bit)个数。
在机器学习领域,是用来度量两个函数的相似程度或者相近程度

几个常见的应用

(1)机器学习领域中的生成模型往往涉及 所产生的能够尽可能反映到真实情况分布模型。
生成对抗网络(GAN)在图片上的应用往往执行的是类似基于黑白图片生成看起来尽量真实的彩色图片这样的任务。
在这类似的应用中,输入往往是图像或像素。网络会学习这些像素之间的依赖关系(比如临近像素通常有相似的颜色
然后使用它来创建看起来尽量真实的图像。因此,生成器的目标就是最小化所学到的像素分布于真实图像像素分布之间的散度
(2)用户画像的刻画。在电商场景,可以使用KL散度去计算同一类型商品不同用户群体之间的金额KL散度,如果都很接近,说明,这类型的商品不能体现不同用户的差异点,可以进行剔除。只保留有差异性的商品类型(KL散度较大)。这样就可以基于结果去对用户进行更好的画像。

具体操作方法

其实Rdrop的实际操作和应用非常简单,对比于对抗训练和梯度惩罚这些方法。
Rdrop主要是实施两部
第一:在数据生成器中作修改,将产生的数据、标签重复生成两次
第二:在loss部分,重新再定义一个新的loss,即KL散度
code部分

def KL(input, target, reduction="sum"):
    input = input.float()
    target = target.float()
    loss = F.kl_div(F.log_softmax(input, dim=-1, dtype=torch.float32),
                    F.softmax(target, dim=-1, dtype=torch.float32), reduction=reduction)
    return loss
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6

F.kl_div函数中 第一个参数传入的是一个对数概率矩阵,第二个参数传入的是概率矩阵。这里很重要,不然求出来的kl散度可能是个负值
【补充 KL散度的细节理解】
当时先自己定义了KL散度,使用了F.kl_div函数来计算。
其中比较纠结的细节是 reduction的参数选择,以及在大规模训练的时候,其选择的影响。
(1)batchmean: 先把结果相加 再除 batchsize
(2)mean:先把结果相加,再除 个数
(3)sum:把结果相加
【 Default: ‘mean’】
例子:

import torch
import torch.nn.functional as F
# 定义两个矩阵
x = torch.randn((4, 5))
y = torch.randn((4, 5))
# 因为要用y指导x,所以求x的对数概率,y的概率
logp_x = F.log_softmax(x, dim=-1)
p_y = F.softmax(y, dim=-1)
kl_sum = F.kl_div(logp_x, p_y, reduction='sum')
kl_mean = F.kl_div(logp_x, p_y, reduction='mean')
kl_batchmean = F.kl_div(logp_x, p_y, reduction='batchmean')
print(kl_sum, kl_mean)
print(kl_sum/kl_mean)
output:
tensor(3.6756) tensor(0.1838)
tensor(20.)
print(kl_sum,kl_mean,kl_batchmean)
tensor(3.6756) tensor(0.1838) tensor(0.9189)
kl_sum/kl_batchmean
tensor(4.)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20

kl_sum/kl_mean = 20(个values)刚好20个values 因为x和y是 4x5的矩阵
kl_sum/kl_batchmean = 4(batch)4代表四组数, 每组5个values

当时纠结的是应该用mean还是用batchmean来计算,肯定不能用sum,因为sum的数值太大,造成kl散度损失值过大,另一部分交叉熵的数值过小。
(1)使用mean:除于所有的数,即batch_num * seq_len
【方法一可以避免 batch num和seq_len的影响】

(2)使用batchmean:除于所有的example数,即batch_num
【方法一可以避免 batch num的影响】

由于Rdrop一般应用在后面任务的微调训练中,在此过程中,会根据不同的任务设置不同的seq len,因此应该避免seq_len的影响。 所以采用men的方法比较好

个人思考

(1)关注非目标类的稳定性,提高模型的鲁棒性
在损失函数多加的KL散度为什么会发挥这么好的作用呢?
交叉熵的训练目标主要是:
让目标类的得分大于非目标类的得分,这样模型就能正确地把目标类预测出来了
也就是说 只有交叉熵的损失函数只能做到 在不同的dropout下,目标类的得分都大于非目标类的得分
而做不到 不同的Dropout下,每个类的得分一致。这样有利于模型的鲁棒性
而新添加的KL散度可以一定程度解决这个问题。
从公式上来看,交叉熵只跟目标类别有关,不关心非目标类的分布,假如目标类为第一个类别,那么预测结果是[0.5,0.2,0.3]或[0.5,0.3,0.2],对它来说都没区别。但对于KL散度项来说就不一样了,每个类的得分都要参与计算,[0.5,0.2,0.3]或[0.5,0.3,0.2]是有非零损失的。
(2)连续对模型的每一层添加干扰
R-Drop的扰动则可以施加到模型的每一层中,并且扰动是随机的。
虽然R-Drop的扰动是随机的,但是R-Drop的扰动更多,所以它造成的扰动也会放大

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/我家自动化/article/detail/625943
推荐阅读
相关标签
  

闽ICP备14008679号