当前位置:   article > 正文

YOLOv7损失函数修改_yolov7修改损失函数

yolov7修改损失函数

第一步:在utils/general.py文件中,修改bbox_iou函数

  1. def bbox_iou(box1, box2, x1y1x2y2=True, GIoU=False, DIoU=False, CIoU=False, EIoU=False, SIoU=False, eps=1e-7):
  2. # Returns the IoU of box1 to box2. box1 is 4, box2 is nx4
  3. box2 = box2.T
  4. # Get the coordinates of bounding boxes
  5. if x1y1x2y2: # x1, y1, x2, y2 = box1
  6. b1_x1, b1_y1, b1_x2, b1_y2 = box1[0], box1[1], box1[2], box1[3]
  7. b2_x1, b2_y1, b2_x2, b2_y2 = box2[0], box2[1], box2[2], box2[3]
  8. else: # transform from xywh to xyxy
  9. b1_x1, b1_x2 = box1[0] - box1[2] / 2, box1[0] + box1[2] / 2
  10. b1_y1, b1_y2 = box1[1] - box1[3] / 2, box1[1] + box1[3] / 2
  11. b2_x1, b2_x2 = box2[0] - box2[2] / 2, box2[0] + box2[2] / 2
  12. b2_y1, b2_y2 = box2[1] - box2[3] / 2, box2[1] + box2[3] / 2
  13. # Intersection area
  14. inter = (torch.min(b1_x2, b2_x2) - torch.max(b1_x1, b2_x1)).clamp(0) * \
  15. (torch.min(b1_y2, b2_y2) - torch.max(b1_y1, b2_y1)).clamp(0)
  16. # Union Area
  17. w1, h1 = b1_x2 - b1_x1, b1_y2 - b1_y1 + eps
  18. w2, h2 = b2_x2 - b2_x1, b2_y2 - b2_y1 + eps
  19. union = w1 * h1 + w2 * h2 - inter + eps
  20. iou = inter / union
  21. if GIoU or DIoU or CIoU or EIoU or SIoU:
  22. cw = torch.max(b1_x2, b2_x2) - torch.min(b1_x1, b2_x1) # convex (smallest enclosing box) width
  23. ch = torch.max(b1_y2, b2_y2) - torch.min(b1_y1, b2_y1) # convex height
  24. if SIoU: # SIoU Loss https://arxiv.org/pdf/2205.12740.pdf
  25. s_cw = (b2_x1 + b2_x2 - b1_x1 - b1_x2) * 0.5
  26. s_ch = (b2_y1 + b2_y2 - b1_y1 - b1_y2) * 0.5
  27. sigma = torch.pow(s_cw ** 2 + s_ch ** 2, 0.5)
  28. sin_alpha_1 = torch.abs(s_cw) / sigma
  29. sin_alpha_2 = torch.abs(s_ch) / sigma
  30. threshold = pow(2, 0.5) / 2
  31. sin_alpha = torch.where(sin_alpha_1 > threshold, sin_alpha_2, sin_alpha_1)
  32. # angle_cost = 1 - 2 * torch.pow( torch.sin(torch.arcsin(sin_alpha) - np.pi/4), 2)
  33. angle_cost = torch.cos(torch.arcsin(sin_alpha) * 2 - np.pi / 2)
  34. rho_x = (s_cw / cw) ** 2
  35. rho_y = (s_ch / ch) ** 2
  36. gamma = angle_cost - 2
  37. distance_cost = 2 - torch.exp(gamma * rho_x) - torch.exp(gamma * rho_y)
  38. omiga_w = torch.abs(w1 - w2) / torch.max(w1, w2)
  39. omiga_h = torch.abs(h1 - h2) / torch.max(h1, h2)
  40. shape_cost = torch.pow(1 - torch.exp(-1 * omiga_w), 4) + torch.pow(1 - torch.exp(-1 * omiga_h), 4)
  41. return iou - 0.5 * (distance_cost + shape_cost)
  42. if CIoU or DIoU or EIoU: # Distance or Complete IoU https://arxiv.org/abs/1911.08287v1
  43. c2 = cw ** 2 + ch ** 2 + eps # convex diagonal squared
  44. rho2 = ((b2_x1 + b2_x2 - b1_x1 - b1_x2) ** 2 +
  45. (b2_y1 + b2_y2 - b1_y1 - b1_y2) ** 2) / 4 # center distance squared
  46. if DIoU:
  47. return iou - rho2 / c2 # DIoU
  48. elif CIoU: # https://github.com/Zzh-tju/DIoU-SSD-pytorch/blob/master/utils/box/box_utils.py#L47
  49. v = (4 / math.pi ** 2) * torch.pow(torch.atan(w2 / (h2 + eps)) - torch.atan(w1 / (h1 + eps)), 2)
  50. with torch.no_grad():
  51. alpha = v / (v - iou + (1 + eps))
  52. return iou - (rho2 / c2 + v * alpha) # CIoU
  53. elif EIoU:
  54. rho_w2 = ((b2_x2 - b2_x1) - (b1_x2 - b1_x1)) ** 2
  55. rho_h2 = ((b2_y2 - b2_y1) - (b1_y2 - b1_y1)) ** 2
  56. cw2 = cw ** 2 + eps
  57. ch2 = ch ** 2 + eps
  58. return iou - (rho2 / c2 + rho_w2 / cw2 + rho_h2 / ch2)
  59. else: # GIoU https://arxiv.org/pdf/1902.09630.pdf
  60. c_area = cw * ch + eps # convex area
  61. return iou - (c_area - union) / c_area # GIoU
  62. else:
  63. return iou # IoU

第二步:在utils/loss.py文件中,修改ComputeLoss类中iou变量

  1. # 原始代码 第468行
  2. iou = bbox_iou(pbox.T, tbox[i], x1y1x2y2=False, CIoU=True) # iou(prediction, target)
  3. ----------------------------------分隔符---------------------------------
  4. # 修改为SIoU
  5. iou = bbox_iou(pbox.T, tbox[i], x1y1x2y2=False, SIoU=True) # iou(prediction, target)

第三步:在utils/loss.py文件中,修改ComputeLossOTA类中iou变量

  1. # 原始代码 第608行
  2. iou = bbox_iou(pbox.T, selected_tbox, x1y1x2y2=False, CIoU=True) # iou(prediction, target)
  3. ----------------------------------分隔符---------------------------------
  4. # 修改为SIoU
  5. iou = bbox_iou(pbox.T, tbox[i], x1y1x2y2=False, SIoU=True) # iou(prediction, target)
原因是因为yolov7中的yaml配置文件有一个loss_ota的参数会选择采用哪一个Loss(ComputeLoss,ComputeLossOTA),为了避免有一个不记得修改,就两个都一起修改即可。

最后:

如果你想要进一步了解更多的相关知识,可以关注下面公众号联系~会不定期发布相关设计内容包括但不限于如下内容:信号处理、通信仿真、算法设计、matlab appdesigner,gui设计、simulink仿真......希望能帮到你!

5a8015ddde1e41418a38e958eb12ecbd.png

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/不正经/article/detail/718491
推荐阅读
相关标签
  

闽ICP备14008679号