当前位置:   article > 正文

yolov5增加iou loss(SIoU,EIoU,WIoU),无痛涨点trick

wiou

        yolo无痛涨点trick,简单实用

        先贴一张最近一篇论文的结果

后来的几种iou的消融实验结果在一定程度上要优于CIoU。

        本文将在yolov5的基础上增加SIoU,EIoU,Focal-XIoU(X为C,D,G,E,S等)以及AlphaXIoU。

        在yolov5的utils文件夹下新增iou.py文件

  1. import math
  2. import torch
  3. def bbox_iou(box1,
  4. box2,
  5. xywh=True,
  6. GIoU=False,
  7. DIoU=False,
  8. CIoU=False,
  9. SIoU=False,
  10. EIoU=False,
  11. WIoU=False,
  12. Focal=False,
  13. alpha=1,
  14. gamma=0.5,
  15. scale=False,
  16. monotonous=False,
  17. eps=1e-7):
  18. """
  19. 计算bboxes iou
  20. Args:
  21. box1: predict bboxes
  22. box2: target bboxes
  23. xywh: 将bboxes转换为xyxy的形式
  24. GIoU: 为True时计算GIoU LOSS (yolov5自带)
  25. DIoU: 为True时计算DIoU LOSS (yolov5自带)
  26. CIoU: 为True时计算CIoU LOSS (yolov5自带,默认使用)
  27. SIoU: 为True时计算SIoU LOSS (新增)
  28. EIoU: 为True时计算EIoU LOSS (新增)
  29. WIoU: 为True时计算WIoU LOSS (新增)
  30. Focal: 为True时,可结合其他的XIoU生成对应的IoU变体,如CIoU=True,Focal=True时为Focal-CIoU
  31. alpha: AlphaIoU中的alpha参数,默认为1,为1时则为普通的IoU,如果想采用AlphaIoU,论文alpha默认值为3,此时设置CIoU=True则为AlphaCIoU
  32. gamma: Focal_XIoU中的gamma参数,默认为0.5
  33. scale: scale为True时,WIoU会乘以一个系数
  34. monotonous: 3个输入分别代表WIoU的3个版本,None: origin v1, True: monotonic FM v2, False: non-monotonic FM v3
  35. eps: 防止除0
  36. Returns:
  37. iou
  38. """
  39. # Returns Intersection over Union (IoU) of box1(1,4) to box2(n,4)
  40. # Get the coordinates of bounding boxes
  41. if xywh: # transform from xywh to xyxy
  42. (x1, y1, w1, h1), (x2, y2, w2, h2) = box1.chunk(4, -1), box2.chunk(4, -1)
  43. w1_, h1_, w2_, h2_ = w1 / 2, h1 / 2, w2 / 2, h2 / 2
  44. b1_x1, b1_x2, b1_y1, b1_y2 = x1 - w1_, x1 + w1_, y1 - h1_, y1 + h1_
  45. b2_x1, b2_x2, b2_y1, b2_y2 = x2 - w2_, x2 + w2_, y2 - h2_, y2 + h2_
  46. else: # x1, y1, x2, y2 = box1
  47. b1_x1, b1_y1, b1_x2, b1_y2 = box1.chunk(4, -1)
  48. b2_x1, b2_y1, b2_x2, b2_y2 = box2.chunk(4, -1)
  49. w1, h1 = b1_x2 - b1_x1, (b1_y2 - b1_y1).clamp(eps)
  50. w2, h2 = b2_x2 - b2_x1, (b2_y2 - b2_y1).clamp(eps)
  51. # Intersection area
  52. inter = (b1_x2.minimum(b2_x2) - b1_x1.maximum(b2_x1)).clamp(0) * \
  53. (b1_y2.minimum(b2_y2) - b1_y1.maximum(b2_y1)).clamp(0)
  54. # Union Area
  55. union = w1 * h1 + w2 * h2 - inter + eps
  56. if scale:
  57. wise_scale = WIoU_Scale(1 - (inter / union), monotonous=monotonous)
  58. # IoU
  59. # iou = inter / union # ori iou
  60. iou = torch.pow(inter / (union + eps), alpha) # alpha iou
  61. if CIoU or DIoU or GIoU or EIoU or SIoU or WIoU:
  62. cw = b1_x2.maximum(b2_x2) - b1_x1.minimum(b2_x1) # convex (smallest enclosing box) width
  63. ch = b1_y2.maximum(b2_y2) - b1_y1.minimum(b2_y1) # convex height
  64. if CIoU or DIoU or EIoU or SIoU or WIoU: # Distance or Complete IoU https://arxiv.org/abs/1911.08287v1
  65. c2 = (cw ** 2 + ch ** 2) ** alpha + eps # convex diagonal squared
  66. rho2 = (((b2_x1 + b2_x2 - b1_x1 - b1_x2) ** 2 + (
  67. b2_y1 + b2_y2 - b1_y1 - b1_y2) ** 2) / 4) ** alpha # center dist ** 2
  68. if CIoU: # https://github.com/Zzh-tju/DIoU-SSD-pytorch/blob/master/utils/box/box_utils.py#L47
  69. v = (4 / math.pi ** 2) * (torch.atan(w2 / h2) - torch.atan(w1 / h1)).pow(2)
  70. with torch.no_grad():
  71. alpha_ciou = v / (v - iou + (1 + eps))
  72. if Focal:
  73. return iou - (rho2 / c2 + torch.pow(v * alpha_ciou + eps, alpha)), torch.pow(inter / (union + eps),
  74. gamma) # Focal_CIoU
  75. return iou - (rho2 / c2 + torch.pow(v * alpha_ciou + eps, alpha)) # CIoU
  76. elif EIoU:
  77. rho_w2 = ((b2_x2 - b2_x1) - (b1_x2 - b1_x1)) ** 2
  78. rho_h2 = ((b2_y2 - b2_y1) - (b1_y2 - b1_y1)) ** 2
  79. cw2 = torch.pow(cw ** 2 + eps, alpha)
  80. ch2 = torch.pow(ch ** 2 + eps, alpha)
  81. if Focal:
  82. return iou - (rho2 / c2 + rho_w2 / cw2 + rho_h2 / ch2), torch.pow(inter / (union + eps), gamma) # Focal_EIou
  83. return iou - (rho2 / c2 + rho_w2 / cw2 + rho_h2 / ch2) # EIou
  84. elif SIoU:
  85. # SIoU Loss https://arxiv.org/pdf/2205.12740.pdf
  86. s_cw, s_ch = (b2_x1 + b2_x2 - b1_x1 - b1_x2) * 0.5 + eps, (b2_y1 + b2_y2 - b1_y1 - b1_y2) * 0.5 + eps
  87. sigma = torch.pow(s_cw ** 2 + s_ch ** 2, 0.5)
  88. sin_alpha_1, sin_alpha_2 = torch.abs(s_cw) / sigma, torch.abs(s_ch) / sigma
  89. threshold = pow(2, 0.5) / 2
  90. sin_alpha = torch.where(sin_alpha_1 > threshold, sin_alpha_2, sin_alpha_1)
  91. angle_cost = torch.cos(torch.arcsin(sin_alpha) * 2 - math.pi / 2)
  92. rho_x, rho_y = (s_cw / cw) ** 2, (s_ch / ch) ** 2
  93. gamma = angle_cost - 2
  94. distance_cost = 2 - torch.exp(gamma * rho_x) - torch.exp(gamma * rho_y)
  95. omiga_w, omiga_h = torch.abs(w1 - w2) / torch.max(w1, w2), torch.abs(h1 - h2) / torch.max(h1, h2)
  96. shape_cost = torch.pow(1 - torch.exp(-1 * omiga_w), 4) + torch.pow(1 - torch.exp(-1 * omiga_h), 4)
  97. if Focal:
  98. return iou - torch.pow(0.5 * (distance_cost + shape_cost) + eps, alpha), torch.pow(
  99. inter / (union + eps), gamma) # Focal_SIou
  100. return iou - torch.pow(0.5 * (distance_cost + shape_cost) + eps, alpha) # SIou
  101. elif WIoU:
  102. if scale:
  103. return getattr(WIoU_Scale, '_scaled_loss')(wise_scale), (1 - iou) * torch.exp((rho2 / c2)), iou # WIoU v3 https://arxiv.org/abs/2301.10051
  104. return iou, torch.exp((rho2 / c2)) # WIoU v1
  105. if Focal:
  106. return iou - rho2 / c2, torch.pow(inter / (union + eps), gamma) # Focal_DIoU
  107. return iou - rho2 / c2 # DIoU
  108. c_area = cw * ch + eps # convex area
  109. if Focal:
  110. return iou - torch.pow((c_area - union) / c_area + eps, alpha), torch.pow(inter / (union + eps), gamma) # Focal_GIoU https://arxiv.org/pdf/1902.09630.pdf
  111. return iou - torch.pow((c_area - union) / c_area + eps, alpha) # GIoU https://arxiv.org/pdf/1902.09630.pdf
  112. if Focal:
  113. return iou, torch.pow(inter / (union + eps), gamma) # Focal_IoU
  114. return iou # IoU
  115. class WIoU_Scale:
  116. """
  117. monotonous: {
  118. None: origin v1
  119. True: monotonic FM v2
  120. False: non-monotonic FM v3
  121. }
  122. momentum: The momentum of running mean
  123. """
  124. iou_mean = 1.
  125. _momentum = 1 - pow(0.5, exp=1 / 7000)
  126. _is_train = True
  127. def __init__(self, iou, monotonous=False):
  128. self.iou = iou
  129. self.monotonous = monotonous
  130. self._update(self)
  131. @classmethod
  132. def _update(cls, self):
  133. if cls._is_train: cls.iou_mean = (1 - cls._momentum) * cls.iou_mean + \
  134. cls._momentum * self.iou.detach().mean().item()
  135. @classmethod
  136. def _scaled_loss(cls, self, gamma=1.9, delta=3):
  137. if isinstance(self.monotonous, bool):
  138. if self.monotonous:
  139. return (self.iou.detach() / self.iou_mean).sqrt()
  140. else:
  141. beta = self.iou.detach() / self.iou_mean
  142. alpha = delta * torch.pow(gamma, beta - delta)
  143. return beta / alpha
  144. return 1

在调用bbox_iou函数的地方做如下修改(主要是__call__中):

  1. class ComputeLoss:
  2. sort_obj_iou = False
  3. # Compute losses
  4. def __init__(self, model, autobalance=False):
  5. device = next(model.parameters()).device # get model device
  6. h = model.hyp # hyperparameters
  7. # Define criteria
  8. BCEcls = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([h['cls_pw']], device=device))
  9. BCEobj = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([h['obj_pw']], device=device))
  10. # Class label smoothing https://arxiv.org/pdf/1902.04103.pdf eqn 3
  11. self.cp, self.cn = smooth_BCE(eps=h.get('label_smoothing', 0.0)) # positive, negative BCE targets
  12. # Focal loss
  13. g = h['fl_gamma'] # focal loss gamma
  14. if g > 0:
  15. BCEcls, BCEobj = FocalLoss(BCEcls, g), FocalLoss(BCEobj, g)
  16. m = de_parallel(model).model[-1] # Detect() module
  17. self.balance = {3: [4.0, 1.0, 0.4]}.get(m.nl, [4.0, 1.0, 0.25, 0.06, 0.02]) # P3-P7
  18. self.ssi = list(m.stride).index(16) if autobalance else 0 # stride 16 index
  19. self.BCEcls, self.BCEobj, self.gr, self.hyp, self.autobalance = BCEcls, BCEobj, 1.0, h, autobalance
  20. self.na = m.na # number of anchors
  21. self.nc = m.nc # number of classes
  22. self.nl = m.nl # number of layers
  23. self.anchors = m.anchors
  24. self.device = device
  25. def __call__(self, p, targets): # predictions, targets
  26. lcls = torch.zeros(1, device=self.device) # class loss
  27. lbox = torch.zeros(1, device=self.device) # box loss
  28. lobj = torch.zeros(1, device=self.device) # object loss
  29. tcls, tbox, indices, anchors = self.build_targets(p, targets) # targets
  30. # Losses
  31. for i, pi in enumerate(p): # layer index, layer predictions
  32. b, a, gj, gi = indices[i] # image, anchor, gridy, gridx
  33. tobj = torch.zeros(pi.shape[:4], dtype=pi.dtype, device=self.device) # target obj
  34. n = b.shape[0] # number of targets
  35. if n:
  36. # pxy, pwh, _, pcls = pi[b, a, gj, gi].tensor_split((2, 4, 5), dim=1) # faster, requires torch 1.8.0
  37. pxy, pwh, _, pcls = pi[b, a, gj, gi].split((2, 2, 1, self.nc), 1) # target-subset of predictions
  38. # Regression
  39. pxy = pxy.sigmoid() * 2 - 0.5
  40. pwh = (pwh.sigmoid() * 2) ** 2 * anchors[i]
  41. pbox = torch.cat((pxy, pwh), 1) # predicted box
  42. # iou = bbox_iou(pbox, tbox[i], CIoU=True).squeeze() # iou(prediction, target)
  43. # lbox += (1.0 - iou).mean() # iou loss
  44. # //
  45. iou = bbox_iou(pbox, tbox[i], WIoU=True, scale=True)
  46. if isinstance(iou, tuple):
  47. if len(iou) == 2:
  48. lbox += (iou[1].detach().squeeze() * (1 - iou[0].squeeze())).mean()
  49. iou = iou[0].squeeze()
  50. else:
  51. lbox += (iou[0] * iou[1]).mean()
  52. iou = iou[2].squeeze()
  53. else:
  54. lbox += (1.0 - iou.squeeze()).mean() # iou loss
  55. iou = iou.squeeze()
  56. # /
  57. # Objectness
  58. iou = iou.detach().clamp(0).type(tobj.dtype)
  59. if self.sort_obj_iou:
  60. j = iou.argsort()
  61. b, a, gj, gi, iou = b[j], a[j], gj[j], gi[j], iou[j]
  62. if self.gr < 1:
  63. iou = (1.0 - self.gr) + self.gr * iou
  64. tobj[b, a, gj, gi] = iou # iou ratio
  65. # Classification
  66. if self.nc > 1: # cls loss (only if multiple classes)
  67. t = torch.full_like(pcls, self.cn, device=self.device) # targets
  68. t[range(n), tcls[i]] = self.cp
  69. lcls += self.BCEcls(pcls, t) # BCE
  70. # Append targets to text file
  71. # with open('targets.txt', 'a') as file:
  72. # [file.write('%11.5g ' * 4 % tuple(x) + '\n') for x in torch.cat((txy[i], twh[i]), 1)]
  73. obji = self.BCEobj(pi[..., 4], tobj)
  74. lobj += obji * self.balance[i] # obj loss
  75. if self.autobalance:
  76. self.balance[i] = self.balance[i] * 0.9999 + 0.0001 / obji.detach().item()
  77. if self.autobalance:
  78. self.balance = [x / self.balance[self.ssi] for x in self.balance]
  79. lbox *= self.hyp['box']
  80. lobj *= self.hyp['obj']
  81. lcls *= self.hyp['cls']
  82. bs = tobj.shape[0] # batch size
  83. return (lbox + lobj + lcls) * bs, torch.cat((lbox, lobj, lcls)).detach()
  84. def build_targets(self, p, targets):
  85. # Build targets for compute_loss(), input targets(image,class,x,y,w,h)
  86. na, nt = self.na, targets.shape[0] # number of anchors, targets
  87. tcls, tbox, indices, anch = [], [], [], []
  88. gain = torch.ones(7, device=self.device) # normalized to gridspace gain
  89. ai = torch.arange(na, device=self.device).float().view(na, 1).repeat(1, nt) # same as .repeat_interleave(nt)
  90. targets = torch.cat((targets.repeat(na, 1, 1), ai[..., None]), 2) # append anchor indices
  91. g = 0.5 # bias
  92. off = torch.tensor(
  93. [
  94. [0, 0],
  95. [1, 0],
  96. [0, 1],
  97. [-1, 0],
  98. [0, -1], # j,k,l,m
  99. # [1, 1], [1, -1], [-1, 1], [-1, -1], # jk,jm,lk,lm
  100. ],
  101. device=self.device).float() * g # offsets
  102. for i in range(self.nl):
  103. anchors, shape = self.anchors[i], p[i].shape
  104. gain[2:6] = torch.tensor(shape)[[3, 2, 3, 2]] # xyxy gain
  105. # Match targets to anchors
  106. t = targets * gain # shape(3,n,7)
  107. if nt:
  108. # Matches
  109. r = t[..., 4:6] / anchors[:, None] # wh ratio
  110. j = torch.max(r, 1 / r).max(2)[0] < self.hyp['anchor_t'] # compare
  111. # j = wh_iou(anchors, t[:, 4:6]) > model.hyp['iou_t'] # iou(3,n)=wh_iou(anchors(3,2), gwh(n,2))
  112. t = t[j] # filter
  113. # Offsets
  114. gxy = t[:, 2:4] # grid xy
  115. gxi = gain[[2, 3]] - gxy # inverse
  116. j, k = ((gxy % 1 < g) & (gxy > 1)).T
  117. l, m = ((gxi % 1 < g) & (gxi > 1)).T
  118. j = torch.stack((torch.ones_like(j), j, k, l, m))
  119. t = t.repeat((5, 1, 1))[j]
  120. offsets = (torch.zeros_like(gxy)[None] + off[:, None])[j]
  121. else:
  122. t = targets[0]
  123. offsets = 0
  124. # Define
  125. bc, gxy, gwh, a = t.chunk(4, 1) # (image, class), grid xy, grid wh, anchors
  126. a, (b, c) = a.long().view(-1), bc.long().T # anchors, image, class
  127. gij = (gxy - offsets).long()
  128. gi, gj = gij.T # grid indices
  129. # Append
  130. indices.append((b, a, gj.clamp_(0, shape[2] - 1), gi.clamp_(0, shape[3] - 1))) # image, anchor, grid
  131. tbox.append(torch.cat((gxy - gij, gwh), 1)) # box
  132. anch.append(anchors[a]) # anchors
  133. tcls.append(c) # class
  134. return tcls, tbox, indices, anch
'
运行

        注意需要从对应的py文件中import对应的函数,并需要注释原始函数

# from utils.metrics import bbox_iou
from utils.iou import bbox_iou

         如果需要应用对应的IoU loss的变体,即可将Focal设置为True,并将对应的IoU也设置为True,如CIoU=True,Focal=True时为Focal-CIoU,此时可以调整gamma,默认设置为0.5。

        如果想要使用AlphaXIoU,将alpha设置为3同时将对应的IoU也设置为True即可,alpha默认设置为1。

        更新WIoU,monotonous有3个输入分别代表WIoU的3个版本,None: origin v1, True: monotonic FM v2, False: non-monotonic FM v3,同时需要设置scale,scale为True时,WIoU会乘以一个系数,结合monotonous即会对应WIoU的3个版本。

        yolov7的代码结构也是一样的,也可以替换到yolov7中,__call__中的bbox_iou函数要改成yolov5的调用方式(pbox不用矩阵转置(T))。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/从前慢现在也慢/article/detail/800627
推荐阅读
相关标签
  

闽ICP备14008679号