当前位置:   article > 正文

yolov8(目标检测、图像分割、关键点检测)知识蒸馏:logit和feature-based蒸馏方法的实现_feature base知识蒸馏

feature base知识蒸馏

1.知识蒸馏的原理

目标检测中,知识蒸馏的原理主要是利用教师模型(通常是大型的深度神经网络)的丰富知识来指导学生模型(轻量级的神经网络)的学习过程。通过蒸馏,学生模型能够在保持较高性能的同时,减小模型的复杂度和计算成本。

知识蒸馏实现的方式有多种,但核心目标是将教师模型学习到的知识迁移到学生中去(通常是通过各种损失函数进行实现)。

本项目支持yolov8检测、分割、关键点任务的知识蒸馏,并对蒸馏代码进行详解,比较容易上手。蒸馏方式多种,支持 logit和 feature-based蒸馏以及在线蒸馏。:

2.logit 蒸馏原理

Logit蒸馏原理主要基于深度学习中的知识迁移技术,特别是在模型压缩和加速领域。其核心思想是利用大型、复杂的教师模型(Teacher Model)的logits(逻辑层的原始输出得分)来指导小型、轻量的学生模型(Student Model)的学习。

Logits是教师模型在做出最终决策之前的原始得分,这些得分在数值上表示了模型对每个类别的预测置信度。相较于最终的类别概率分布,logits包含了更丰富的信息,尤其是当不同类别之间存在细微差别时。

在Logit蒸馏过程中,教师模型的logits被用作额外的监督信号来训练学生模型。通过最小化教师模型和学生模型在logits层面上的差异(通常使用均方误差MSE或KL散度等损失函数),可以使学生模型学习到教师模型在决策边界附近的细致区分能力。这种蒸馏方式有助于提升学生模型在保持较高性能的同时,减小模型的复杂度和计算成本。

逻辑蒸馏损失定义的代码在:ultralytics/utils/distill_loss.py

  1. class Distill_LogitLoss:
  2. def __init__(self,p, t_p, alpha =0.25):
  3. t_ft = torch.cuda.FloatTensor if t_p[0].is_cuda else torch.Tensor
  4. self.p =p
  5. self.t_p = t_p
  6. self.logit_loss = t_ft([0])
  7. self.DLogitLoss = nn.MSELoss(reduction="none")
  8. self.bs = p[0].shape[0]
  9. self.alpha = alpha
  10. def __call__(self):
  11. # per output
  12. assert len(self.p) == len(self.t_p)
  13. for i, (pi,t_pi) in enumerate(zip(self.p,self.t_p)): # layer index, layer predictions
  14. assert pi.shape == t_pi.shape
  15. self.logit_loss += torch.mean(self.DLogitLoss(pi, t_pi))
  16. return self.logit_loss[0]*self.alpha

3.feature-base蒸馏原理

Feature-based蒸馏原理是知识蒸馏中的一种重要方法,其关键在于利用教师模型的隐藏层特征来指导学生模型的学习过程。这种蒸馏方式旨在使学生模型能够学习到教师模型在特征提取和表示方面的能力,从而提升其性能。

具体来说,Feature-based蒸馏通过比较教师模型和学生模型在某一或多个隐藏层的特征表示来实现知识的迁移。在训练过程中,教师模型的隐藏层特征被提取出来,并作为监督信号来指导学生模型相应层的特征学习。通过优化两者在特征层面的差异(如使用均方误差、余弦相似度等作为损失函数),可以使学生模型逐渐逼近教师模型的特征表示能力。

这种蒸馏方式有几个显著的优势。首先,它充分利用了教师模型在特征提取方面的优势,帮助学生模型学习到更具判别性的特征表示。其次,通过比较特征层面的差异,可以更加细致地指导学生模型的学习过程,使其在保持较高性能的同时减小模型复杂度。最后,Feature-based蒸馏可以与其他蒸馏方式相结合,形成更为复杂的蒸馏策略,以进一步提升模型性能。

需要注意的是,在选择进行Feature-based蒸馏的隐藏层时,需要谨慎考虑。不同层的特征具有不同的语义信息和抽象程度,因此选择合适的层进行蒸馏对于最终效果至关重要。此外,蒸馏过程中的损失函数和权重设置也需要根据具体任务和数据集进行调整。

综上所述,Feature-based蒸馏原理是通过利用教师模型的隐藏层特征来指导学生模型的学习过程,从而实现知识的迁移和模型性能的提升。这种方法在深度学习领域具有广泛的应用前景,尤其在需要提高模型特征提取能力的场景中表现出色。

本文将给出3种feature-base的蒸馏损失方法,代码分别如下

  • MimicLoss
  1. class MimicLoss(nn.Module):
  2. def __init__(self, channels_s, channels_t):
  3. super(MimicLoss, self).__init__()
  4. device = 'cuda' if torch.cuda.is_available() else 'cpu'
  5. self.mse = nn.MSELoss()
  6. def forward(self, y_s, y_t):
  7. """Forward computation.
  8. Args:
  9. y_s (list): The student model prediction with
  10. shape (N, C, H, W) in list.
  11. y_t (list): The teacher model prediction with
  12. shape (N, C, H, W) in list.
  13. Return:
  14. torch.Tensor: The calculated loss value of all stages.
  15. """
  16. assert len(y_s) == len(y_t)
  17. losses = []
  18. for idx, (s, t) in enumerate(zip(y_s, y_t)):
  19. assert s.shape == t.shape
  20. losses.append(self.mse(s, t))
  21. loss = sum(losses)
  22. return loss
  • CWDLoss
  1. class CWDLoss(nn.Module):
  2. """PyTorch version of `Channel-wise Distillation for Semantic Segmentation.
  3. <https://arxiv.org/abs/2011.13256>`_.
  4. """
  5. def __init__(self, channels_s, channels_t,tau=1.0):
  6. super(CWDLoss, self).__init__()
  7. self.tau = tau
  8. def forward(self, y_s, y_t):
  9. """Forward computation.
  10. Args:
  11. y_s (list): The student model prediction with
  12. shape (N, C, H, W) in list.
  13. y_t (list): The teacher model prediction with
  14. shape (N, C, H, W) in list.
  15. Return:
  16. torch.Tensor: The calculated loss value of all stages.
  17. """
  18. assert len(y_s) == len(y_t)
  19. losses = []
  20. for idx, (s, t) in enumerate(zip(y_s, y_t)):
  21. assert s.shape == t.shape
  22. N, C, H, W = s.shape
  23. # normalize in channel diemension
  24. softmax_pred_T = F.softmax(t.view(-1, W * H) / self.tau, dim=1) # [N*C, H*W]
  25. logsoftmax = torch.nn.LogSoftmax(dim=1)
  26. cost = torch.sum(
  27. softmax_pred_T * logsoftmax(t.view(-1, W * H) / self.tau) -
  28. softmax_pred_T * logsoftmax(s.view(-1, W * H) / self.tau)) * (self.tau ** 2)
  29. losses.append(cost / (C * N))
  30. loss = sum(losses)
  31. return loss
  • MGDLoss
  1. class MGDLoss(nn.Module):
  2. def __init__(self, channels_s, channels_t, alpha_mgd=0.00002, lambda_mgd=0.65):
  3. super(MGDLoss, self).__init__()
  4. device = 'cuda' if torch.cuda.is_available() else 'cpu'
  5. self.alpha_mgd = alpha_mgd
  6. self.lambda_mgd = lambda_mgd
  7. self.generation = [
  8. nn.Sequential(
  9. nn.Conv2d(channel, channel, kernel_size=3, padding=1),
  10. nn.ReLU(inplace=True),
  11. nn.Conv2d(channel, channel, kernel_size=3, padding=1)).to(device) for channel in channels_t
  12. ]
  13. def forward(self, y_s, y_t):
  14. """Forward computation.
  15. Args:
  16. y_s (list): The student model prediction with
  17. shape (N, C, H, W) in list.
  18. y_t (list): The teacher model prediction with
  19. shape (N, C, H, W) in list.
  20. Return:
  21. torch.Tensor: The calculated loss value of all stages.
  22. """
  23. assert len(y_s) == len(y_t)
  24. losses = []
  25. for idx, (s, t) in enumerate(zip(y_s, y_t)):
  26. assert s.shape == t.shape
  27. losses.append(self.get_dis_loss(s, t, idx) * self.alpha_mgd)
  28. loss = sum(losses)
  29. return loss
  30. def get_dis_loss(self, preds_S, preds_T, idx):
  31. loss_mse = nn.MSELoss(reduction='sum')
  32. N, C, H, W = preds_T.shape
  33. device = preds_S.device
  34. mat = torch.rand((N, 1, H, W)).to(device)
  35. mat = torch.where(mat > 1 - self.lambda_mgd, 0, 1).to(device)
  36. masked_fea = torch.mul(preds_S, mat)
  37. new_fea = self.generation[idx](masked_fea)
  38. dis_loss = loss_mse(new_fea, preds_T) / N
  39. return dis_loss

以上三种feature-based的蒸馏损失,其中MimicLoss是最常见的特征蒸馏损失,而MGDCWD是当前的SOTA特征蒸馏方案。

4.yolov8 蒸馏代码实现

(1)蒸馏参数的设置

将以下代码放置在ultralytics\engine\trainer.py文件种142行位置处

  1. self.dfea_loss = 0 # feature distill loss
  2. self.dlogit_loss = 0 # logit distill loss
  3. self.loss_t = 0 # teacher model distill online loss
  4. self.distill_loss =None
  5. self.model_t = overrides.get("model_t",None)
  6. self.distill_feat_type = "cwd" # "cwd","mgd","mimic"
  7. self.distill_online = True # False or True
  8. self.logit_loss = True # False or True
  9. #self.distill_layers = [6,8,12,15,18,21] # distill layers
  10. self.distill_layers = [2,4,6,8,12,15,18,21]
  11. # self.distill_layers = [15,18,21]
  12. # self.model_t: 获取蒸馏训练的教师模型,如果在训练模型时,没传入model_t, 则不会进行蒸馏训练,只进行一般的模型训练
  13. # self.distill_feat_type: 设置feature - based蒸馏的类型,支持"cwd", "mgd", "mimic", 任意一种
  14. # self.distill_online: 设置是否使用在线蒸馏, 默认为False即离线蒸馏,你也可以设置为True
  15. # self.logit_loss: 设置是否使用logit蒸馏
  16. # self.distill_layers: 设置特征蒸馏的层数,可根据需要选择需要蒸馏的特征层

   (2)蒸馏损失代码实现

新建ultralytics/utils/distill_loss.py文件,并将以上有关蒸馏损失放置在其中(完整代码可关注博主加私信获取)

(3) 优化器optimizer修改

(完整代码可关注博主加私信获取。获取后直接替换trainer.py即可)代码在ultralytics/engine/trainer.pybuild_optimizer函数中

将如下的300行左右build_optimizer,按下图进行修改

build_optimizer函数内容如下

  1. def build_optimizer(self, model, model_t,distill_loss,distill_online=False,name='auto', lr=0.001, momentum=0.9, decay=1e-5, iterations=1e5):
  2. """
  3. Constructs an optimizer for the given model, based on the specified optimizer name, learning rate, momentum,
  4. weight decay, and number of iterations.
  5. Args:
  6. model (torch.nn.Module): The model for which to build an optimizer.
  7. name (str, optional): The name of the optimizer to use. If 'auto', the optimizer is selected
  8. based on the number of iterations. Default: 'auto'.
  9. lr (float, optional): The learning rate for the optimizer. Default: 0.001.
  10. momentum (float, optional): The momentum factor for the optimizer. Default: 0.9.
  11. decay (float, optional): The weight decay for the optimizer. Default: 1e-5.
  12. iterations (float, optional): The number of iterations, which determines the optimizer if
  13. name is 'auto'. Default: 1e5.
  14. Returns:
  15. (torch.optim.Optimizer): The constructed optimizer.
  16. """
  17. g = [], [], [] # optimizer parameter groups
  18. bn = tuple(v for k, v in nn.__dict__.items() if 'Norm' in k) # normalization layers, i.e. BatchNorm2d()
  19. if name == 'auto':
  20. LOGGER.info(f"{colorstr('optimizer:')} 'optimizer=auto' found, "
  21. f"ignoring 'lr0={self.args.lr0}' and 'momentum={self.args.momentum}' and "
  22. f"determining best 'optimizer', 'lr0' and 'momentum' automatically... ")
  23. nc = getattr(model, 'nc', 10) # number of classes
  24. lr_fit = round(0.002 * 5 / (4 + nc), 6) # lr0 fit equation to 6 decimal places
  25. name, lr, momentum = ('SGD', 0.01, 0.9) if iterations > 10000 else ('AdamW', lr_fit, 0.9)
  26. self.args.warmup_bias_lr = 0.0 # no higher than 0.01 for Adam
  27. for v in model.modules():
  28. if hasattr(v, 'bias') and isinstance(v.bias, nn.Parameter): # bias (no decay)
  29. g[2].append(v.bias)
  30. if isinstance(v, bn): # weight (no decay)
  31. g[1].append(v.weight)
  32. elif hasattr(v, 'weight') and isinstance(v.weight, nn.Parameter): # weight (with decay)
  33. g[0].append(v.weight)
  34. if model_t is not None and distill_online:
  35. for v in model_t.modules():
  36. # print(v)
  37. if hasattr(v, 'bias') and isinstance(v.bias, nn.Parameter): # bias (no decay)
  38. g[2].append(v.bias)
  39. if isinstance(v, bn): # weight (no decay)
  40. g[1].append(v.weight)
  41. elif hasattr(v, 'weight') and isinstance(v.weight, nn.Parameter): # weight (with decay)
  42. g[0].append(v.weight)
  43. if model_t is not None and distill_loss is not None:
  44. for k, v in distill_loss.named_modules():
  45. # print(v)
  46. if hasattr(v, 'bias') and isinstance(v.bias, nn.Parameter): # bias (no decay)
  47. g[2].append(v.bias)
  48. if isinstance(v, bn) or 'bn' in k: # weight (no decay)
  49. g[1].append(v.weight)
  50. elif hasattr(v, 'weight') and isinstance(v.weight, nn.Parameter): # weight (with decay)
  51. g[0].append(v.weight)
  52. if name in ('Adam', 'Adamax', 'AdamW', 'NAdam', 'RAdam'):
  53. optimizer = getattr(optim, name, optim.Adam)(g[2], lr=lr, betas=(momentum, 0.999), weight_decay=0.0)
  54. elif name == 'RMSProp':
  55. optimizer = optim.RMSprop(g[2], lr=lr, momentum=momentum)
  56. elif name == 'SGD':
  57. optimizer = optim.SGD(g[2], lr=lr, momentum=momentum, nesterov=True)
  58. else:
  59. raise NotImplementedError(
  60. f"Optimizer '{name}' not found in list of available optimizers "
  61. f'[Adam, AdamW, NAdam, RAdam, RMSProp, SGD, auto].'
  62. 'To request support for addition optimizers please visit https://github.com/ultralytics/ultralytics.')
  63. optimizer.add_param_group({'params': g[0], 'weight_decay': decay}) # add g0 with weight_decay
  64. optimizer.add_param_group({'params': g[1], 'weight_decay': 0.0}) # add g1 (BatchNorm2d weights)
  65. LOGGER.info(
  66. f"{colorstr('optimizer:')} {type(optimizer).__name__}(lr={lr}, momentum={momentum}) with parameter groups "
  67. f'{len(g[1])} weight(decay=0.0), {len(g[0])} weight(decay={decay}), {len(g[2])} bias(decay=0.0)')
  68. return optimizer

5.yolov8 蒸馏训练步骤

在项目中,教师模型model_t选择是yolov8l, 学生模型model_s,选择的是yolov8n

(1) 训练教师模型

  1. from ultralytics import YOLO
  2. data = r"ultralytics\datasets\coco128.yaml"
  3. model_t = YOLO(r'weights\yolov8l.pt')
  4. model_t.train(data=data, epochs=300, imgsz=640)

(2) 训练学生模型baseline

  1. from ultralytics import YOLO
  2. data = r"ultralytics\datasets\coco128.yaml"
  3. model_s = YOLO(r'weights\yolov8n.pt')
  4. model_s.train(data=data, epochs=300, imgsz=640)

(3) 蒸馏训练

将已经训练好的教师模型model_t的知识通过logit与feature-base知识蒸馏的方式迁移到学生模型model_s上,从而提升学生模型的性能。

  1. import torch
  2. from ultralytics import YOLO
  3. data = r"/home/xxx/project/public/yolov8-ultralytics-main/yolov8-ultralytics-main/ultralytics/cfg/datasets/coco128.yaml"
  4. model_t = YOLO(r'/home/xxx/project/public/yolov8-ultralytics-main/yolov8-ultralytics-main/weights/yolov8l.pt')
  5. model_t.model.model[-1].set_Distillation = True
  6. model_s = YOLO(r'/home/yuanwushui/project/public/yolov8-ultralytics-main/yolov8-ultralytics-main/yolov8n.pt')
  7. model_s.train(data=data, epochs=300, imgsz=640, model_t= model_t.model)

如果传入了model_t,则会进行蒸馏训练,否则为普通训练

注:feature-based蒸馏的类型设置(支持"cwd","mgd","mimic", 任意一种);设置是否使用在线蒸馏, (默认为False即离线蒸馏,你也可以设置为True);设置是否使用logit蒸馏;设置特征蒸馏的层数,(可根据需要选择需要蒸馏的特征层)。均在ultralytics/engine/trainer.py中的BaseTrainer类的初始化函数中__init__.py中进行设置。如下图

6.训练成功

注,以上全部代码均可关注博主,私信后获取,仅需在某些位置进行代码与文件的替换即可,基于你的代码改写后并不影响你的原始代码使用,是否开启蒸馏、开启什么样的蒸馏取决于你的参数设置

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/Cpp五条/article/detail/621372
推荐阅读
相关标签
  

闽ICP备14008679号