当前位置:   article > 正文

RetinaNet损失函数源码解读

retinanet损失函数

RetinaNet背景

 RetinaNet算法源自2018年Facebook AI Research的论文 Focal Loss for Dense Object Detection,作者包括了Ross大神、Kaiming大神和Piotr大神。该论文最大的贡献在于提出了Focal Loss用于解决类别不均衡问题,从而创造了RetinaNet(One Stage目标检测算法)这个精度超越经典Two Stage的Faster-RCNN的目标检测网络。

在这里插入图片描述

retinanet

在这里插入图片描述

数据流的细节

上图出自https://blog.csdn.net/weixin_48167570/article/details/120937167

  • 模型结构:
    • 最后输出三层不同尺度的特征层:C3,C4,C5
    • 经过RetinaNet的head,得到5个不同尺度的融合特征层,用于后面回归和分类
    • 将三个特征层分别输入到classification和regression回归网络中
    • 其中:A=9,4A代表每个锚点对应的九个不同锚框对应的x,y,w,h,KA代表每个锚框对应的K个种类
    • 最终输出为:
      	features = self.fpn([x2, x3, x4])
        # [bs, 9*h*w*5, 4]
        regression = torch.cat([self.regressionModel(feature) for feature in features], dim=1)
        # [bs, 9*h*w*5, classes]
        classification = torch.cat([self.classificationModel(feature) for feature in features], dim=1)
    
    • 1
    • 2
    • 3
    • 4
    • 5

如上,最终所有尺度计算出来的向量被连接在一起(concatenate on dim=1),送入下一阶段的损失计算部分中

        if self.training:
            return self.focalLoss(classification, regression, anchors, annotations)
  • 1
  • 2

IOU计算

这一步计算得到的IOU用作下一步的训练计算正负样本,保证训练时每个anchor都尽量对应着一个GT

def calc_iou(a, b):
	# a,b分别是两个框的[x1,y1,x2,y2]坐标
    area = (b[:, 2] - b[:, 0]) * (b[:, 3] - b[:, 1])
    # 计算相交区域的宽和高
    # 此处经过squeeze,向量从[n]到[n,1],iw * ih = interaction_area, iw.shape:[n,n],iw是交叉区域的宽度,同理ih为高度
    # 这里的torch.min操作可以看作一种广播机制,将每个anchor与所有的annotations bbox进行iou的计算
    iw = torch.min(torch.unsqueeze(a[:, 2], dim=1), b[:, 2]) - torch.max(torch.unsqueeze(a[:, 0], 1), b[:, 0])
    ih = torch.min(torch.unsqueeze(a[:, 3], dim=1), b[:, 3]) - torch.max(torch.unsqueeze(a[:, 1], 1), b[:, 1])
    # 对宽高剪裁,防止为负数
    iw = torch.clamp(iw, min=0)
    ih = torch.clamp(ih, min=0)
    # Sa + Sg - Si = Su:并区域面积=anchor面积+annotation面积-交叉区域面积
    ua = torch.unsqueeze((a[:, 2] - a[:, 0]) * (a[:, 3] - a[:, 1]), dim=1) + area - iw * ih
    # 限制最小的并区域面积,太小会导致交并比过大
    ua = torch.clamp(ua, min=1e-8)
    # 交叉的面积
    intersection = iw * ih
    # 最后的iou计算,IoU.shape:[9*h*w*5, n]其中shape[0]为anchor数量,shape[1]为anchor对应的所有annotations的交并比
    IoU = intersection / ua
    return IoU
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
'
运行

Focal Loss

  • 损失函数公式
    在这里插入图片描述
    上面是传统的二分类交叉熵损失,经过修改,简写为以下损失函数:
    在这里插入图片描述

  • 加入平衡因子平衡本身的不平均:
    在这里插入图片描述

  • 注:

    • 经过修改:pt意为预测值和gt值得接近程度,在log前乘以 ( 1 − p t ) γ (1-p_t)^{\gamma} (1pt)γ,使负样本(y=0)的损失乘以一个较小的权重,也就是对于真实值越接近预测值的样本损失占比越小难分类的样本损失占比较大
    • RetinaNet等一阶段法在训练时不使用NMS,分配给每个anchor的IOU max annotation作为GT,基于anchor的单阶段法本身基于密集anchor预测,如果训练时采用NMS,会造成anchor大量缺失影响单阶段性能
  • 下面着重于focal loss的代码实现:

class FocalLoss(nn.Module):

    def forward(self, classifications, regressions, anchors, annotations):
        alpha = 0.25
        gamma = 2.0
        batch_size = classifications.shape[0]
        classification_losses = []
        regression_losses = []
        # anchors:[0,5 * h*w, 9] 多个特征图的anchor数组
        anchor = anchors[0, :, :]

        anchor_widths  = anchor[:, 2] - anchor[:, 0]
        anchor_heights = anchor[:, 3] - anchor[:, 1]
        anchor_ctr_x   = anchor[:, 0] + 0.5 * anchor_widths
        anchor_ctr_y   = anchor[:, 1] + 0.5 * anchor_heights

        for j in range(batch_size):
            # classification, regression = [bs, 5 * num_anchors * w * h, num_classes], [bs, 5 * num_anchors * w * h, 4]
            # 第j个batch的所有anchor
            classification = classifications[j, :, :]
            regression = regressions[j, :, :]
            # bbox_annotation = [bs, n, 5] 其中 5含有[x, y, x, y, label]
            bbox_annotation = annotations[j, :, :]
            bbox_annotation = bbox_annotation[bbox_annotation[:, 4] != -1] # 非填充目标
            # 将classification的predict限制在一定范围内
            classification = torch.clamp(classification, 1e-4, 1.0 - 1e-4)
			'''
			分为样本含有gt和不含有gt两种情况讨论
			'''
            # 1. 如果一个GT都没有的情况
            if bbox_annotation.shape[0] == 0:
                if torch.cuda.is_available():
                	# 计算平衡因子
                    alpha_factor = torch.ones(classification.shape).cuda() * alpha
                    alpha_factor = 1. - alpha_factor
                    # 计算focal weight
                    focal_weight = classification
                    focal_weight = alpha_factor * torch.pow(focal_weight, gamma)
					# 计算bce
                    bce = -(torch.log(1.0 - classification))
                    # 最终的损失计算,回归损失按0计算
                    cls_loss = focal_weight * bce
                    classification_losses.append(cls_loss.sum())
                    regression_losses.append(torch.tensor(0).float().cuda())
                
            # 2. batch的图片含有GT时
            IoU = calc_iou(anchors[0, :, :], bbox_annotation[:, :4])
            IoU_max, IoU_argmax = torch.max(IoU, dim=1) # num_anchors x 1,每个anchor对应的最大交并比和对应的IoU对应的annotation索引
            # targets作为GT与classification计算损失,初始化为-1
            targets = torch.ones(classification.shape) * -1  # shape:[9*w*h, K]
            # 比较 IoU_max,分配targets的正负样本分配
            targets[torch.lt(IoU_max, 0.4), :] = 0
            positive_indices = torch.ge(IoU_max, 0.5)
            # 正样本数数量
            num_positive_anchors = positive_indices.sum()
            # 给每个anchor分配的GT bbox, shape:[num_positive_anchors,5]
            assigned_annotations = bbox_annotation[IoU_argmax, :]
            targets[positive_indices, :] = 0
            # 对targets中每个满足IoU要求的对应种类赋值1
            targets[positive_indices, assigned_annotations[positive_indices, 4].long()] = 1
            
            if torch.cuda.is_available():
                alpha_factor = torch.ones(targets.shape).cuda() * alpha
            else:
                alpha_factor = torch.ones(targets.shape) * alpha
                
            # 对于正样本,α=alpha_factor而负样本赋值1-alpha_factor
            alpha_factor = torch.where(torch.eq(targets, 1.), alpha_factor, 1. - alpha_factor)
            focal_weight = torch.where(torch.eq(targets, 1.), 1. - classification, classification)
            # focal_weight
            focal_weight = alpha_factor * torch.pow(focal_weight, gamma)
			# binary crossentropy loss---bce.shape:[num_anchors, num_classes]
            bce = -(targets * torch.log(classification) + (1.0 - targets) * torch.log(1.0 - classification))

            # 最后的loss计算
            cls_loss = focal_weight * bce

            if torch.cuda.is_available():
                cls_loss = torch.where(torch.ne(targets, -1.0), cls_loss, torch.zeros(cls_loss.shape).cuda())
            else:
                cls_loss = torch.where(torch.ne(targets, -1.0), cls_loss, torch.zeros(cls_loss.shape))

            classification_losses.append(cls_loss.sum()/torch.clamp(num_positive_anchors.float(), min=1.0))

  			# 计算回归损失值,分为锚框的IoU满足阈值的正样本索引和负样本索引
            if positive_indices.sum() > 0:
                assigned_annotations = assigned_annotations[positive_indices, :]
                anchor_widths_pi = anchor_widths[positive_indices]
                anchor_heights_pi = anchor_heights[positive_indices]
                anchor_ctr_x_pi = anchor_ctr_x[positive_indices]
                anchor_ctr_y_pi = anchor_ctr_y[positive_indices]
				# x,y,x1,y1 -> cx,cy,w,h
                gt_widths  = assigned_annotations[:, 2] - assigned_annotations[:, 0]
                gt_heights = assigned_annotations[:, 3] - assigned_annotations[:, 1]
                gt_ctr_x   = assigned_annotations[:, 0] + 0.5 * gt_widths
                gt_ctr_y   = assigned_annotations[:, 1] + 0.5 * gt_heights

                # clip widths to 1
                gt_widths  = torch.clamp(gt_widths, min=1)
                gt_heights = torch.clamp(gt_heights, min=1)
				# targets是真实中心坐标与预测的anchor中心坐标偏移与anchor长宽的比值,也就是需要预测的值
				'''
				梳理一下:
				coco数据集标记的锚框是x,y,w,h格式
				通过coco的工具包处理,处理成了x,y,x1,y1形式
				预测的结果:cx,cy,logΔx,logΔy
				'''
                targets_dx = (gt_ctr_x - anchor_ctr_x_pi) / anchor_widths_pi
                targets_dy = (gt_ctr_y - anchor_ctr_y_pi) / anchor_heights_pi
                targets_dw = torch.log(gt_widths / anchor_widths_pi)
                targets_dh = torch.log(gt_heights / anchor_heights_pi)

                targets = torch.stack((targets_dx, targets_dy, targets_dw, targets_dh))
                targets = targets.t()

                if torch.cuda.is_available():
                    targets = targets/torch.Tensor([[0.1, 0.1, 0.2, 0.2]]).cuda()
                else:
                    targets = targets/torch.Tensor([[0.1, 0.1, 0.2, 0.2]])

                negative_indices = 1 + (~positive_indices)

                regression_diff = torch.abs(targets - regression[positive_indices, :])

                regression_loss = torch.where(
                    torch.le(regression_diff, 1.0 / 9.0),
                    0.5 * 9.0 * torch.pow(regression_diff, 2),
                    regression_diff - 0.5 / 9.0
                )
                regression_losses.append(regression_loss.mean())
            else:
                # 全是预测的负样本,回归损失置零
                if torch.cuda.is_available():
                    regression_losses.append(torch.tensor(0).float().cuda())
                else:
                    regression_losses.append(torch.tensor(0).float())

        return torch.stack(classification_losses).mean(dim=0, keepdim=True), torch.stack(regression_losses).mean(dim=0, keepdim=True)

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128
  • 129
  • 130
  • 131
  • 132
  • 133
  • 134
  • 135
  • 136
  • 137
  • 138
  • 139
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/小惠珠哦/article/detail/848119
推荐阅读
相关标签
  

闽ICP备14008679号