当前位置:   article > 正文

VarifocalLoss 源码实现解读及其在 Yolov8 中的应用_varifocal loss公式

varifocal loss公式

文章日志:

  • 2023-09-12:文章发布
  • 2024-02-06:修正了文中 BCE 公式的错误
  • 2024-02-07:在 yolov8 中添加 varifocal loss 作为 cls loss 的实现

YoloV8 内的 varifocal loss 实现

class VarifocalLoss(nn.Module):
    """Varifocal loss by Zhang et al. https://arxiv.org/abs/2008.13367."""

    def __init__(self):
        """Initialize the VarifocalLoss class."""
        super().__init__()

    def forward(self, pred_score, gt_score, label, alpha=0.75, gamma=2.0):
        """Computes varfocal loss."""
        weight = alpha * pred_score.sigmoid().pow(gamma) * (1 - label) + gt_score * label
        with torch.cuda.amp.autocast(enabled=False):
            loss = (F.binary_cross_entropy_with_logits(pred_score.float(), gt_score.float(), reduction='none') *
                    weight).mean(1).sum()
        return loss
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14

最近老想在计算 yolov8 的分类 loss 的时候重新启用 varifocal loss,但是感觉有点坑。自己对 varifocal loss 的理解也不是很深,所以还是乖乖地理解透彻,再进行改造。

从 yolov8 的源码上来看,就是整个 varifocal loss 的就是通过 bce loss 乘以一个权重 weight 计算得到的。

理论的公式推导

但是我一直有一个疑问,代码和公式是怎么对应上的。

下面是 varifocal loss 的公式:

VFL ( p , q ) = { − q ⋅ [ q log ⁡ ( p ) + ( 1 − q ) l o g ( 1 − p ) ] q > 0 − α p γ log ⁡ ( 1 − p ) q = 0 \text{VFL}(p, q) = \left\{

q[qlog(p)+(1q)log(1p)]q>0αpγlog(1p)q=0
\right. VFL(p,q)={q[qlog(p)+(1q)log(1p)]αpγlog(1p)q>0q=0

  • p p p: 预测的 IACS 值
  • q q q:候选框与真实框的 IOU,对于负样本来说,也就是不负责真实框的候选框,其 q q q 值为 0 。

我们知道 bce loss 的公式如下: BCE ( p , q ) = { − [ q log ⁡ ( p ) + ( 1 − q ) log ⁡ ( 1 − p ) ] q > 0 − log ⁡ ( 1 − p ) q = 0 \text{BCE}(p, q) = \left\{

[qlog(p)+(1q)log(1p)]q>0log(1p)q=0
\right. BCE(p,q)={[qlog(p)+(1q)log(1p)]log(1p)q>0q=0

将 bce loss 的公式代入到 varifocal loss 中,就可以得到: VFL ( p , q ) = { q ⋅ BCE ( p , q ) q > 0 α p γ ⋅ BCE ( p , q ) q = 0 \text{VFL}(p, q) = \left\{

qBCE(p,q)q>0αpγBCE(p,q)q=0
\right. VFL(p,q)={qBCE(p,q)αpγBCE(p,q)q>0q=0

如果存在一个 label 的掩码矩阵,可以有效的标识 q > 0 q >0 q>0 或者 q = 0 q=0 q=0的情况,我们可以进一步的将公式表示为:
VFL ( p , q ) = α p γ ⋅ ( 1 − label ) ⋅ BCE ( p , q ) + q ⋅ label ⋅ BCE ( p , q ) \text{VFL}(p, q) = \alpha p^{\gamma}\cdot (1- \text{label}) \cdot \text{BCE}(p, q) + q \cdot \text{label} \cdot \text{BCE}(p, q) VFL(p,q)=αpγ(1label)BCE(p,q)+qlabelBCE(p,q)

我们要怎么选择这个 label 掩码矩阵? q q q 为 0,意味着,候选框和真实框是没有交集的,也就是所谓不负责,这么看感觉和类别没啥关系。感觉只要求得一个 gt box 和 对应的 anchor box 之间的关系就行了。也就是知道哪个 anchor box 是正样本就行了。

代码的公式实现

回到代码中 pred_scores 就是 p p pgt_score 就是 q q q,这样的话,代码中的权重计算可以表示为:
w = α p γ ⋅ ( 1 − label ) + q ⋅ label w = \alpha p^{\gamma}\cdot (1-\text{label}) + q\cdot \text{label} w=αpγ(1label)+qlabel
也就是说代码中的 VFL 公式可以表示为:
VFL ( p , q ) = w ⋅ BCE ( p , q ) = α p γ ⋅ ( 1 − label ) ⋅ BCE ( p , q ) + q ⋅ label ⋅ BCE ( p , q ) \text{VFL}(p, q) = w\cdot \text{BCE}(p, q) = \alpha p^{\gamma}\cdot (1-\text{label}) \cdot \text{BCE}(p, q) + q \cdot \text{label} \cdot \text{BCE}(p, q) VFL(p,q)=wBCE(p,q)=αpγ(1label)BCE(p,q)+qlabelBCE(p,q)

可以看出来,我推导出来的公式和 yolov8 源码中代码的实现好像差了一个负号,我也不知道问题出来哪里,是代码的实现错了吗?但是官方好像之前确实用过 varifocal loss 来进行分类 loss 的计算。(注:经过评论区抓 Bug, 发现少了括号,所以符号错了,现在已经改正了,公式和代码中的是完全一样了。)


mmdet 内的 varifocal loss 实现

代码地址:mmdetection 的 varifocal loss 实现

def varifocal_loss(pred: Tensor,
                   target: Tensor,
                   weight: Optional[Tensor] = None,
                   alpha: float = 0.75,
                   gamma: float = 2.0,
                   iou_weighted: bool = True,
                   reduction: str = 'mean',
                   avg_factor: Optional[int] = None) -> Tensor:

    # pred and target should be of the same size
    assert pred.size() == target.size()
    pred_sigmoid = pred.sigmoid()
    target = target.type_as(pred)
    if iou_weighted:
        focal_weight = target * (target > 0.0).float() + \
            alpha * (pred_sigmoid - target).abs().pow(gamma) * \
            (target <= 0.0).float()
    else:
        focal_weight = (target > 0.0).float() + \
            alpha * (pred_sigmoid - target).abs().pow(gamma) * \
            (target <= 0.0).float()
    loss = F.binary_cross_entropy_with_logits(
        pred, target, reduction='none') * focal_weight
    loss = weight_reduce_loss(loss, weight, reduction, avg_factor)
    return loss
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25

代码的公式实现

看了一下上面的代码,因为是单一分类的 loss,考虑 iou_weight 的话,focal weight 的计算公式如下:
w = q ⋅ label + α ⋅ ( 1 − label ) ∣ p − q ∣ γ w = q \cdot \text{label} + \alpha \cdot (1-\text{label}) |p-q|^{\gamma} w=qlabel+α(1label)pqγ
最后 varifocal loss 的计算公式如下:
VFL ( p , q ) = α ⋅ ∣ p − q ∣ γ ⋅ ( 1 − label ) ⋅ BCE ( p , q ) + q ⋅ label ⋅ BCE ( p , q ) \text{VFL}(p,q) =\alpha \cdot |p-q|^{\gamma} \cdot (1-\text{label}) \cdot \text{BCE}(p,q) + q \cdot \text{label} \cdot \text{BCE}(p,q ) VFL(p,q)=αpqγ(1label)BCE(p,q)+qlabelBCE(p,q)

先前没有搞懂为啥有个 iou_weight,感觉和我在 yolov8 源码中看到的不一样,但是粗略看了一下论文,发现原来 varifocal loss 计算的 p p p 就是 IACS (iou aware classification score),就是 iou 与 cls score 组合起来的一个参数。如果说之前的 cls score 是使用 one-hot 来表示的话,现在就是使用 iou 来代替。

所以在 mmedetection 的源码中,如果不是用 IACS 来计算 loss 的话,就会将 iou_weight 设置成 False,这样就不需要考虑 target 中对应的具体 iou 值了。


想到这里,我在想 yolov8 是否计算了 IACS 值。

Varifocal loss 在 Yolov8 中的应用

问题具象化

其实在官方的 issue 列表里面有一个 issue 就提到了这个问题:Using VFL loss in Yolov8 but the dimensions of the anchor class and gt class not match.

因为 Yolov8 的源码里是有 varifocal loss 的实现的,但是如果只是简单的将 bce loss 替换成 varifocal loss 的话,就会出现维度不匹配的问题,也就是 issue 标题的问题。所以我在

loss.py 这个文件里,我们可以找到 Varifocal
loss 和 focal loss 的实现,如果是用 yolov8 进行检测任务的话, 也可以看到 v8DetectionLoss,也就是,如果想要应用 Varifocal loss,在 v8DetectionLoss 进行相应的修改就好了。

Yolov8 的损失函数,分开来看由 bbox loss 和 cls loss 组成,bbox loss 又是通过 ciou loss 和 dfl loss 获得的,但是在这里我们就不展开了,官方的 cls loss 是使用 bce loss 实现的,但是上面有一行 varifocal loss 是如下注释掉的:

    # Cls loss
    # loss[1] = self.varifocal_loss(pred_scores, target_scores, target_labels) / target_scores_sum  # VFL way
    loss[1] = self.bce(pred_scores, target_scores.to(dtype)).sum() / target_scores_sum  # BCE
  • 1
  • 2
  • 3

但是如果你直接关掉注释,使用它的这句代码,就会出现 issue 中提到的维度不匹配问题,这是不能直接使用的,需要经过一定的数据处理才能将数据的维度对齐
也就是我们需要写一段维度对齐的代码来保证 varifocal loss 函数的输入是正确的。

target_labels 是什么,需要转换成什么形式?

从前面的公式解释中,我们就可以看到,target_labels 在 varifocal loss 里就是 label

glenn-jocher 对于 label 的解释如下:
在这里插入图片描述
大体来说,就是一个 bounding box 和 ground truth box 的关系矩阵,只要两者是有联系的,在 varifocal loss 里,通常要编码成 one-hot 矩阵。

issue 1448 也有一位大佬提到虽然已经不维护 varifocal loss 了,但是还记得 target_labels 是要编码成 one-hot 形式然后再传进去。
在这里插入图片描述
初始的 target_labels 是通过 self.assigner 获得的,将第一个参数显式的写出来就好了。

target_labels, target_bboxes, target_scores, fg_mask, _ = self.assigner(
            pred_scores.detach().sigmoid(),
            (pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype),
            anchor_points * stride_tensor,
            gt_labels,
            gt_bboxes,
            mask_gt,
        )
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8

但是这个初始的 target_labels 的维度是 [batchsize, anchor],我的代码里是 [16, 8400], 里面的数字是类别信息。

tensor([[11, 11, 11,  ..., 11, 11, 11],
        [11, 11, 11,  ..., 11, 11, 11],
        [ 0,  0,  0,  ...,  0,  0,  0],
        ...,
        [ 9,  9,  9,  ...,  9,  9,  9],
        [ 8,  8,  8,  ...,  8,  8,  8],
        [ 2,  2,  2,  ...,  2,  2,  2]], device='cuda:0')
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7

所以通过下面的转换,就可以将 target_labels 转换成 one-hot 的编码,进而计算 varifocal loss。

target_labels = target_labels.unsqueeze(-1).expand(-1, -1, self.nc)  # self.nc: class num
one_hot = torch.zeros(target_labels.size(), device=self.device)
one_hot.scatter_(-1, target_labels, 1)
target_labels = one_hot
  • 1
  • 2
  • 3
  • 4

在这里插入图片描述

完整代码

class v8DetectionLoss:
    def __init__(self, model):  # model must be de-paralleled
        ...
        self.vfl = VarifocalLoss()
        ...
    
    def __call__(self, preds, batch):
        ...
        target_labels, target_bboxes, target_scores, fg_mask, _ = self.assigner(
                    pred_scores.detach().sigmoid(),
                    (pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype),
                    anchor_points * stride_tensor,
                    gt_labels,
                    gt_bboxes,
                    mask_gt,
        )
        
        target_labels = target_labels.unsqueeze(-1).expand(-1, -1, self.nc)  # self.nc: class num
        one_hot = torch.zeros(target_labels.size(), device=self.device)
        one_hot.scatter_(-1, target_labels, 1)
        target_labels = one_hot
        
        loss[1] = self.varifocal_loss(pred_scores, target_scores, target_labels) / target_scores_sum  # VFL way
        ...
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24

但是我个人感觉效果不太好,不知道是否哪里搞错了。不过 glenn-jocher 说 one-hot 编码 label 部分是正确的,效果不如 bce loss 可能和数据集还别的因素造成的。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/花生_TL007/article/detail/402640
推荐阅读
相关标签
  

闽ICP备14008679号