当前位置:   article > 正文

知识蒸馏(Knowledge Distillation)

知识蒸馏

本文主要罗列与知识蒸馏相关的一些算法与应用。但首先需要明确的是,教师网络或给定的预训练模型中包含哪些可迁移的知识?基于常见的深度学习任务,可迁移知识列举为:

  • 中间层特征:浅层特征注重纹理细节,深层特征注重抽象语义;
  • 任务相关知识:如分类概率分布,目标检测涉及的实例语义、位置回归信息等;
  • 表征相关知识:强调特征表征能力的迁移,相对通用、任务无关(Task-agnostic);表征间相关性,如相似度、Relation等;

另外,知识蒸馏的应用主要有哪些?粗略概况,可包含模型压缩、迁移学习与多教师信息融合等。

1、Distilling the Knowledge in a Neural Network

Hinton的文章"Distilling the Knowledge in a Neural Network"首次提出了知识蒸馏(暗知识提取)的概念,通过引入与教师网络(Teacher network:复杂、但预测精度优越)相关的软目标(Soft-target)作为Total loss的一部分,以诱导学生网络(Student network:精简、低复杂度,更适合推理部署)的训练,实现知识迁移(Knowledge transfer)。

如上图所示,教师网络(左侧)的预测输出除以温度参数(Temperature)之后、再做Softmax计算,可以获得软化的概率分布(软目标或软标签),数值介于0~1之间,取值分布较为缓和。Temperature数值越大,分布越缓和;而Temperature数值减小,容易放大错误分类的概率,引入不必要的噪声。针对较困难的分类或检测任务,Temperature通常取1,确保教师网络中正确预测的贡献。硬目标则是样本的真实标注,可以用One-hot矢量表示。Total loss设计为软目标与硬目标所对应的交叉熵的加权平均(表示为KD loss与CE loss),其中软目标交叉熵的加权系数越大,表明迁移诱导越依赖教师网络的贡献,这对训练初期阶段是很有必要的,有助于让学生网络更轻松的鉴别简单样本,但训练后期需要适当减小软目标的比重,让真实标注帮助鉴别困难样本。另外,教师网络的预测精度通常要优于学生网络,而模型容量则无具体限制,且教师网络推理精度越高,越有利于学生网络的学习。

教师网络与学生网络也可以联合训练,此时教师网络的暗知识及学习方式都会影响学生网络的学习,具体如下(式中三项分别为教师网络Softmax输出的交叉熵loss、学生网络Softmax输出的交叉熵loss、以及教师网络数值输出与学生网络Softmax输出的交叉熵loss):

联合训练的Paper地址:https://arxiv.org/abs/1711.05852

2、Exploring Knowledge Distillation of Deep Neural Networks for Efficient Hardware Solutions

GitHub地址:https://github.com/peterliht/knowledge-distillation-pytorch

这篇文章将Total loss重新定义如下:

Total loss的PyTorch代码如下,引入了精简网络输出与教师网络输出的KL散度,并在诱导训练期间,先将Teacher network的预测输出缓存到CPU内存中,可以减轻GPU显存的Overhead:

  1. def loss_fn_kd(outputs, labels, teacher_outputs, params):
  2. """
  3. Compute the knowledge-distillation (KD) loss given outputs, labels.
  4. "Hyperparameters": temperature and alpha
  5. NOTE: the KL Divergence for PyTorch comparing the softmaxs of teacher
  6. and student expects the input tensor to be log probabilities! See Issue #2
  7. """
  8. alpha = params.alpha
  9. T = params.temperature
  10. KD_loss = nn.KLDivLoss()(F.log_softmax(outputs/T, dim=1),
  11. F.softmax(teacher_outputs/T, dim=1)) * (alpha * T * T) + \
  12.                F.cross_entropy(outputs, labels) * (1. - alpha)
  13. return KD_loss

3、Ensemble of Multiple Teachers

Paper地址: Efficient Knowledge Distillation from an Ensemble of Teachers | Request PDF

第一种算法:多个教师网络输出的Soft label按加权组合,构成统一的Soft label,然后指导学生网络的训练:

第二种算法:由于加权平均方式会弱化、平滑多个教师网络的预测结果,因此可以随机选择某个教师网络的Soft label作为Guidance:

第三种算法:同样地,为避免加权平均带来的平滑效果,首先采用教师网络输出的Soft label重新标注样本、增广数据、再用于模型训练,该方法能够让模型学会从更多视角观察同一样本数据的不同功能:

4、Hint-

声明:本文内容由网友自发贡献,转载请注明出处:【wpsshop博客】
推荐阅读
相关标签
  

闽ICP备14008679号