当前位置:   article > 正文

YOLOX源码之 Label Assignment

YOLOX源码之 Label Assignment

网络结构没什么好讲的,backbone、neck、head组成,backbone采用的cspdarknet,neck采用的pafpn,head是decoupled head结构。这里主要讲一下label assignment的具体实现,yolox中采用了simota,是ota的简化版本。

实现中标签分配以及计算损失部分是在yolo_head.py中,连带着head的网络层一起的,这里也顺带一起讲了。

首先forward函数的输入xin是neck的输出,当输入shape为(4,3,416,416)时,xin的shape为[(4,128,52,52),(4,256,26,26),(4,512,13,13)],对应8,16,32三种不同stride的输出特征图。

接下里的for循环是分别对三个特征图进行head部分网络层的forward,并计算对应的grids,grids具体是什么后面会讲。以stride=8对应的大小为(4,128,52,52)的特征图为例,self.stems[k]是一层1x1卷积,然后分类分支cls_conv和回归分支reg_conv都是2层3x3卷积,self.cls_preds[k]得到最终的分类输出shape为(b,num_classes,52,52),self.reg_preds[k]得到最终的回归输出shape为(b,4,52,52),self.obj_preds[k]得到最终的objectiveness输出shape为(b,1,52,52)。这里b=4,num_classes=16。

接下来将三个输出torch.cat得到输出shape为(4,21,52,52)。接下来函数self.get_output_and_grid()得到网格坐标grid和解码后的输出output。

代码如下

  1. def get_output_and_grid(self, output, k, stride, dtype):
  2. # (4,21,52,52)
  3. grid = self.grids[k]
  4. batch_size = output.shape[0]
  5. n_ch = 5 + self.num_classes
  6. hsize, wsize = output.shape[-2:]
  7. if grid.shape[2:4] != output.shape[2:4]:
  8. yv, xv = meshgrid([torch.arange(hsize), torch.arange(wsize)])
  9. grid = torch.stack((xv, yv), 2).view(1, 1, hsize, wsize, 2).type(dtype) # (1,1,52,52,2), 先按行后按列,每个像素点的坐标
  10. self.grids[k] = grid
  11. output = output.view(batch_size, 1, n_ch, hsize, wsize) # (4,1,21,52,52)
  12. output = output.permute(0, 1, 3, 4, 2).reshape(
  13. batch_size, hsize * wsize, -1
  14. ) # (4,2704,21)
  15. grid = grid.view(1, -1, 2) # (1,2704,2)
  16. output[..., :2] = (output[..., :2] + grid) * stride
  17. output[..., 2:4] = torch.exp(output[..., 2:4]) * stride
  18. return output, grid

self.grids是三个torch.Size([1])的列表,所以会进入到line8的if中。hsize和wsize分别是特征图的高和宽这里都是52,meshgrid返回的yv和xv分别是特征图每个像素点对应的y坐标和x坐标,如下所示

  1. tensor([[ 0, 0, 0, ..., 0, 0, 0],
  2. [ 1, 1, 1, ..., 1, 1, 1],
  3. [ 2, 2, 2, ..., 2, 2, 2],
  4. ...,
  5. [49, 49, 49, ..., 49, 49, 49],
  6. [50, 50, 50, ..., 50, 50, 50],
  7. [51, 51, 51, ..., 51, 51, 51]])
  8. tensor([[ 0, 1, 2, ..., 49, 50, 51],
  9. [ 0, 1, 2, ..., 49, 50, 51],
  10. [ 0, 1, 2, ..., 49, 50, 51],
  11. ...,
  12. [ 0, 1, 2, ..., 49, 50, 51],
  13. [ 0, 1, 2, ..., 49, 50, 51],
  14. [ 0, 1, 2, ..., 49, 50, 51]])

然后将xy坐标stack得到每个点的xy坐标,shape为(1,1,52,52,2),按先行后列的顺序,如下

  1. tensor([[[[[ 0., 0.],
  2. [ 1., 0.],
  3. [ 2., 0.],
  4. ...,
  5. [49., 0.],
  6. [50., 0.],
  7. [51., 0.]],
  8. [[ 0., 1.],
  9. [ 1., 1.],
  10. [ 2., 1.],
  11. ...,
  12. [49., 1.],
  13. [50., 1.],
  14. [51., 1.]],
  15. ...,
  16. [[ 0., 50.],
  17. [ 1., 50.],
  18. [ 2., 50.],
  19. ...,
  20. [49., 50.],
  21. [50., 50.],
  22. [51., 50.]],
  23. [[ 0., 51.],
  24. [ 1., 51.],
  25. [ 2., 51.],
  26. ...,
  27. [49., 51.],
  28. [50., 51.],
  29. [51., 51.]]]]], device='cuda:0', dtype=torch.float16)

然后将output和grid分别view调整维度,output中每个点对应一个预测框,output[..., :2]是预测框中心点相对于每个点的偏移,因此line8加上每个点的坐标grid并乘以stride还原回原图上得到原图上真实预测框的中心点坐标。line9则是通过 et 并乘以stride得到原图上真实预测框的宽高。 

在得到原图上预测框的坐标以及类别和objectiveness后,接下来就是进行label assignment并计算loss,具体实现都在self.get_losses()中。其中输入outputs是将坐标还原到原图中的三个特征图的输出并concat得到的,shape为(b, 3549, 21),3549=52x52+26x26+13x13,21=4+1+16。

在函数get_losses()中,调用self.get_assignments进行标签分配,这里使用的方法是simota。关于simota和ota的原理可参考OTA: Optimal Transport Assignment for Object Detection 原理与代码解读-CSDN博客https://blog.csdn.net/ooooocj/article/details/136569249。get_assignments()的完整实现如下

  1. @torch.no_grad()
  2. def get_assignments(
  3. self,
  4. batch_idx,
  5. num_gt,
  6. gt_bboxes_per_image, # (17,4)
  7. gt_classes, # (17)
  8. bboxes_preds_per_image, # (3549,4)
  9. expanded_strides, # (1,3549)
  10. x_shifts, # (1,3549)
  11. y_shifts, # (1,3549)
  12. cls_preds, # (4,3549,16)
  13. obj_preds, # (4,3549,1)
  14. mode="gpu",
  15. ):
  16. if mode == "cpu":
  17. print("-----------Using CPU for the Current Batch-------------")
  18. gt_bboxes_per_image = gt_bboxes_per_image.cpu().float()
  19. bboxes_preds_per_image = bboxes_preds_per_image.cpu().float()
  20. gt_classes = gt_classes.cpu().float()
  21. expanded_strides = expanded_strides.cpu().float()
  22. x_shifts = x_shifts.cpu()
  23. y_shifts = y_shifts.cpu()
  24. fg_mask, geometry_relation = self.get_geometry_constraint(
  25. gt_bboxes_per_image,
  26. expanded_strides,
  27. x_shifts,
  28. y_shifts,
  29. ) # (3549), (17,357)
  30. # fg_mask中True位置的anchor point至少在一个gt box的center area内,后续会用来进行label assignment。而不在fg_mask中False位置的anchor point
  31. # 不在任意一个gt box的center area内。
  32. bboxes_preds_per_image = bboxes_preds_per_image[fg_mask] # (357,4)
  33. cls_preds_ = cls_preds[batch_idx][fg_mask] # (357,16)
  34. obj_preds_ = obj_preds[batch_idx][fg_mask] # (357,1)
  35. num_in_boxes_anchor = bboxes_preds_per_image.shape[0] # 357
  36. if mode == "cpu":
  37. gt_bboxes_per_image = gt_bboxes_per_image.cpu()
  38. bboxes_preds_per_image = bboxes_preds_per_image.cpu()
  39. pair_wise_ious = bboxes_iou(gt_bboxes_per_image, bboxes_preds_per_image, False) # (17,357)
  40. gt_cls_per_image = (
  41. F.one_hot(gt_classes.to(torch.int64), self.num_classes)
  42. .float()
  43. ) # (17,16)
  44. pair_wise_ious_loss = -torch.log(pair_wise_ious + 1e-8) # (17,357)
  45. if mode == "cpu":
  46. cls_preds_, obj_preds_ = cls_preds_.cpu(), obj_preds_.cpu()
  47. with torch.cuda.amp.autocast(enabled=False):
  48. cls_preds_ = (
  49. cls_preds_.float().sigmoid_() * obj_preds_.float().sigmoid_()
  50. ).sqrt()
  51. pair_wise_cls_loss = F.binary_cross_entropy(
  52. cls_preds_.unsqueeze(0).repeat(num_gt, 1, 1), # (357,16)->(1,357,16)->(17,357,16)
  53. gt_cls_per_image.unsqueeze(1).repeat(1, num_in_boxes_anchor, 1), # (17,16)->(17,1,16)->(17,357,16)
  54. reduction="none"
  55. ).sum(-1) # (17,357), 共16个类别,每个类单独计算bce
  56. del cls_preds_
  57. cost = (
  58. pair_wise_cls_loss
  59. + 3.0 * pair_wise_ious_loss
  60. + float(1e6) * (~geometry_relation) # center area之外的anchor point对应的cost加上一个很大的值来过滤
  61. ) # (17,357)
  62. (
  63. num_fg, # 22
  64. gt_matched_classes, # (22)
  65. pred_ious_this_matching, # (22)
  66. matched_gt_inds, # (22)
  67. ) = self.simota_matching(cost, pair_wise_ious, gt_classes, num_gt, fg_mask)
  68. del pair_wise_cls_loss, cost, pair_wise_ious, pair_wise_ious_loss
  69. if mode == "cpu":
  70. gt_matched_classes = gt_matched_classes.cuda()
  71. fg_mask = fg_mask.cuda()
  72. pred_ious_this_matching = pred_ious_this_matching.cuda()
  73. matched_gt_inds = matched_gt_inds.cuda()
  74. return (
  75. gt_matched_classes,
  76. fg_mask, # (3549)
  77. pred_ious_this_matching,
  78. matched_gt_inds,
  79. num_fg,
  80. )

在ota中使用了center prior,即只有gt box中心有限区域内的anchor point作为正样本的candidate,而不是整个gt box内所有的anchor point都作为正样本的候选。函数get_geometry_constraint就是实现这个的

  1. def get_geometry_constraint(
  2. self, gt_bboxes_per_image, expanded_strides, x_shifts, y_shifts,
  3. ):
  4. """
  5. Calculate whether the center of an object is located in a fixed range of
  6. an anchor. This is used to avert inappropriate matching. It can also reduce
  7. the number of candidate anchors so that the GPU memory is saved.
  8. """
  9. expanded_strides_per_image = expanded_strides[0] # (3549)
  10. x_centers_per_image = ((x_shifts[0] + 0.5) * expanded_strides_per_image).unsqueeze(0) # (1,3549)
  11. y_centers_per_image = ((y_shifts[0] + 0.5) * expanded_strides_per_image).unsqueeze(0) # (1,3549)
  12. # in fixed center
  13. center_radius = 1.5 # 这里有可能center area区域比原目标还大
  14. center_dist = expanded_strides_per_image.unsqueeze(0) * center_radius # (1,3549)
  15. gt_bboxes_per_image_l = (gt_bboxes_per_image[:, 0:1]) - center_dist # (17,1) -> (17,3549)
  16. gt_bboxes_per_image_r = (gt_bboxes_per_image[:, 0:1]) + center_dist # (17,3549)
  17. gt_bboxes_per_image_t = (gt_bboxes_per_image[:, 1:2]) - center_dist # (17,3549)
  18. gt_bboxes_per_image_b = (gt_bboxes_per_image[:, 1:2]) + center_dist # (17,3549)
  19. c_l = x_centers_per_image - gt_bboxes_per_image_l # (1,3549)-(17,3549) -> (17,3549)
  20. c_r = gt_bboxes_per_image_r - x_centers_per_image
  21. c_t = y_centers_per_image - gt_bboxes_per_image_t
  22. c_b = gt_bboxes_per_image_b - y_centers_per_image
  23. center_deltas = torch.stack([c_l, c_t, c_r, c_b], 2) # (17,3549,4)
  24. is_in_centers = center_deltas.min(dim=-1).values > 0.0 # (17,3549)
  25. anchor_filter = is_in_centers.sum(dim=0) > 0 # (3549), 一共3549个anchor point, 对应位置为False, 说明这个anchor point不在任意一个gt box的center area内
  26. geometry_relation = is_in_centers[:, anchor_filter] # (17,357), anchor_filter.sum()==357,表明某个anchor point至少在一个gt box的center area内
  27. return anchor_filter, geometry_relation

最终返回的anchor_filter是一个shape为(3549, )的tensor,值全为True或False。前面说过三个特征图一共3549个anchor point,值为False对应的anchor point不在任意一个gt box的center area内,后续进行标签分配时只从值为True的anchor point中挑选。当我用自己的数据调试时,另一个输出geometry_relation的shape为(17, 357),17是图中gt的数量,357是anchor_filter中值为True的anchor point的数量,geometry_relation表示每个gt的中心区域内对应的anchor point。

然后用fg_mask也就是anchor_filter挑选出候选的正样本,然后计算ota的cost matrix,cost矩阵包括分类损失以及回归损失,注意分类的预测要取sigmoid后并与obj预测相乘再与gt计算交叉熵损失,最后加上float(1e6) * (~geometry_relation)是对每个gt中心区域外的anchor加上一个特别大的cost,从而过滤它们。

在得到cost矩阵后,就是通过simota进行标签分配的过程了,具体实现在函数simota_matching中

  1. def simota_matching(self, cost, pair_wise_ious, gt_classes, num_gt, fg_mask):
  2. # (17,357),(17,357),(17),17,(3549)
  3. matching_matrix = torch.zeros_like(cost, dtype=torch.uint8) # (17,357)
  4. n_candidate_k = min(10, pair_wise_ious.size(1)) # 这里10就是文章中dynamic_k中的q
  5. topk_ious, _ = torch.topk(pair_wise_ious, n_candidate_k, dim=1) # (17,10)
  6. dynamic_ks = torch.clamp(topk_ious.sum(1).int(), min=1) # (17)
  7. # tensor([3, 1, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], device='cuda:0', dtype=torch.int32)
  8. # 每个gt选择q个最大iou值,相加取整作为为该gt分配的anchor point的个数
  9. for gt_idx in range(num_gt):
  10. _, pos_idx = torch.topk(
  11. cost[gt_idx], k=dynamic_ks[gt_idx], largest=False
  12. ) # 选择cost最小的dynamic k个anchor point作为分配的正样本,代替原始OTA中的sinkhorn算法
  13. matching_matrix[gt_idx][pos_idx] = 1
  14. del topk_ious, dynamic_ks, pos_idx
  15. anchor_matching_gt = matching_matrix.sum(0) # (357)
  16. # deal with the case that one anchor matches multiple ground-truths
  17. if anchor_matching_gt.max() > 1:
  18. multiple_match_mask = anchor_matching_gt > 1
  19. _, cost_argmin = torch.min(cost[:, multiple_match_mask], dim=0) # 当一个anchor point匹配多个gt时,选择cost最小的gt作为匹配的结果
  20. matching_matrix[:, multiple_match_mask] *= 0
  21. matching_matrix[cost_argmin, multiple_match_mask] = 1
  22. fg_mask_inboxes = anchor_matching_gt > 0 # (357), pos anchor point的index
  23. num_fg = fg_mask_inboxes.sum().item()
  24. # num_fg==22, anchor_matching_gt.sum()==tensor(22, device='cuda:0')
  25. # 当if anchor_matching_gt.max() > 1成立时,num_fg > matching_matrix.sum().item()
  26. # fg_mask.sum().item() == 357
  27. fg_mask[fg_mask.clone()] = fg_mask_inboxes
  28. # fg_mask.sum().item() == 22
  29. # 更新fg_mask,本来fg_mask中有357个anchor point初步过滤后再gt center area内,然后经过simota第二次匹配找到pos anchor point
  30. # 注意这里[]内fg_mask.clone()的作用,是找到那357个的值,然后用fg_mask_inboxes替换
  31. # 这里fg_mask更新后,不用return,外面的fg_mask也更新了
  32. # 这里的fg_mask就是所有3549个anchor中哪几个anchor是正样本,正样本处的值为1
  33. matched_gt_inds = matching_matrix[:, fg_mask_inboxes].argmax(0)
  34. # matched_gt_inds == tensor([7, 7, 8, 2, 1, 9, 4, 4, 10, 10, 3, 13, 5, 14, 12, 6, 15, 16, 11, 0, 0, 0], device='cuda:0')
  35. # 每个pos anchor匹配到了第几个gt的index
  36. # print(gt_classes) == tensor([6, 5, 12, 12, 12, 5, 12, 12, 5, 12, 12, 5, 5, 5, 5, 12, 12], device='cuda:0', dtype=torch.float16)
  37. gt_matched_classes = gt_classes[matched_gt_inds] # 每个pos anchor匹配到的gt的实际类别索引
  38. # print(gt_matched_classes) == tensor([12, 12, 5, 12, 5, 12, 12, 12, 12, 12, 12, 5, 5, 5, 5, 12, 12, 12, 5, 6, 6, 6.], device='cuda:0', dtype=torch.float16)
  39. pred_ious_this_matching = (matching_matrix * pair_wise_ious).sum(0)[
  40. fg_mask_inboxes
  41. ]
  42. # 这里sum(0)沿列求和,一列只有1个值大于0,因为上面处理完后,一个anchor只能匹配一个gt。但一行可以有多个大于0的值,即1个gt可以和多个anchor匹配
  43. return num_fg, gt_matched_classes, pred_ious_this_matching, matched_gt_inds

simota和原本的ota的区别是,在得到cost矩阵后,ota通过sinkhorn算法进行匹配,而simota则直接选择topk个cost最小的anchor作为正样本,和最早的faster rcnn中的topk相似,只不过那里是选择iou最小,这里是选择cost最小,这里的cost不仅考虑了iou还考虑了分类损失和center prior。另外这里的k不是认为设置的固定值,而是dynamic k,具体是根据先选择q个iou最大的anchor(这里q仍然是人工设定的代码中取10),然后这10个iou求和取整得到k值。

  1. n_candidate_k = min(10, pair_wise_ious.size(1)) # 这里10就是文章中dynamic_k中的q
  2. topk_ious, _ = torch.topk(pair_wise_ious, n_candidate_k, dim=1) # (17,10)
  3. dynamic_ks = torch.clamp(topk_ious.sum(1).int(), min=1) # (17)

一个gt可以匹配多个anchor,但一个anchor只能匹配一个gt,根据上面的规则选择cost最小的k个anchor后如果存在一个anchor匹配多个gt的情况,选择cost最小对应的gt作为匹配结果。

样本分配完后,就是计算损失了,这里没什么好讲的,回归损失采用的iou loss,分类损失和obj损失都是bce loss。yolox中作者在最后15个epoch关闭了mosaic数据增强,并添加了额外的L1 loss来增加回归的精度,这里L1 loss就是在特征图上计算的,预测就是特征图的原始输出,没有像iou loss一样加上grid并乘以stride映射会原图,这里target是将label反向映射到特征图上。

  1. def get_l1_target(self, l1_target, gt, stride, x_shifts, y_shifts, eps=1e-8):
  2. l1_target[:, 0] = gt[:, 0] / stride - x_shifts
  3. l1_target[:, 1] = gt[:, 1] / stride - y_shifts
  4. l1_target[:, 2] = torch.log(gt[:, 2] / stride + eps)
  5. l1_target[:, 3] = torch.log(gt[:, 3] / stride + eps)
  6. return l1_target

 

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/IT小白/article/detail/698041
推荐阅读
相关标签
  

闽ICP备14008679号