当前位置:   article > 正文

【yolo系列:yolov7改进wise-iou】_改进yolo系列 | yolov7 更换训练策略之 siou / eiou / wiou / foc

改进yolo系列 | yolov7 更换训练策略之 siou / eiou / wiou / focal_xiou 最全汇

yolo系列文章目录

学习视频:
YOLOV7改进-Wise IoU_哔哩哔哩_bilibili

代码地址:
objectdetection_script/yolov7-iou.py at master · z1069614715/objectdetection_script (github.com)
Wise-IoU(WIoU)是一种用于目标检测的创新性损失函数,针对传统边界框损失函数中对训练数据质量要求较高的问题进行了改进。在目标检测中,边界框损失函数的设计对模型性能至关重要。以往的研究大多假定训练数据是高质量的,并试图通过强化边界框损失的拟合能力来提高模型性能。然而,在实际训练集中,通常包含了一些低质量的示例,如果盲目地加强对这些低质量示例的回归,可能会损害模型的检测性能。

为了解决这个问题,先前的研究提出了Focal-EIoU v1方法,但其聚焦机制是静态的,未充分挖掘非单调聚焦机制的潜力。基于这一观点,研究者们提出了一种新的动态非单调聚焦机制,即Wise-IoU(WIoU)。这种机制使用“离群度”替代传统的IoU(Intersection over Union)来评估锚框的质量,并引入了明智的梯度增益分配策略。这一策略在降低高质量锚框竞争性的同时,也减小了低质量示例产生的有害梯度。这使得WIoU能够更集中地处理普通质量的锚框,从而提高整体检测器的性能。

在实际应用中,将WIoU应用于最先进的单级检测器YOLOv7时,它在MS-COCO数据集上的AP-75(Average Precision with IoU threshold at 0.75)从53.03%提升到了54.50%。这种显著的性能提升表明Wise-IoU在处理目标检测任务中具有很高的实用性和效果。通过引入动态非单调聚焦机制,WIoU为目标检测领域带来了新的思路和方法,为提高模型的鲁棒性和准确性提供了有力支持。



一、在yolov7之上进行替换

utils/general.py替换bbiox

class WIoU_Scale:
    ''' monotonous: {
            None: origin v1
            True: monotonic FM v2
            False: non-monotonic FM v3
        }
        momentum: The momentum of running mean'''
    
    iou_mean = 1.
    monotonous = False
    _momentum = 1 - 0.5 ** (1 / 7000)
    _is_train = True

    def __init__(self, iou):
        self.iou = iou
        self._update(self)
    
    @classmethod
    def _update(cls, self):
        if cls._is_train: cls.iou_mean = (1 - cls._momentum) * cls.iou_mean + \
                                         cls._momentum * self.iou.detach().mean().item()
    
    @classmethod
    def _scaled_loss(cls, self, gamma=1.9, delta=3):
        if isinstance(self.monotonous, bool):
            if self.monotonous:
                return (self.iou.detach() / self.iou_mean).sqrt()
            else:
                beta = self.iou.detach() / self.iou_mean
                alpha = delta * torch.pow(gamma, beta - delta)
                return beta / alpha
        return 1
    

def bbox_iou(box1, box2, xywh=True, GIoU=False, DIoU=False, CIoU=False, SIoU=False, EIoU=False, WIoU=False, Focal=False, alpha=1, gamma=0.5, scale=False, eps=1e-7):
    # Returns Intersection over Union (IoU) of box1(1,4) to box2(n,4)

    # Get the coordinates of bounding boxes
    if xywh:  # transform from xywh to xyxy
        (x1, y1, w1, h1), (x2, y2, w2, h2) = box1.chunk(4, -1), box2.chunk(4, -1)
        w1_, h1_, w2_, h2_ = w1 / 2, h1 / 2, w2 / 2, h2 / 2
        b1_x1, b1_x2, b1_y1, b1_y2 = x1 - w1_, x1 + w1_, y1 - h1_, y1 + h1_
        b2_x1, b2_x2, b2_y1, b2_y2 = x2 - w2_, x2 + w2_, y2 - h2_, y2 + h2_
    else:  # x1, y1, x2, y2 = box1
        b1_x1, b1_y1, b1_x2, b1_y2 = box1.chunk(4, -1)
        b2_x1, b2_y1, b2_x2, b2_y2 = box2.chunk(4, -1)
        w1, h1 = b1_x2 - b1_x1, (b1_y2 - b1_y1).clamp(eps)
        w2, h2 = b2_x2 - b2_x1, (b2_y2 - b2_y1).clamp(eps)

    # Intersection area
    inter = (b1_x2.minimum(b2_x2) - b1_x1.maximum(b2_x1)).clamp(0) * \
            (b1_y2.minimum(b2_y2) - b1_y1.maximum(b2_y1)).clamp(0)

    # Union Area
    union = w1 * h1 + w2 * h2 - inter + eps
    if scale:
        self = WIoU_Scale(1 - (inter / union))

    # IoU
    # iou = inter / union # ori iou
    iou = torch.pow(inter/(union + eps), alpha) # alpha iou
    if CIoU or DIoU or GIoU or EIoU or SIoU or WIoU:
        cw = b1_x2.maximum(b2_x2) - b1_x1.minimum(b2_x1)  # convex (smallest enclosing box) width
        ch = b1_y2.maximum(b2_y2) - b1_y1.minimum(b2_y1)  # convex height
        if CIoU or DIoU or EIoU or SIoU or WIoU:  # Distance or Complete IoU https://arxiv.org/abs/1911.08287v1
            c2 = (cw ** 2 + ch ** 2) ** alpha + eps  # convex diagonal squared
            rho2 = (((b2_x1 + b2_x2 - b1_x1 - b1_x2) ** 2 + (b2_y1 + b2_y2 - b1_y1 - b1_y2) ** 2) / 4) ** alpha  # center dist ** 2
            if CIoU:  # https://github.com/Zzh-tju/DIoU-SSD-pytorch/blob/master/utils/box/box_utils.py#L47
                v = (4 / math.pi ** 2) * (torch.atan(w2 / h2) - torch.atan(w1 / h1)).pow(2)
                with torch.no_grad():
                    alpha_ciou = v / (v - iou + (1 + eps))
                if Focal:
                    return iou - (rho2 / c2 + torch.pow(v * alpha_ciou + eps, alpha)), torch.pow(inter/(union + eps), gamma)  # Focal_CIoU
                else:
                    return iou - (rho2 / c2 + torch.pow(v * alpha_ciou + eps, alpha))  # CIoU
            elif EIoU:
                rho_w2 = ((b2_x2 - b2_x1) - (b1_x2 - b1_x1)) ** 2
                rho_h2 = ((b2_y2 - b2_y1) - (b1_y2 - b1_y1)) ** 2
                cw2 = torch.pow(cw ** 2 + eps, alpha)
                ch2 = torch.pow(ch ** 2 + eps, alpha)
                if Focal:
                    return iou - (rho2 / c2 + rho_w2 / cw2 + rho_h2 / ch2), torch.pow(inter/(union + eps), gamma) # Focal_EIou
                else:
                    return iou - (rho2 / c2 + rho_w2 / cw2 + rho_h2 / ch2) # EIou
            elif SIoU:
                # SIoU Loss https://arxiv.org/pdf/2205.12740.pdf
                s_cw = (b2_x1 + b2_x2 - b1_x1 - b1_x2) * 0.5 + eps
                s_ch = (b2_y1 + b2_y2 - b1_y1 - b1_y2) * 0.5 + eps
                sigma = torch.pow(s_cw ** 2 + s_ch ** 2, 0.5)
                sin_alpha_1 = torch.abs(s_cw) / sigma
                sin_alpha_2 = torch.abs(s_ch) / sigma
                threshold = pow(2, 0.5) / 2
                sin_alpha = torch.where(sin_alpha_1 > threshold, sin_alpha_2, sin_alpha_1)
                angle_cost = torch.cos(torch.arcsin(sin_alpha) * 2 - math.pi / 2)
                rho_x = (s_cw / cw) ** 2
                rho_y = (s_ch / ch) ** 2
                gamma = angle_cost - 2
                distance_cost = 2 - torch.exp(gamma * rho_x) - torch.exp(gamma * rho_y)
                omiga_w = torch.abs(w1 - w2) / torch.max(w1, w2)
                omiga_h = torch.abs(h1 - h2) / torch.max(h1, h2)
                shape_cost = torch.pow(1 - torch.exp(-1 * omiga_w), 4) + torch.pow(1 - torch.exp(-1 * omiga_h), 4)
                if Focal:
                    return iou - torch.pow(0.5 * (distance_cost + shape_cost) + eps, alpha), torch.pow(inter/(union + eps), gamma) # Focal_SIou
                else:
                    return iou - torch.pow(0.5 * (distance_cost + shape_cost) + eps, alpha) # SIou
            elif WIoU:
                if Focal:
                    raise Exception("WIoU do not support Focal.")
                elif scale:
                    return getattr(WIoU_Scale, '_scaled_loss')(self), (1 - iou) * torch.exp((rho2 / c2)), iou # WIoU https://arxiv.org/abs/2301.10051
                else:
                    return iou, torch.exp((rho2 / c2)) # WIoU v1
            if Focal:
                return iou - rho2 / c2, torch.pow(inter/(union + eps), gamma)  # Focal_DIoU
            else:
                return iou - rho2 / c2  # DIoU
        c_area = cw * ch + eps  # convex area
        if Focal:
            return iou - torch.pow((c_area - union) / c_area + eps, alpha), torch.pow(inter/(union + eps), gamma)  # Focal_GIoU https://arxiv.org/pdf/1902.09630.pdf
        else:
            return iou - torch.pow((c_area - union) / c_area + eps, alpha)  # GIoU https://arxiv.org/pdf/1902.09630.pdf
    if Focal:
        return iou, torch.pow(inter/(union + eps), gamma)  # Focal_IoU
    else:
        return iou  # IoU
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125

二、在loss.py的ComputeLoss,ComputeLossOTA修改如下

在这里插入图片描述

if type(iou) is tuple:
    if len(iou) == 2:
        lbox += (iou[1].detach() * (1 - iou[0].)).mean()
        iou = iou[0]
    else:
        lbox += (iou[0] * iou[1]).mean()
        iou = iou[-1]
else:
    lbox += (1.0 - iou.squeeze()).mean()  # iou loss
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9

修改在这里插入图片描述

在这里插入图片描述

三、设置版本

monotonous = False
就是v3,truev2,none为v1版本,可以自行尝试效果。
在目标检测领域,选择合适的模型版本对性能提升至关重要。在这个实验中,我们探讨了三个不同版本的模型:v1、v2、v3,它们分别代表了monotonous参数设置为None、True和False的情况。这个参数的变化引入了不同的聚焦机制,从而影响了模型的性能表现。

首先,当monotonous参数为None(v1版本)时,模型的聚焦机制是静态的,无法适应数据中的复杂特征变化。这可能导致模型无法有效地捕捉到目标边界的微小变化,从而影响了检测的准确性。

其次,monotonous参数为True(v2版本)时,模型采用了单调递增的聚焦机制。这种机制对于某些场景可能更加适用,但在某些情况下可能会忽略掉一些关键的目标特征,导致性能提升有限。

最后,monotonous参数为False(v3版本)时,引入了非单调递增的聚焦机制。这种机制允许模型更加灵活地适应各种特定目标的形状和结构,从而在处理复杂场景时表现更为出色。

在实验过程中,我们可以根据不同版本的模型输出结果进行性能对比分析。通过比较各个版本在各种测试场景下的检测准确性、鲁棒性和处理速度,我们可以确定哪个版本在实验中取得了明显的提升。

总结而言,选择适当的聚焦机制非常关键,它直接影响了目标检测模型的性能。在实验中,我们可以通过调整monotonous参数来尝试不同版本的模型,从而找到最适合特定任务和数据集的版本。这种灵活性使得我们能够根据实际需求优化模型,取得更好的检测结果。在选择模型版本时,综合考虑准确性、鲁棒性和处理速度等因素,可以帮助我们做出明智的决策,提高目标检测系统的性能和可靠性。


总结

确定好训练配置后,即可进行性能对比分析,找出哪个版本在实验中取得了明显的提升。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/Gausst松鼠会/article/detail/397018
推荐阅读
相关标签
  

闽ICP备14008679号