赞
踩
网络结构没什么好讲的,backbone、neck、head组成,backbone采用的cspdarknet,neck采用的pafpn,head是decoupled head结构。这里主要讲一下label assignment的具体实现,yolox中采用了simota,是ota的简化版本。
实现中标签分配以及计算损失部分是在yolo_head.py中,连带着head的网络层一起的,这里也顺带一起讲了。
首先forward函数的输入xin是neck的输出,当输入shape为(4,3,416,416)时,xin的shape为[(4,128,52,52),(4,256,26,26),(4,512,13,13)],对应8,16,32三种不同stride的输出特征图。
接下里的for循环是分别对三个特征图进行head部分网络层的forward,并计算对应的grids,grids具体是什么后面会讲。以stride=8对应的大小为(4,128,52,52)的特征图为例,self.stems[k]是一层1x1卷积,然后分类分支cls_conv和回归分支reg_conv都是2层3x3卷积,self.cls_preds[k]得到最终的分类输出shape为(b,num_classes,52,52),self.reg_preds[k]得到最终的回归输出shape为(b,4,52,52),self.obj_preds[k]得到最终的objectiveness输出shape为(b,1,52,52)。这里b=4,num_classes=16。
接下来将三个输出torch.cat得到输出shape为(4,21,52,52)。接下来函数self.get_output_and_grid()得到网格坐标grid和解码后的输出output。
代码如下
- def get_output_and_grid(self, output, k, stride, dtype):
- # (4,21,52,52)
- grid = self.grids[k]
-
- batch_size = output.shape[0]
- n_ch = 5 + self.num_classes
- hsize, wsize = output.shape[-2:]
- if grid.shape[2:4] != output.shape[2:4]:
- yv, xv = meshgrid([torch.arange(hsize), torch.arange(wsize)])
- grid = torch.stack((xv, yv), 2).view(1, 1, hsize, wsize, 2).type(dtype) # (1,1,52,52,2), 先按行后按列,每个像素点的坐标
- self.grids[k] = grid
-
- output = output.view(batch_size, 1, n_ch, hsize, wsize) # (4,1,21,52,52)
- output = output.permute(0, 1, 3, 4, 2).reshape(
- batch_size, hsize * wsize, -1
- ) # (4,2704,21)
- grid = grid.view(1, -1, 2) # (1,2704,2)
- output[..., :2] = (output[..., :2] + grid) * stride
- output[..., 2:4] = torch.exp(output[..., 2:4]) * stride
- return output, grid
self.grids是三个torch.Size([1])的列表,所以会进入到line8的if中。hsize和wsize分别是特征图的高和宽这里都是52,meshgrid返回的yv和xv分别是特征图每个像素点对应的y坐标和x坐标,如下所示
- tensor([[ 0, 0, 0, ..., 0, 0, 0],
- [ 1, 1, 1, ..., 1, 1, 1],
- [ 2, 2, 2, ..., 2, 2, 2],
- ...,
- [49, 49, 49, ..., 49, 49, 49],
- [50, 50, 50, ..., 50, 50, 50],
- [51, 51, 51, ..., 51, 51, 51]])
- tensor([[ 0, 1, 2, ..., 49, 50, 51],
- [ 0, 1, 2, ..., 49, 50, 51],
- [ 0, 1, 2, ..., 49, 50, 51],
- ...,
- [ 0, 1, 2, ..., 49, 50, 51],
- [ 0, 1, 2, ..., 49, 50, 51],
- [ 0, 1, 2, ..., 49, 50, 51]])
然后将xy坐标stack得到每个点的xy坐标,shape为(1,1,52,52,2),按先行后列的顺序,如下
- tensor([[[[[ 0., 0.],
- [ 1., 0.],
- [ 2., 0.],
- ...,
- [49., 0.],
- [50., 0.],
- [51., 0.]],
-
- [[ 0., 1.],
- [ 1., 1.],
- [ 2., 1.],
- ...,
- [49., 1.],
- [50., 1.],
- [51., 1.]],
-
- ...,
-
- [[ 0., 50.],
- [ 1., 50.],
- [ 2., 50.],
- ...,
- [49., 50.],
- [50., 50.],
- [51., 50.]],
-
- [[ 0., 51.],
- [ 1., 51.],
- [ 2., 51.],
- ...,
- [49., 51.],
- [50., 51.],
- [51., 51.]]]]], device='cuda:0', dtype=torch.float16)
然后将output和grid分别view调整维度,output中每个点对应一个预测框,output[..., :2]是预测框中心点相对于每个点的偏移,因此line8加上每个点的坐标grid并乘以stride还原回原图上得到原图上真实预测框的中心点坐标。line9则是通过
在得到原图上预测框的坐标以及类别和objectiveness后,接下来就是进行label assignment并计算loss,具体实现都在self.get_losses()中。其中输入outputs是将坐标还原到原图中的三个特征图的输出并concat得到的,shape为(b, 3549, 21),3549=52x52+26x26+13x13,21=4+1+16。
在函数get_losses()中,调用self.get_assignments进行标签分配,这里使用的方法是simota。关于simota和ota的原理可参考OTA: Optimal Transport Assignment for Object Detection 原理与代码解读-CSDN博客和https://blog.csdn.net/ooooocj/article/details/136569249。get_assignments()的完整实现如下
- @torch.no_grad()
- def get_assignments(
- self,
- batch_idx,
- num_gt,
- gt_bboxes_per_image, # (17,4)
- gt_classes, # (17)
- bboxes_preds_per_image, # (3549,4)
- expanded_strides, # (1,3549)
- x_shifts, # (1,3549)
- y_shifts, # (1,3549)
- cls_preds, # (4,3549,16)
- obj_preds, # (4,3549,1)
- mode="gpu",
- ):
-
- if mode == "cpu":
- print("-----------Using CPU for the Current Batch-------------")
- gt_bboxes_per_image = gt_bboxes_per_image.cpu().float()
- bboxes_preds_per_image = bboxes_preds_per_image.cpu().float()
- gt_classes = gt_classes.cpu().float()
- expanded_strides = expanded_strides.cpu().float()
- x_shifts = x_shifts.cpu()
- y_shifts = y_shifts.cpu()
-
- fg_mask, geometry_relation = self.get_geometry_constraint(
- gt_bboxes_per_image,
- expanded_strides,
- x_shifts,
- y_shifts,
- ) # (3549), (17,357)
- # fg_mask中True位置的anchor point至少在一个gt box的center area内,后续会用来进行label assignment。而不在fg_mask中False位置的anchor point
- # 不在任意一个gt box的center area内。
-
- bboxes_preds_per_image = bboxes_preds_per_image[fg_mask] # (357,4)
- cls_preds_ = cls_preds[batch_idx][fg_mask] # (357,16)
- obj_preds_ = obj_preds[batch_idx][fg_mask] # (357,1)
- num_in_boxes_anchor = bboxes_preds_per_image.shape[0] # 357
-
- if mode == "cpu":
- gt_bboxes_per_image = gt_bboxes_per_image.cpu()
- bboxes_preds_per_image = bboxes_preds_per_image.cpu()
-
- pair_wise_ious = bboxes_iou(gt_bboxes_per_image, bboxes_preds_per_image, False) # (17,357)
-
- gt_cls_per_image = (
- F.one_hot(gt_classes.to(torch.int64), self.num_classes)
- .float()
- ) # (17,16)
- pair_wise_ious_loss = -torch.log(pair_wise_ious + 1e-8) # (17,357)
-
- if mode == "cpu":
- cls_preds_, obj_preds_ = cls_preds_.cpu(), obj_preds_.cpu()
-
- with torch.cuda.amp.autocast(enabled=False):
- cls_preds_ = (
- cls_preds_.float().sigmoid_() * obj_preds_.float().sigmoid_()
- ).sqrt()
- pair_wise_cls_loss = F.binary_cross_entropy(
- cls_preds_.unsqueeze(0).repeat(num_gt, 1, 1), # (357,16)->(1,357,16)->(17,357,16)
- gt_cls_per_image.unsqueeze(1).repeat(1, num_in_boxes_anchor, 1), # (17,16)->(17,1,16)->(17,357,16)
- reduction="none"
- ).sum(-1) # (17,357), 共16个类别,每个类单独计算bce
- del cls_preds_
-
- cost = (
- pair_wise_cls_loss
- + 3.0 * pair_wise_ious_loss
- + float(1e6) * (~geometry_relation) # center area之外的anchor point对应的cost加上一个很大的值来过滤
- ) # (17,357)
-
- (
- num_fg, # 22
- gt_matched_classes, # (22)
- pred_ious_this_matching, # (22)
- matched_gt_inds, # (22)
- ) = self.simota_matching(cost, pair_wise_ious, gt_classes, num_gt, fg_mask)
- del pair_wise_cls_loss, cost, pair_wise_ious, pair_wise_ious_loss
-
- if mode == "cpu":
- gt_matched_classes = gt_matched_classes.cuda()
- fg_mask = fg_mask.cuda()
- pred_ious_this_matching = pred_ious_this_matching.cuda()
- matched_gt_inds = matched_gt_inds.cuda()
-
- return (
- gt_matched_classes,
- fg_mask, # (3549)
- pred_ious_this_matching,
- matched_gt_inds,
- num_fg,
- )
在ota中使用了center prior,即只有gt box中心有限区域内的anchor point作为正样本的candidate,而不是整个gt box内所有的anchor point都作为正样本的候选。函数get_geometry_constraint就是实现这个的
- def get_geometry_constraint(
- self, gt_bboxes_per_image, expanded_strides, x_shifts, y_shifts,
- ):
- """
- Calculate whether the center of an object is located in a fixed range of
- an anchor. This is used to avert inappropriate matching. It can also reduce
- the number of candidate anchors so that the GPU memory is saved.
- """
- expanded_strides_per_image = expanded_strides[0] # (3549)
- x_centers_per_image = ((x_shifts[0] + 0.5) * expanded_strides_per_image).unsqueeze(0) # (1,3549)
- y_centers_per_image = ((y_shifts[0] + 0.5) * expanded_strides_per_image).unsqueeze(0) # (1,3549)
-
- # in fixed center
- center_radius = 1.5 # 这里有可能center area区域比原目标还大
- center_dist = expanded_strides_per_image.unsqueeze(0) * center_radius # (1,3549)
- gt_bboxes_per_image_l = (gt_bboxes_per_image[:, 0:1]) - center_dist # (17,1) -> (17,3549)
- gt_bboxes_per_image_r = (gt_bboxes_per_image[:, 0:1]) + center_dist # (17,3549)
- gt_bboxes_per_image_t = (gt_bboxes_per_image[:, 1:2]) - center_dist # (17,3549)
- gt_bboxes_per_image_b = (gt_bboxes_per_image[:, 1:2]) + center_dist # (17,3549)
-
- c_l = x_centers_per_image - gt_bboxes_per_image_l # (1,3549)-(17,3549) -> (17,3549)
- c_r = gt_bboxes_per_image_r - x_centers_per_image
- c_t = y_centers_per_image - gt_bboxes_per_image_t
- c_b = gt_bboxes_per_image_b - y_centers_per_image
- center_deltas = torch.stack([c_l, c_t, c_r, c_b], 2) # (17,3549,4)
- is_in_centers = center_deltas.min(dim=-1).values > 0.0 # (17,3549)
- anchor_filter = is_in_centers.sum(dim=0) > 0 # (3549), 一共3549个anchor point, 对应位置为False, 说明这个anchor point不在任意一个gt box的center area内
- geometry_relation = is_in_centers[:, anchor_filter] # (17,357), anchor_filter.sum()==357,表明某个anchor point至少在一个gt box的center area内
-
- return anchor_filter, geometry_relation
最终返回的anchor_filter是一个shape为(3549, )的tensor,值全为True或False。前面说过三个特征图一共3549个anchor point,值为False对应的anchor point不在任意一个gt box的center area内,后续进行标签分配时只从值为True的anchor point中挑选。当我用自己的数据调试时,另一个输出geometry_relation的shape为(17, 357),17是图中gt的数量,357是anchor_filter中值为True的anchor point的数量,geometry_relation表示每个gt的中心区域内对应的anchor point。
然后用fg_mask也就是anchor_filter挑选出候选的正样本,然后计算ota的cost matrix,cost矩阵包括分类损失以及回归损失,注意分类的预测要取sigmoid后并与obj预测相乘再与gt计算交叉熵损失,最后加上float(1e6) * (~geometry_relation)
是对每个gt中心区域外的anchor加上一个特别大的cost,从而过滤它们。
在得到cost矩阵后,就是通过simota进行标签分配的过程了,具体实现在函数simota_matching中
- def simota_matching(self, cost, pair_wise_ious, gt_classes, num_gt, fg_mask):
- # (17,357),(17,357),(17),17,(3549)
- matching_matrix = torch.zeros_like(cost, dtype=torch.uint8) # (17,357)
-
- n_candidate_k = min(10, pair_wise_ious.size(1)) # 这里10就是文章中dynamic_k中的q
- topk_ious, _ = torch.topk(pair_wise_ious, n_candidate_k, dim=1) # (17,10)
- dynamic_ks = torch.clamp(topk_ious.sum(1).int(), min=1) # (17)
- # tensor([3, 1, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], device='cuda:0', dtype=torch.int32)
- # 每个gt选择q个最大iou值,相加取整作为为该gt分配的anchor point的个数
- for gt_idx in range(num_gt):
- _, pos_idx = torch.topk(
- cost[gt_idx], k=dynamic_ks[gt_idx], largest=False
- ) # 选择cost最小的dynamic k个anchor point作为分配的正样本,代替原始OTA中的sinkhorn算法
- matching_matrix[gt_idx][pos_idx] = 1
-
- del topk_ious, dynamic_ks, pos_idx
-
- anchor_matching_gt = matching_matrix.sum(0) # (357)
- # deal with the case that one anchor matches multiple ground-truths
- if anchor_matching_gt.max() > 1:
- multiple_match_mask = anchor_matching_gt > 1
- _, cost_argmin = torch.min(cost[:, multiple_match_mask], dim=0) # 当一个anchor point匹配多个gt时,选择cost最小的gt作为匹配的结果
- matching_matrix[:, multiple_match_mask] *= 0
- matching_matrix[cost_argmin, multiple_match_mask] = 1
- fg_mask_inboxes = anchor_matching_gt > 0 # (357), pos anchor point的index
- num_fg = fg_mask_inboxes.sum().item()
- # num_fg==22, anchor_matching_gt.sum()==tensor(22, device='cuda:0')
- # 当if anchor_matching_gt.max() > 1成立时,num_fg > matching_matrix.sum().item()
-
- # fg_mask.sum().item() == 357
- fg_mask[fg_mask.clone()] = fg_mask_inboxes
- # fg_mask.sum().item() == 22
- # 更新fg_mask,本来fg_mask中有357个anchor point初步过滤后再gt center area内,然后经过simota第二次匹配找到pos anchor point
- # 注意这里[]内fg_mask.clone()的作用,是找到那357个的值,然后用fg_mask_inboxes替换
- # 这里fg_mask更新后,不用return,外面的fg_mask也更新了
- # 这里的fg_mask就是所有3549个anchor中哪几个anchor是正样本,正样本处的值为1
-
- matched_gt_inds = matching_matrix[:, fg_mask_inboxes].argmax(0)
- # matched_gt_inds == tensor([7, 7, 8, 2, 1, 9, 4, 4, 10, 10, 3, 13, 5, 14, 12, 6, 15, 16, 11, 0, 0, 0], device='cuda:0')
- # 每个pos anchor匹配到了第几个gt的index
- # print(gt_classes) == tensor([6, 5, 12, 12, 12, 5, 12, 12, 5, 12, 12, 5, 5, 5, 5, 12, 12], device='cuda:0', dtype=torch.float16)
- gt_matched_classes = gt_classes[matched_gt_inds] # 每个pos anchor匹配到的gt的实际类别索引
- # print(gt_matched_classes) == tensor([12, 12, 5, 12, 5, 12, 12, 12, 12, 12, 12, 5, 5, 5, 5, 12, 12, 12, 5, 6, 6, 6.], device='cuda:0', dtype=torch.float16)
-
- pred_ious_this_matching = (matching_matrix * pair_wise_ious).sum(0)[
- fg_mask_inboxes
- ]
- # 这里sum(0)沿列求和,一列只有1个值大于0,因为上面处理完后,一个anchor只能匹配一个gt。但一行可以有多个大于0的值,即1个gt可以和多个anchor匹配
- return num_fg, gt_matched_classes, pred_ious_this_matching, matched_gt_inds
simota和原本的ota的区别是,在得到cost矩阵后,ota通过sinkhorn算法进行匹配,而simota则直接选择topk个cost最小的anchor作为正样本,和最早的faster rcnn中的topk相似,只不过那里是选择iou最小,这里是选择cost最小,这里的cost不仅考虑了iou还考虑了分类损失和center prior。另外这里的k不是认为设置的固定值,而是dynamic k,具体是根据先选择q个iou最大的anchor(这里q仍然是人工设定的代码中取10),然后这10个iou求和取整得到k值。
- n_candidate_k = min(10, pair_wise_ious.size(1)) # 这里10就是文章中dynamic_k中的q
- topk_ious, _ = torch.topk(pair_wise_ious, n_candidate_k, dim=1) # (17,10)
- dynamic_ks = torch.clamp(topk_ious.sum(1).int(), min=1) # (17)
一个gt可以匹配多个anchor,但一个anchor只能匹配一个gt,根据上面的规则选择cost最小的k个anchor后如果存在一个anchor匹配多个gt的情况,选择cost最小对应的gt作为匹配结果。
样本分配完后,就是计算损失了,这里没什么好讲的,回归损失采用的iou loss,分类损失和obj损失都是bce loss。yolox中作者在最后15个epoch关闭了mosaic数据增强,并添加了额外的L1 loss来增加回归的精度,这里L1 loss就是在特征图上计算的,预测就是特征图的原始输出,没有像iou loss一样加上grid并乘以stride映射会原图,这里target是将label反向映射到特征图上。
- def get_l1_target(self, l1_target, gt, stride, x_shifts, y_shifts, eps=1e-8):
- l1_target[:, 0] = gt[:, 0] / stride - x_shifts
- l1_target[:, 1] = gt[:, 1] / stride - y_shifts
- l1_target[:, 2] = torch.log(gt[:, 2] / stride + eps)
- l1_target[:, 3] = torch.log(gt[:, 3] / stride + eps)
- return l1_target
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。