当前位置:   article > 正文

目标检测算法——YOLOv5/v7/v8改进结合涨点Trick之Wise-IoU(超越CIOU/SIOU)

wise-iou


超越CIOU/SIOU | Wise-IoU助力YOLO强势涨点!!!


论文题目:Wise-IoU: Bounding Box Regression Loss with Dynamic Focusing Mechanism

论文链接:https://arxiv.org/abs/2301.10051

近年来的研究大多假设训练数据中的示例有较高的质量,致力于强化边界框损失的拟合能力。但注意到目标检测训练集中含有低质量示例,如果一味地强化边界框对低质量示例的回归,显然会危害模型检测性能的提升。

Focal-EIoU v1 被提出以解决这个问题,但由于其聚焦机制是静态的,并未充分挖掘非单调聚焦机制的潜能。基于这个观点,作者提出了动态非单调的聚焦机制,设计了 Wise-IoU (WIoU)。动态非单调聚焦机制使用“离群度”替代 IoU 对锚框进行质量评估,并提供了明智的梯度增益分配策略。该策略在降低高质量锚框的竞争力的同时,也减小了低质量示例产生的有害梯度。这使得 WIoU 可以聚焦于普通质量的锚框,并提高检测器的整体性能。将WIoU应用于最先进的单级检测器 YOLOv7 时,在 MS-COCO 数据集上的 AP-75 从 53.03% 提升到 54.50%。

一、 Wise-IoU相关代码

  1. import math
  2. import torch
  3. class IoU_Cal:
  4. ''' pred, target: x0,y0,x1,y1
  5. monotonous: {
  6. None: origin
  7. True: monotonic FM
  8. False: non-monotonic FM
  9. }
  10. momentum: The momentum of running mean (This can be set by the function <momentum_estimation>)'''
  11. iou_mean = 1.
  12. monotonous = False
  13. momentum = 1 - pow(0.05, 1 / (890 * 34))
  14. _is_train = True
  15. @classmethod
  16. def momentum_estimation(cls, n, t):
  17. ''' n: Number of batches per training epoch
  18. t: The epoch when mAP's ascension slowed significantly'''
  19. time_to_real = n * t
  20. cls.momentum = 1 - pow(0.05, 1 / time_to_real)
  21. return cls.momentum
  22. def __init__(self, pred, target):
  23. self.pred, self.target = pred, target
  24. self._fget = {
  25. # x,y,w,h
  26. 'pred_xy': lambda: (self.pred[..., :2] + self.pred[..., 2: 4]) / 2,
  27. 'pred_wh': lambda: self.pred[..., 2: 4] - self.pred[..., :2],
  28. 'target_xy': lambda: (self.target[..., :2] + self.target[..., 2: 4]) / 2,
  29. 'target_wh': lambda: self.target[..., 2: 4] - self.target[..., :2],
  30. # x0,y0,x1,y1
  31. 'min_coord': lambda: torch.minimum(self.pred[..., :4], self.target[..., :4]),
  32. 'max_coord': lambda: torch.maximum(self.pred[..., :4], self.target[..., :4]),
  33. # The overlapping region
  34. 'wh_inter': lambda: torch.relu(self.min_coord[..., 2: 4] - self.max_coord[..., :2]),
  35. 's_inter': lambda: torch.prod(self.wh_inter, dim=-1),
  36. # The area covered
  37. 's_union': lambda: torch.prod(self.pred_wh, dim=-1) +
  38. torch.prod(self.target_wh, dim=-1) - self.s_inter,
  39. # The smallest enclosing box
  40. 'wh_box': lambda: self.max_coord[..., 2: 4] - self.min_coord[..., :2],
  41. 's_box': lambda: torch.prod(self.wh_box, dim=-1),
  42. 'l2_box': lambda: torch.square(self.wh_box).sum(dim=-1),
  43. # The central points' connection of the bounding boxes
  44. 'd_center': lambda: self.pred_xy - self.target_xy,
  45. 'l2_center': lambda: torch.square(self.d_center).sum(dim=-1),
  46. # IoU
  47. 'iou': lambda: 1 - self.s_inter / self.s_union
  48. }
  49. self._update(self)
  50. def __setitem__(self, key, value):
  51. self._fget[key] = value
  52. def __getattr__(self, item):
  53. if callable(self._fget[item]):
  54. self._fget[item] = self._fget[item]()
  55. return self._fget[item]
  56. @classmethod
  57. def train(cls):
  58. cls._is_train = True
  59. @classmethod
  60. def eval(cls):
  61. cls._is_train = False
  62. @classmethod
  63. def _update(cls, self):
  64. if cls._is_train: cls.iou_mean = (1 - cls.momentum) * cls.iou_mean + \
  65. cls.momentum * self.iou.detach().mean().item()
  66. def _scaled_loss(self, loss, alpha=1.9, delta=3):
  67. if isinstance(self.monotonous, bool):
  68. beta = self.iou.detach() / self.iou_mean
  69. if self.monotonous:
  70. loss *= beta.sqrt()
  71. else:
  72. divisor = delta * torch.pow(alpha, beta - delta)
  73. loss *= beta / divisor
  74. return loss
  75. @classmethod
  76. def IoU(cls, pred, target, self=None):
  77. self = self if self else cls(pred, target)
  78. return self.iou
  79. @classmethod
  80. def WIoU(cls, pred, target, self=None):
  81. self = self if self else cls(pred, target)
  82. dist = torch.exp(self.l2_center / self.l2_box.detach())
  83. return self._scaled_loss(dist * self.iou)
  84. @classmethod
  85. def EIoU(cls, pred, target, self=None):
  86. self = self if self else cls(pred, target)
  87. penalty = self.l2_center / self.l2_box.detach() \
  88. + torch.square(self.d_center / self.wh_box).sum(dim=-1)
  89. return self._scaled_loss(self.iou + penalty)
  90. @classmethod
  91. def GIoU(cls, pred, target, self=None):
  92. self = self if self else cls(pred, target)
  93. return self._scaled_loss(self.iou + (self.s_box - self.s_union) / self.s_box)
  94. @classmethod
  95. def DIoU(cls, pred, target, self=None):
  96. self = self if self else cls(pred, target)
  97. return self._scaled_loss(self.iou + self.l2_center / self.l2_box)
  98. @classmethod
  99. def CIoU(cls, pred, target, eps=1e-4, self=None):
  100. self = self if self else cls(pred, target)
  101. v = 4 / math.pi ** 2 * \
  102. (torch.atan(self.pred_wh[..., 0] / (self.pred_wh[..., 1] + eps)) -
  103. torch.atan(self.target_wh[..., 0] / (self.target_wh[..., 1] + eps))) ** 2
  104. alpha = v / (self.iou + v)
  105. return self._scaled_loss(self.iou + self.l2_center / self.l2_box + alpha.detach() * v)
  106. @classmethod
  107. def SIoU(cls, pred, target, theta=4, self=None):
  108. self = self if self else cls(pred, target)
  109. # Angle Cost
  110. angle = torch.arcsin(torch.abs(self.d_center).min(dim=-1)[0] / (self.l2_center.sqrt() + 1e-4))
  111. angle = torch.sin(2 * angle) - 2
  112. # Dist Cost
  113. dist = angle[..., None] * torch.square(self.d_center / self.wh_box)
  114. dist = 2 - torch.exp(dist[..., 0]) - torch.exp(dist[..., 1])
  115. # Shape Cost
  116. d_shape = torch.abs(self.pred_wh - self.target_wh)
  117. big_shape = torch.maximum(self.pred_wh, self.target_wh)
  118. w_shape = 1 - torch.exp(- d_shape[..., 0] / big_shape[..., 0])
  119. h_shape = 1 - torch.exp(- d_shape[..., 1] / big_shape[..., 1])
  120. shape = w_shape ** theta + h_shape ** theta
  121. return self._scaled_loss(self.iou + (dist + shape) / 2)

二、实验对比结果


声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/Gausst松鼠会/article/detail/271257

推荐阅读
相关标签