当前位置:   article > 正文

Preparing lessons: Improve knowledge distillation with Better supervision_logits adjustment

logits adjustment

问题:教师的logits进行训练可能出现incorrect和overly uncertain的监督

解决(1) Logits Adjustment(LA)

            (2) Dynamic Temperature Distillation(DTD)

LA:针对错误判断的训练样本,交换GT标签和误判标签的logits值

DTD:一些uncertain soft target是因为过高的温度值,因此可采用动态温度计算soft target,
该温度在训练期间自适应更新----->学生模型在训练中能够获得更多discriminative information,可视为在线的hard examples mining过程(OHEM:根据损失选择样本,重新送入网络)。

文章提出的方法:

3.1回顾知识蒸馏:

 公式(1):之前KD工作中为了将网络的softmax输出概率分布更加soft,便向softmax函数引入了温度参数

 ti 和si分别为教师模型和学生模型的logits

q和p分别为教师模型和学生模型的softmax概率分布。

公式(2):是蒸馏过程中的整体损失构成=学生损失+蒸馏损失;

学生损失是学生模型基于ground truth label进行训练学习得到的损失,缓解学习到的 不正确的logits信息;但这种缓解只是轻微的,仍存在遗传错误的问题。

文中验证:统计在ResNet50指导下ResNet18在CIFAR-10和CIFAR-100的遗传错误率分别有57%和42%。

 

 3.2 Genetic errors and logits adjustment

新概念Genetic errors: 学生模型的错误预测与教师的错误预测一致,称为遗传性错误。

提出LA,试图fix教师的预测,对教师的softened logits执行函数A(·),修改后的损失仍用交叉熵:

 A(·)的3个特点:

(1)fix错误的qt,而对正确的不做任何事,为保证训练稳定性,所作的修改尽可能小一些。

(2)在学生训练期间,从教师网络得到的参数qt不可变,因此优化对象可用交叉熵表示,而不是用KL。

(3)交叉熵计算不用y,因为A(qt)是完全正确的。

3.2.1 Why not LSR?

 Label Smooth Regularization (LSR)属于LA最简单的实现方式,但是有所限制。

LSR label:

 类别数量k   样本x    脉冲信号δ(·)

LSR丢弃教师预测的非true类别的概率,但这在KD中被证明是有帮助的。

因此,另一种简单实现被提出:Probability Shift(PS). 概率转移

思想:交换真实值标签值(理论上最大值)与预测类别值(预测的最大值),以保证最大置信度落在真实值标签。

Fig2. PD on 误判样本的soft target ,The sample is from CIFAR-100 training data, whose ground truth label is leopard but ResNet-50 teacher’s prediction is rabbit.The value of ”leopard” is still large, which indicates that the teacher does not go ridiculous. 转换操作就是交换两个类的值,得到一个leopard的最大预测值,rabbit的第二大预测值。

与LSR相比,PS保留了涉及微小概率的类别间差异,LSR则丢弃了大部分。不正确的预测类别往往与真实类别有一些相似的特征。也就是说,不正确的预测类别可能比其他类别包含更多信息。该方法还保留了软目标的数字分布,这对稳定训练过程是有帮助的。

动态温度蒸馏:

[30,31,44]研究表明学生可以从监督的不确定性中受益,但教师的过度不确定的预测也可能会影响性能。

下图:蒸馏softmax的可视化

图中可以看到随着温度的升高,各个类之间的概率差异变小,而真实值是leopard,另外两种kangroo和rabbit在训练中就是干扰项,因此为更好区分(扩大类间相似度),应当选择更小些的温度值。(但之前有提到过较高的温度值可以让softmax输出更加soft,但这里用较高的温度会让非真实类称为训练干扰项,所以采用动态温度DTD的思想)

 DTD描述:这里用的KL散度,而不是cross entropy,因为ptx是变化的,不是一个常量。

t0和β代表 基础温度和偏差,wx代表样本x的批量归一化权重,描述混淆的程度。当样本x有些混淆且教师预测值不确定时,wx会增加。如此一来tx<t0,soft targets更加有区分性。

======》混乱的样本会有更大的weights,更低的温度值。---样本之间就更加有区分度。

======》DTD更加关注那些confusing examples,就可以视为是一种hard examples mining.

文中提出两种方法计算权重wx,:

一种是FLSW计算sample-wise 权重; 另一种是依据学生预测的最大输出计算wx,称为Confidence Weighted by Student Max(CWSM)

Focal loss style weights:

原来的focal loss:

p为一个样本的分类分数。

本文方法中:学习难度可通过学生的logit v和教师的logit t之间的相似性来衡量。为简便,将r设置为一个常量,得

(wx代表迷惑程度即难分类程度)

v·t∈[-1,1]是两个分布的内积。当学生预测与教师的预测相差甚远时,wx就会变大。 

(回顾:A·B 内积计算是由第一个矩阵的每一行乘以第二个矩阵的每一列得到的)

Confidence Weighted by Student Max

根据学生归一化的logits的最大值给样本加权,在一定程度上可以反应样本的学习情况。

学生模型通常对confusing 样本有着不确定的预测,其logits的最大值也相应小一些,这里计算wx公式描述为:

其中学生的logit v 应该是normalized的,vmax被视为代表学生对样本的置信度 。低置信度的样本有更高的权重,这些样本的梯度在蒸馏过程中也贡献的更多。

Compound loss function and algorithm复式损失函数和算法

结合LA和DTD,整体损失:

 与公式(3)类似,但不同在于(10)采用sample-wise温度来soften logits.

监督张量A(qtx)会随着学习情况不断变化。

此外这里没必要使用真实值交叉熵,因为A(qtx)总是正确的。

 实验:

数据集:CIFAR-10, CIFAR-100, Tiny ImageNet
方法比较:标准蒸馏 (KD),注意力机制 (AT),神经元选择性迁移 (NST)

标准蒸馏KD:

公式(2)中α=0.7,方便起见,用KL散度实现两个分布之间的交叉熵。

AT:KD中引入注意力机制

 

NST:

注:MMD(Maximum Mean Discrepancy)

 

 SP:Similarity Preserving Distillation

 从特征相似性的角度提出了一种新的蒸馏损失,引导学生模仿样本间的相似性,而不是教师对空间的逻辑.

其中b为batch size,||·||F是矩阵的Frobenius范数,矩阵中的元素的平方和再开方。对于向量而言就是L2距离。Gt和Gs分别为教师和学生模型certain layer的相似性矩阵。

CIFAR-100

 

 

 

 

 

 

 

 

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/笔触狂放9/article/detail/688058
推荐阅读
相关标签
  

闽ICP备14008679号