赞
踩
class VarifocalLoss(nn.Module):
"""Varifocal loss by Zhang et al. https://arxiv.org/abs/2008.13367."""
def __init__(self):
"""Initialize the VarifocalLoss class."""
super().__init__()
def forward(self, pred_score, gt_score, label, alpha=0.75, gamma=2.0):
"""Computes varfocal loss."""
weight = alpha * pred_score.sigmoid().pow(gamma) * (1 - label) + gt_score * label
with torch.cuda.amp.autocast(enabled=False):
loss = (F.binary_cross_entropy_with_logits(pred_score.float(), gt_score.float(), reduction='none') *
weight).mean(1).sum()
return loss
最近老想在计算 yolov8 的分类 loss 的时候重新启用 varifocal loss,但是感觉有点坑。自己对 varifocal loss 的理解也不是很深,所以还是乖乖地理解透彻,再进行改造。
从 yolov8 的源码上来看,就是整个 varifocal loss 的就是通过 bce loss 乘以一个权重 weight 计算得到的。
但是我一直有一个疑问,代码和公式是怎么对应上的。
下面是 varifocal loss 的公式:
VFL
(
p
,
q
)
=
{
−
q
⋅
[
q
log
(
p
)
+
(
1
−
q
)
l
o
g
(
1
−
p
)
]
q
>
0
−
α
p
γ
log
(
1
−
p
)
q
=
0
\text{VFL}(p, q) = \left\{
我们知道 bce loss 的公式如下:
BCE
(
p
,
q
)
=
{
−
[
q
log
(
p
)
+
(
1
−
q
)
log
(
1
−
p
)
]
q
>
0
−
log
(
1
−
p
)
q
=
0
\text{BCE}(p, q) = \left\{
将 bce loss 的公式代入到 varifocal loss 中,就可以得到:
VFL
(
p
,
q
)
=
{
q
⋅
BCE
(
p
,
q
)
q
>
0
α
p
γ
⋅
BCE
(
p
,
q
)
q
=
0
\text{VFL}(p, q) = \left\{
如果存在一个 label 的掩码矩阵,可以有效的标识
q
>
0
q >0
q>0 或者
q
=
0
q=0
q=0的情况,我们可以进一步的将公式表示为:
VFL
(
p
,
q
)
=
α
p
γ
⋅
(
1
−
label
)
⋅
BCE
(
p
,
q
)
+
q
⋅
label
⋅
BCE
(
p
,
q
)
\text{VFL}(p, q) = \alpha p^{\gamma}\cdot (1- \text{label}) \cdot \text{BCE}(p, q) + q \cdot \text{label} \cdot \text{BCE}(p, q)
VFL(p,q)=αpγ⋅(1−label)⋅BCE(p,q)+q⋅label⋅BCE(p,q)
我们要怎么选择这个 label 掩码矩阵? q q q 为 0,意味着,候选框和真实框是没有交集的,也就是所谓不负责,这么看感觉和类别没啥关系。感觉只要求得一个 gt box 和 对应的 anchor box 之间的关系就行了。也就是知道哪个 anchor box 是正样本就行了。
回到代码中 pred_scores
就是
p
p
p, gt_score
就是
q
q
q,这样的话,代码中的权重计算可以表示为:
w
=
α
p
γ
⋅
(
1
−
label
)
+
q
⋅
label
w = \alpha p^{\gamma}\cdot (1-\text{label}) + q\cdot \text{label}
w=αpγ⋅(1−label)+q⋅label
也就是说代码中的 VFL 公式可以表示为:
VFL
(
p
,
q
)
=
w
⋅
BCE
(
p
,
q
)
=
α
p
γ
⋅
(
1
−
label
)
⋅
BCE
(
p
,
q
)
+
q
⋅
label
⋅
BCE
(
p
,
q
)
\text{VFL}(p, q) = w\cdot \text{BCE}(p, q) = \alpha p^{\gamma}\cdot (1-\text{label}) \cdot \text{BCE}(p, q) + q \cdot \text{label} \cdot \text{BCE}(p, q)
VFL(p,q)=w⋅BCE(p,q)=αpγ⋅(1−label)⋅BCE(p,q)+q⋅label⋅BCE(p,q)
可以看出来,我推导出来的公式和 yolov8 源码中代码的实现好像差了一个负号,我也不知道问题出来哪里,是代码的实现错了吗?但是官方好像之前确实用过 varifocal loss 来进行分类 loss 的计算。(注:经过评论区抓 Bug, 发现少了括号,所以符号错了,现在已经改正了,公式和代码中的是完全一样了。)
代码地址:mmdetection 的 varifocal loss 实现
def varifocal_loss(pred: Tensor, target: Tensor, weight: Optional[Tensor] = None, alpha: float = 0.75, gamma: float = 2.0, iou_weighted: bool = True, reduction: str = 'mean', avg_factor: Optional[int] = None) -> Tensor: # pred and target should be of the same size assert pred.size() == target.size() pred_sigmoid = pred.sigmoid() target = target.type_as(pred) if iou_weighted: focal_weight = target * (target > 0.0).float() + \ alpha * (pred_sigmoid - target).abs().pow(gamma) * \ (target <= 0.0).float() else: focal_weight = (target > 0.0).float() + \ alpha * (pred_sigmoid - target).abs().pow(gamma) * \ (target <= 0.0).float() loss = F.binary_cross_entropy_with_logits( pred, target, reduction='none') * focal_weight loss = weight_reduce_loss(loss, weight, reduction, avg_factor) return loss
看了一下上面的代码,因为是单一分类的 loss,考虑 iou_weight 的话,focal weight 的计算公式如下:
w
=
q
⋅
label
+
α
⋅
(
1
−
label
)
∣
p
−
q
∣
γ
w = q \cdot \text{label} + \alpha \cdot (1-\text{label}) |p-q|^{\gamma}
w=q⋅label+α⋅(1−label)∣p−q∣γ
最后 varifocal loss 的计算公式如下:
VFL
(
p
,
q
)
=
α
⋅
∣
p
−
q
∣
γ
⋅
(
1
−
label
)
⋅
BCE
(
p
,
q
)
+
q
⋅
label
⋅
BCE
(
p
,
q
)
\text{VFL}(p,q) =\alpha \cdot |p-q|^{\gamma} \cdot (1-\text{label}) \cdot \text{BCE}(p,q) + q \cdot \text{label} \cdot \text{BCE}(p,q )
VFL(p,q)=α⋅∣p−q∣γ⋅(1−label)⋅BCE(p,q)+q⋅label⋅BCE(p,q)
先前没有搞懂为啥有个 iou_weight,感觉和我在 yolov8 源码中看到的不一样,但是粗略看了一下论文,发现原来 varifocal loss 计算的 p p p 就是 IACS (iou aware classification score),就是 iou 与 cls score 组合起来的一个参数。如果说之前的 cls score 是使用 one-hot 来表示的话,现在就是使用 iou 来代替。
所以在 mmedetection 的源码中,如果不是用 IACS 来计算 loss 的话,就会将 iou_weight 设置成 False,这样就不需要考虑 target
中对应的具体 iou 值了。
想到这里,我在想 yolov8 是否计算了 IACS 值。
其实在官方的 issue 列表里面有一个 issue 就提到了这个问题:Using VFL loss in Yolov8 but the dimensions of the anchor class and gt class not match.
因为 Yolov8 的源码里是有 varifocal loss 的实现的,但是如果只是简单的将 bce loss 替换成 varifocal loss 的话,就会出现维度不匹配的问题,也就是 issue 标题的问题。所以我在
在 loss.py 这个文件里,我们可以找到 Varifocal
loss 和 focal loss 的实现,如果是用 yolov8 进行检测任务的话, 也可以看到 v8DetectionLoss,也就是,如果想要应用 Varifocal loss,在 v8DetectionLoss 进行相应的修改就好了。
Yolov8 的损失函数,分开来看由 bbox loss 和 cls loss 组成,bbox loss 又是通过 ciou loss 和 dfl loss 获得的,但是在这里我们就不展开了,官方的 cls loss 是使用 bce loss 实现的,但是上面有一行 varifocal loss 是如下注释掉的:
# Cls loss
# loss[1] = self.varifocal_loss(pred_scores, target_scores, target_labels) / target_scores_sum # VFL way
loss[1] = self.bce(pred_scores, target_scores.to(dtype)).sum() / target_scores_sum # BCE
但是如果你直接关掉注释,使用它的这句代码,就会出现 issue 中提到的维度不匹配问题,这是不能直接使用的,需要经过一定的数据处理才能将数据的维度对齐
也就是我们需要写一段维度对齐的代码来保证 varifocal loss 函数的输入是正确的。
从前面的公式解释中,我们就可以看到,target_labels
在 varifocal loss 里就是 label
。
glenn-jocher 对于 label 的解释如下:
大体来说,就是一个 bounding box 和 ground truth box 的关系矩阵,只要两者是有联系的,在 varifocal loss 里,通常要编码成 one-hot 矩阵。
在 issue 1448 也有一位大佬提到虽然已经不维护 varifocal loss 了,但是还记得 target_labels
是要编码成 one-hot 形式然后再传进去。
初始的 target_labels
是通过 self.assigner
获得的,将第一个参数显式的写出来就好了。
target_labels, target_bboxes, target_scores, fg_mask, _ = self.assigner(
pred_scores.detach().sigmoid(),
(pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype),
anchor_points * stride_tensor,
gt_labels,
gt_bboxes,
mask_gt,
)
但是这个初始的 target_labels
的维度是 [batchsize, anchor]
,我的代码里是 [16, 8400], 里面的数字是类别信息。
tensor([[11, 11, 11, ..., 11, 11, 11],
[11, 11, 11, ..., 11, 11, 11],
[ 0, 0, 0, ..., 0, 0, 0],
...,
[ 9, 9, 9, ..., 9, 9, 9],
[ 8, 8, 8, ..., 8, 8, 8],
[ 2, 2, 2, ..., 2, 2, 2]], device='cuda:0')
所以通过下面的转换,就可以将 target_labels
转换成 one-hot 的编码,进而计算 varifocal loss。
target_labels = target_labels.unsqueeze(-1).expand(-1, -1, self.nc) # self.nc: class num
one_hot = torch.zeros(target_labels.size(), device=self.device)
one_hot.scatter_(-1, target_labels, 1)
target_labels = one_hot
class v8DetectionLoss: def __init__(self, model): # model must be de-paralleled ... self.vfl = VarifocalLoss() ... def __call__(self, preds, batch): ... target_labels, target_bboxes, target_scores, fg_mask, _ = self.assigner( pred_scores.detach().sigmoid(), (pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype), anchor_points * stride_tensor, gt_labels, gt_bboxes, mask_gt, ) target_labels = target_labels.unsqueeze(-1).expand(-1, -1, self.nc) # self.nc: class num one_hot = torch.zeros(target_labels.size(), device=self.device) one_hot.scatter_(-1, target_labels, 1) target_labels = one_hot loss[1] = self.varifocal_loss(pred_scores, target_scores, target_labels) / target_scores_sum # VFL way ...
但是我个人感觉效果不太好,不知道是否哪里搞错了。不过 glenn-jocher 说 one-hot 编码 label 部分是正确的,效果不如 bce loss 可能和数据集还别的因素造成的。
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。