当前位置:   article > 正文

优化改进YOLOv5算法之Wise-IOU损失函数_wise iou

wise iou

1 Wise-IOU损失函数

边界框回归(BBR)的损失函数对于目标检测至关重要。它的良好定义将为模型带来显著的性能改进。大多数现有的工作假设训练数据中的样本是高质量的,并侧重于增强BBR损失的拟合能力。如果盲目地加强低质量样本的BBR,这将危及本地化性能。Focal EIoU v1被提出来解决这个问题,但由于其静态聚焦机制(FM),非单调FM的潜力没有被充分利用。基于这一思想,作者提出了一种基于IoU的损失,该损失具有动态非单调FM,名为Wise IoU(WIoU)。当WIoU应用于最先进的实时检测器YOLOv7时,MS-COCO数据集上的AP75从53.03%提高到54.50%。

现有工作记锚框为 ,目标框为 

 

 IoU 用于度量目标检测任务中预测框与真实框的重叠程度,定义为:

同时,IoU 有一个致命的缺陷,可以在下面公式中观察到。当边界框之间没有重叠时 , 反向传播的梯度消失。这导致重叠区域的宽度  在训练时无法更新

现有的工作考虑了许多与包围盒相关的几何因素并构造了惩罚项  来解决这个问题,现有的边界框损失都是基于加法的损失,并遵循以下范式:

Distance-IoU
DIoU 将惩罚项定义为中心点连接的归一化长度:

同时为最小包围框的尺寸 提供了负梯度,这将使得 增大而阻碍预测框与目标框重叠:  

但不可否认的是,距离度量的确是一个极其有效的解决方案,成为高效边界框损失的必要因子。EIoU 在此基础上加大了对距离度量的惩罚力度,其惩罚项定义为:

Complete-IoU
的基础上,CIoU 增加了对纵横比一致性的考虑:

其中的描述了纵横比一致性:

其中反向传播的梯度满足 ,也就是不可能为预测框的宽高提供同号的梯度。在前文对 DIoU 的分析中可知 DIoU 会产生负梯度,当这个负梯度与正好抵消时,会导致预测框无法优化。而 CIoU 对纵横比一致性的考虑将打破这种僵局。

 Scylla-IoU
Zhora Gevorgyan 证明了中心对齐的边界框会具有更快的收敛速度,以 angle cost、distance cost、shape cost 构造了 SIoU。其中 angle cost 描述了边界框中心连线与 x-y 轴的最小夹角:

distance cost 描述了两边界框的中心点在x轴和y轴上的归一化距离,其惩罚力度与 angle cost 正相关。distance cost 被定义为:

shape cost 描述了两边界框的形状差异,当两边界框的尺寸不一致时不为 0。shape cost 被定义为:

类似,它们都由 distance cost 和 shape cost 组成:  

Wise IoU

Wise-IoU v1
因为训练数据中难以避免地包含低质量示例,所以如距离、纵横比之类的几何度量都会加剧对低质量示例的惩罚从而使模型的泛化性能下降。好的损失函数应该在锚框与目标框较好地重合时削弱几何度量的惩罚,不过多地干预训练将使模型有更好的泛化能力。在此基础上,我们根据距离度量构建了距离注意力,得到了具有两层注意力机制的 WIoU v1:

  • ,这将显著放大普通质量锚框的 

,这将显著降低高质量锚框的,并在锚框与目标框重合较好的情况下显著降低其对中心点距离的关注

为了防止产生阻碍收敛的梯度,将从计算图 (上标 * 表示此操作) 中分离。因为它有效地消除了阻碍收敛的因素,所以我们没有引入新的度量指标,如纵横比。 

Wise-IoU v2
Focal Loss 设计了一种针对交叉熵的单调聚焦机制,有效降低了简单示例对损失值的贡献。这使得模型能够聚焦于困难示例,获得分类性能的提升。类似地,我们构造了的单调聚焦系数:

 在模型训练过程中,梯度增益随着的减小而减小,导致训练后期收敛速度较慢。因此,引入的均值作为归一化因子:

 

其中的为动量为m的滑动平均值,动态更新归一化因子使梯度增益整体保持在较高水平,解决了训练后期收敛速度慢的问题  

Wise-IoU v3
定义离群度以描述锚框的质量,其定义为:

离群度小意味着锚框质量高,我们为其分配一个小的梯度增益,以便使边界框回归聚焦到普通质量的锚框上。对离群度较大的锚框分配较小的梯度增益,将有效防止低质量示例产生较大的有害梯度。我们利用  构造了一个非单调聚焦系数并将其应用于 WIoU v1:

其中,当时, 使得。当锚框的离群程度满足(为定值)时,锚框将获得最高的梯度增益。由于是动态的,锚框的质量划分标准也是动态的,这使得 WIoU v3 在每一时刻都能做出最符合当前情况的梯度增益分配策略      

为了防止低质量锚框在训练初期落后,我们初始化使得的锚框具有最高的梯度增益。为了在训练的早期阶段保持这样的策略,需要设置一个小的动量来延迟接近真实值  的时间。对于 batch size 为的训练,我们建议将动量设置为:

 这种设置使得经过t轮训练后有。在训练的中后期,WIoU v3 将小梯度增益分配给低质量的锚框以减少有害梯度。同时 WIoU v3 会聚焦于普通质量的锚框,提高模型的定位性能 

2 YOLOv5中添加Wise-IOU损失函数

yolov5-6.1版本中的iou损失函数是在utils/metrics.py文件定义的,在该文件添加以下关于Wise-IOU函数的代码,如下所示

  1. import numpy as np
  2. import torch, math
  3. class WIoU_Scale:
  4. ''' monotonous: {
  5. None: origin v1
  6. True: monotonic FM v2
  7. False: non-monotonic FM v3
  8. }
  9. momentum: The momentum of running mean'''
  10. iou_mean = 1.
  11. monotonous = False
  12. _momentum = 1 - 0.5 ** (1 / 7000)
  13. _is_train = True
  14. def __init__(self, iou):
  15. self.iou = iou
  16. self._update(self)
  17. @classmethod
  18. def _update(cls, self):
  19. if cls._is_train: cls.iou_mean = (1 - cls._momentum) * cls.iou_mean + \
  20. cls._momentum * self.iou.detach().mean().item()
  21. @classmethod
  22. def _scaled_loss(cls, self, gamma=1.9, delta=3):
  23. if isinstance(self.monotonous, bool):
  24. if self.monotonous:
  25. return (self.iou.detach() / self.iou_mean).sqrt()
  26. else:
  27. beta = self.iou.detach() / self.iou_mean
  28. alpha = delta * torch.pow(gamma, beta - delta)
  29. return beta / alpha
  30. return 1
  31. def bbox_iou(box1, box2, xywh=True, GIoU=False, DIoU=False, CIoU=False, SIoU=False, EIoU=False, WIoU=False, Focal=False, alpha=1, gamma=0.5, scale=False, eps=1e-7):
  32. # Returns Intersection over Union (IoU) of box1(1,4) to box2(n,4)
  33. # Get the coordinates of bounding boxes
  34. if xywh: # transform from xywh to xyxy
  35. (x1, y1, w1, h1), (x2, y2, w2, h2) = box1.chunk(4, -1), box2.chunk(4, -1)
  36. w1_, h1_, w2_, h2_ = w1 / 2, h1 / 2, w2 / 2, h2 / 2
  37. b1_x1, b1_x2, b1_y1, b1_y2 = x1 - w1_, x1 + w1_, y1 - h1_, y1 + h1_
  38. b2_x1, b2_x2, b2_y1, b2_y2 = x2 - w2_, x2 + w2_, y2 - h2_, y2 + h2_
  39. else: # x1, y1, x2, y2 = box1
  40. b1_x1, b1_y1, b1_x2, b1_y2 = box1.chunk(4, -1)
  41. b2_x1, b2_y1, b2_x2, b2_y2 = box2.chunk(4, -1)
  42. w1, h1 = b1_x2 - b1_x1, (b1_y2 - b1_y1).clamp(eps)
  43. w2, h2 = b2_x2 - b2_x1, (b2_y2 - b2_y1).clamp(eps)
  44. # Intersection area
  45. inter = (b1_x2.minimum(b2_x2) - b1_x1.maximum(b2_x1)).clamp(0) * \
  46. (b1_y2.minimum(b2_y2) - b1_y1.maximum(b2_y1)).clamp(0)
  47. # Union Area
  48. union = w1 * h1 + w2 * h2 - inter + eps
  49. if scale:
  50. self = WIoU_Scale(1 - (inter / union))
  51. # IoU
  52. # iou = inter / union # ori iou
  53. iou = torch.pow(inter/(union + eps), alpha) # alpha iou
  54. if CIoU or DIoU or GIoU or EIoU or SIoU or WIoU:
  55. cw = b1_x2.maximum(b2_x2) - b1_x1.minimum(b2_x1) # convex (smallest enclosing box) width
  56. ch = b1_y2.maximum(b2_y2) - b1_y1.minimum(b2_y1) # convex height
  57. if CIoU or DIoU or EIoU or SIoU or WIoU: # Distance or Complete IoU https://arxiv.org/abs/1911.08287v1
  58. c2 = (cw ** 2 + ch ** 2) ** alpha + eps # convex diagonal squared
  59. rho2 = (((b2_x1 + b2_x2 - b1_x1 - b1_x2) ** 2 + (b2_y1 + b2_y2 - b1_y1 - b1_y2) ** 2) / 4) ** alpha # center dist ** 2
  60. if CIoU: # https://github.com/Zzh-tju/DIoU-SSD-pytorch/blob/master/utils/box/box_utils.py#L47
  61. v = (4 / math.pi ** 2) * (torch.atan(w2 / h2) - torch.atan(w1 / h1)).pow(2)
  62. with torch.no_grad():
  63. alpha_ciou = v / (v - iou + (1 + eps))
  64. if Focal:
  65. return iou - (rho2 / c2 + torch.pow(v * alpha_ciou + eps, alpha)), torch.pow(inter/(union + eps), gamma) # Focal_CIoU
  66. else:
  67. return iou - (rho2 / c2 + torch.pow(v * alpha_ciou + eps, alpha)) # CIoU
  68. elif EIoU:
  69. rho_w2 = ((b2_x2 - b2_x1) - (b1_x2 - b1_x1)) ** 2
  70. rho_h2 = ((b2_y2 - b2_y1) - (b1_y2 - b1_y1)) ** 2
  71. cw2 = torch.pow(cw ** 2 + eps, alpha)
  72. ch2 = torch.pow(ch ** 2 + eps, alpha)
  73. if Focal:
  74. return iou - (rho2 / c2 + rho_w2 / cw2 + rho_h2 / ch2), torch.pow(inter/(union + eps), gamma) # Focal_EIou
  75. else:
  76. return iou - (rho2 / c2 + rho_w2 / cw2 + rho_h2 / ch2) # EIou
  77. elif SIoU:
  78. # SIoU Loss https://arxiv.org/pdf/2205.12740.pdf
  79. s_cw = (b2_x1 + b2_x2 - b1_x1 - b1_x2) * 0.5 + eps
  80. s_ch = (b2_y1 + b2_y2 - b1_y1 - b1_y2) * 0.5 + eps
  81. sigma = torch.pow(s_cw ** 2 + s_ch ** 2, 0.5)
  82. sin_alpha_1 = torch.abs(s_cw) / sigma
  83. sin_alpha_2 = torch.abs(s_ch) / sigma
  84. threshold = pow(2, 0.5) / 2
  85. sin_alpha = torch.where(sin_alpha_1 > threshold, sin_alpha_2, sin_alpha_1)
  86. angle_cost = torch.cos(torch.arcsin(sin_alpha) * 2 - math.pi / 2)
  87. rho_x = (s_cw / cw) ** 2
  88. rho_y = (s_ch / ch) ** 2
  89. gamma = angle_cost - 2
  90. distance_cost = 2 - torch.exp(gamma * rho_x) - torch.exp(gamma * rho_y)
  91. omiga_w = torch.abs(w1 - w2) / torch.max(w1, w2)
  92. omiga_h = torch.abs(h1 - h2) / torch.max(h1, h2)
  93. shape_cost = torch.pow(1 - torch.exp(-1 * omiga_w), 4) + torch.pow(1 - torch.exp(-1 * omiga_h), 4)
  94. if Focal:
  95. return iou - torch.pow(0.5 * (distance_cost + shape_cost) + eps, alpha), torch.pow(inter/(union + eps), gamma) # Focal_SIou
  96. else:
  97. return iou - torch.pow(0.5 * (distance_cost + shape_cost) + eps, alpha) # SIou
  98. elif WIoU:
  99. if Focal:
  100. raise RuntimeError("WIoU do not support Focal.")
  101. elif scale:
  102. return getattr(WIoU_Scale, '_scaled_loss')(self), (1 - iou) * torch.exp((rho2 / c2)), iou # WIoU https://arxiv.org/abs/2301.10051
  103. else:
  104. return iou, torch.exp((rho2 / c2)) # WIoU v1
  105. if Focal:
  106. return iou - rho2 / c2, torch.pow(inter/(union + eps), gamma) # Focal_DIoU
  107. else:
  108. return iou - rho2 / c2 # DIoU
  109. c_area = cw * ch + eps # convex area
  110. if Focal:
  111. return iou - torch.pow((c_area - union) / c_area + eps, alpha), torch.pow(inter/(union + eps), gamma) # Focal_GIoU https://arxiv.org/pdf/1902.09630.pdf
  112. else:
  113. return iou - torch.pow((c_area - union) / c_area + eps, alpha) # GIoU https://arxiv.org/pdf/1902.09630.pdf
  114. if Focal:
  115. return iou, torch.pow(inter/(union + eps), gamma) # Focal_IoU
  116. else:
  117. return iou # IoU

然后在utils/loss.py文件中调用bbox_iou损失函数时,将WIoU设置为True即可。

 

 

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/知新_RL/article/detail/396992
推荐阅读
相关标签
  

闽ICP备14008679号