当前位置:   article > 正文

【目标检测实验系列】YOLOv5改进实验:结合VariFocal Loss损失函数,减少小目标漏检问题,高效提升模型检测的召回率(超详细改进代码流程)

varifocal loss

1. 文章主要内容

       本篇博客主要涉及两个主体内容。第一个:简单介绍VariFocal Loss的原理。第二个:基于YOLOv5 6.0版本,将损失函数替换为VariFocal Loss的详细调试步骤(通读本篇博客大概需要5分钟左右的时间)。

2. VariFocal Loss损失函数(原理:简单介绍,可自行详细研究)

2.1 VariFocal Loss损失函数

       VariFocal Loss是从Focal Loss而来,所以我们要首先了解Focal Loss。Focal Loss提出来要解决的问题是训练数据中,正负样本不均衡的问题。 何为正负样本不均衡?比如说,我们训练的图片样本,尤其是包含很多小目标的图片样本,其实要检测的目标(也就是我们说的正样本)只占图片区域的少部分(综合来看),大部分的区域则为背景区域(也就是我们说的负样本);这就会导致训练数据负样本占多,而正样本相对来说占少数,模型的训练效果会变差,Focal Loss给难分、易分样本加上权重因子,提升难分样本的权重,降低易分样本的权重,从而控制正负样本均衡的问题,其中背景类一般为易分样本,而目标类为难分样本。 同时,Focal Loss适合检测密集型目标的图片样本,这个对小尺寸、拥挤、遮挡等特点的数据集会有不错的效果。

       VariFocal Loss是在Focal Loss的基础上提出的,因为Focal Loss对正负样本的处理是均衡的,而varifocal loss仅减少了负样本的损失贡献,而不以同样的方式降低正样本的权重。具体的公式、原理还请查看原论文或者网上的文章解析。
       原论文地址Focal Loss 论文VariFocal Loss 论文

2.2 博主数据集实验效果

       博主所训练的数据集特点:小尺寸目标居多,密集且目标尺寸不一,实验数据如下所示:
       原YOLOv5s框架实验数据:P(查准率):0.935、R(召回率):0.927、mAP@0.5(平均检测精度):0.942
       YOLOv5s+VFLoss:P(查准率):0.974、R(召回率):0.95、mAP@0.5(平均检测精度):0.962
       由实验数据对比,YOLOv5s+VariFocal Loss能够极大的提升R,同时P、mAP也有所提升,提升R指标,就表明能够检测更多的目标,可以减少模型漏检的问题,且FLOPs和原YOLOv5模型一样。

3. 代码详细改进流程(重要)

3.1 新建varifocalLoss.py文件

       ( 注意:博主使用的是Pycharm集成开发工具)首先在data->tricks目录下新建一个叫varifocal loss的py文件( 注意:tricks文件夹是自定义创建的,没有自己创建一个即可),将如下代码复制到varifocal loss的py文件中:

import torch
from torch import nn

class VFLoss(nn.Module):
    def __init__(self, loss_fcn, gamma=1.5, alpha=0.25):
        super(VFLoss, self).__init__()
        # 传递 nn.BCEWithLogitsLoss() 损失函数  must be nn.BCEWithLogitsLoss()
        self.loss_fcn = loss_fcn  #
        self.gamma = gamma
        self.alpha = alpha
        self.reduction = loss_fcn.reduction
        self.loss_fcn.reduction = 'none'  # required to apply VFL to each element

    def forward(self, pred, true):

        loss = self.loss_fcn(pred, true)

        pred_prob = torch.sigmoid(pred)  # prob from logits

        focal_weight = true * (true > 0.0).float() + self.alpha * (pred_prob - true).abs().pow(self.gamma) * (
                    true <= 0.0).float()
        loss *= focal_weight

        if self.reduction == 'mean':
            return loss.mean()
        elif self.reduction == 'sum':
            return loss.sum()
        else:
            return loss
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29

       另外项目的data-tricks目录结构如下所示:( 之所以要新建文件,是为了方便,清晰的分辨哪些创新点,而不是一股脑都放在一个文件中
在这里插入图片描述

3.2 修改hyp.scratch-low.yaml文件

       在data/hyps文件夹下面,找到hyp.scratch-low.yaml(注意:hyps下面有多个yaml文件,博主这里修改hyp.scratch-low.yaml,是因为在train.py文件中调用了此文件,如果你调用了是另外的yaml文件,则需要在你调用的那个yaml文件中做修改),修改fl_gamma的这一行值,原本是0.0,这里修改为1.5即可,如下图所示:
在这里插入图片描述

3.3 修改loss.py文件

       loss.py文件在utils文件夹下面,打开并定位到如下代码部分(大概是111行),修改成如下的代码所示:

BCEcls, BCEobj = VFLoss(BCEcls, g), VFLoss(BCEobj, g)
  • 1

在这里插入图片描述
       这里用VFLoss替换了YOLOv5的分类、置信度损失,回归框损失没有替换。同时,我们注意到,g = h[‘fl_gamma’] 这行代码就是hyp.scratch-low.yaml的fl_gamma值,设置为1.5(1.5为经验值)就可以进入到if g > 0 的条件当中,另外需要导入VFLoss的引用,不然会报错,只需要在文件首部添加from data.tricks.varifocalLoss import VFLoss即可。

4. 本篇小结

       本篇博客主要介绍了VF+YOLOv5的修改详细流程,助力模型对小目标、多尺寸、密集目标的检测涨点。另外,在修改过程中,要是有任何问题,评论区交流;如果博客对您有帮助,请帮忙点个赞,收藏一下;后续会持续更新本人实验当中觉得有用的点子,如果很感兴趣的话,可以关注一下,谢谢大家啦!

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/Gausst松鼠会/article/detail/329502
推荐阅读
相关标签
  

闽ICP备14008679号