当前位置:   article > 正文

YOLOv8知识蒸馏 | 目标检测的无损涨点

yolov8知识蒸馏

YOLOv8知识蒸馏

本文采用知识蒸馏的技术来训练模型。知识蒸馏是一种将复杂的模型知识传递给一个较简单模型的方法,从而提高简单模型的性能。而在知识蒸馏的过程中,主要有两种方式来传递知识:软标签和注意力图。软标签是一种将目标位置和类别信息以概率分布的形式传递给学生网络的方法,可以提供更丰富的信息。而注意力图则是一种将教师网络对目标的关注程度传递给学生网络的方式,可以帮助学生网络更好地学习目标的特征。
事前说明,知识蒸馏可用于自己魔改后的网络结构,但是需要保证教师网络比学生网络更大,且效果更好。本文将以私人睡岗数据集为例,选用yolov8n网络为学生网络,选用yosv8s网络为教师网络。

step 1.训练一版教师网络

step 2.损失函数修改

在上面代码下插入以下代码

#蒸馏关
#-------------------------------
# self.t_weights = False
# executed = True
#------------------------
# 蒸馏开
#-----------------------------------
executed = False
#------------------------------
if not executed:
    # ------------------蒸馏代码改进-------------------------------
    self.t_weights = "weights/best.pt"
    if self.t_weights:
        from ultralytics.nn.tasks import attempt_load_one_weight
        t_model = attempt_load_one_weight(self.t_weights, device=self.device)
        self.t_model = t_model[0]
        self.t_model.float()
        self.t_model.train()
    # -----------------------------------------------------------
    executed = True
if self.t_weights:
    with torch.no_grad():
        t_pred = self.t_model(batch['img'].to(torch.float32))
    t_loss=self.compute_distillation_output_loss(
        preds, t_pred, self.t_model, dist_loss="l2", T=20)
    return (loss.sum()+t_loss) * batch_size, loss.detach()
return loss.sum() * batch_size, loss.detach()  # loss(box, cls, dfl)

step 3.准备知识蒸馏训练

将蒸馏下面的代码打开、注释掉关闭的代码,并将以及训练好的教师网络放入self.t_weights

step 4 开始知识蒸馏训练

将model换成要蒸馏的模型结构,直接开训

step 5 训练结果

factor

P

R

Map

inference

parameters

GFLOPs

YOLOv8n(学生模型)

———

0.985

0.993

0.992

7.4ms

3006038

8.1

YOLOv8n(学生模型)

加蒸馏

0.995

0.999

0.995

7.4ms

3006038

28.1

YOLOv8s(教师模型)

0.994

0.996

0.995

 9ms

10892606

28.1

结果显示,yolov8n学生模型在教师模型软标签知识学习下效果最优

点个关注收藏,私信我发源码哦

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/凡人多烦事01/article/detail/416753
推荐阅读
相关标签
  

闽ICP备14008679号