赞
踩
本文为个人学习过程中所记录笔记,便于梳理思路和后续查看用,如有错误,感谢批评指正!
code:https://github.com/Megvii-BaseDetection/YOLOX
paper:https://arxiv.org/abs/2107.08430
参考:
【1】从零开始实现yolox四:模型的训练(一)损失函数与标签分配
1、BACKBONE
默认采用yolov3的darknet53作为backbone,其中采用backbone的三个不同分辨率输出的分支C3, C4, C5.
2、NECK
yolox中采用了PAFPN作为neck,其与FPN相比结构稍为复杂,但总体思想还是多尺度特征的融合。整体结构草图如下:
3、HEAD
yolox中采用解耦头,将回归和分类分支进行解耦。具体结构如图所示:
4、数据集增强
4.1 mosaic
数据集增强顺序为mosaic+mixup,即1组mosaic图片和一张常规图片进行一次mixup。
训练代码分析
模型中采用解耦头,针对三种不同分辨率,每一种分辨率输出特征图均将分类和回归解耦,一共有box头,回归头,置信度头(暂且这样称呼吧)。然后将三个头concat,最后将三种尺度的进一步concat。
if self.training: output = torch.cat([reg_output, obj_output, cls_output], 1) output, grid = self.get_output_and_grid( output, k, stride_this_level, xin[0].type() ) #输出的处理 def get_output_and_grid(self, output, k, stride, dtype): grid = self.grids[k] #[0, 0, 0] batch_size = output.shape[0] n_ch = 5 + self.num_classes hsize, wsize = output.shape[-2:] #h,w if grid.shape[2:4] != output.shape[2:4]: #进入分支 yv, xv = meshgrid([torch.arange(hsize), torch.arange(wsize)]) #生成网格,类似于位置编码 grid = torch.stack((xv, yv), 2).view(1, 1, hsize, wsize, 2).type(dtype) #1*1*80*80*2 self.grids[k] = grid ##1*1*80*80*2 output = output.view(batch_size, 1, n_ch, hsize, wsize) output = output.permute(0, 1, 3, 4, 2).reshape( batch_size, hsize * wsize, -1 ) #batch * 6400 * n_ch grid = grid.view(1, -1, 2) #1,80*80,2 output[..., :2] = (output[..., :2] + grid) * stride #前两列box的左上角坐标加入grid乘以stride output[..., 2:4] = torch.exp(output[..., 2:4]) * stride #2,3列box右下角坐标归一化然后乘以stride return output, grid
loss计算:
def get_losses( self, imgs, x_shifts, y_shifts, expanded_strides, labels, outputs, origin_preds, dtype, ): bbox_preds = outputs[:, :, :4] # [batch, n_anchors_all, 4] obj_preds = outputs[:, :, 4:5] # [batch, n_anchors_all, 1] cls_preds = outputs[:, :, 5:] # [batch, n_anchors_all, n_cls] # calculate targets nlabel = (labels.sum(dim=2) > 0).sum(dim=1) # number of objects total_num_anchors = outputs.shape[1] x_shifts = torch.cat(x_shifts, 1) # [1, n_anchors_all] y_shifts = torch.cat(y_shifts, 1) # [1, n_anchors_all] expanded_strides = torch.cat(expanded_strides, 1) if self.use_l1: origin_preds = torch.cat(origin_preds, 1) cls_targets = [] reg_targets = [] l1_targets = [] obj_targets = [] fg_masks = [] num_fg = 0.0 num_gts = 0.0 for batch_idx in range(outputs.shape[0]): num_gt = int(nlabel[batch_idx]) #当前图片中gt数目 num_gts += num_gt if num_gt == 0: #如果当前图片gt数目为0,就新建几个空张量 cls_target = outputs.new_zeros((0, self.num_classes)) #类别 reg_target = outputs.new_zeros((0, 4)) #回归框 l1_target = outputs.new_zeros((0, 4)) # obj_target = outputs.new_zeros((total_num_anchors, 1)) #置信度 fg_mask = outputs.new_zeros(total_num_anchors).bool() #能和gt匹配的预测框的索引,这里则全为false else: gt_bboxes_per_image = labels[batch_idx, :num_gt, 1:5] #真值框坐标 gt_classes = labels[batch_idx, :num_gt, 0] #真值类别 bboxes_preds_per_image = bbox_preds[batch_idx] #预测框 try: ( gt_matched_classes, fg_mask, pred_ious_this_matching, matched_gt_inds, num_fg_img, ) = self.get_assignments( # noqa batch_idx, num_gt, gt_bboxes_per_image, gt_classes, bboxes_preds_per_image, expanded_strides, x_shifts, y_shifts, cls_preds, obj_preds, ) except RuntimeError as e: # TODO: the string might change, consider a better way if "CUDA out of memory. " not in str(e): raise # RuntimeError might not caused by CUDA OOM logger.error( "OOM RuntimeError is raised due to the huge memory cost during label assignment. \ CPU mode is applied in this batch. If you want to avoid this issue, \ try to reduce the batch size or image size." ) torch.cuda.empty_cache() #释放显存,有文档说该句可以省略,解释为:'''当显存中的数据没有任何变量引用时,会自动释放显存,但释放的显存在Nvidia中看不到,只有加上这一句,才会在Nvidia-smi中释放''' ( gt_matched_classes, fg_mask, pred_ious_this_matching, matched_gt_inds, num_fg_img, ) = self.get_assignments( # noqa batch_idx, num_gt, gt_bboxes_per_image, gt_classes, bboxes_preds_per_image, expanded_strides, x_shifts, y_shifts, cls_preds, obj_preds, "cpu", ) torch.cuda.empty_cache() num_fg += num_fg_img cls_target = F.one_hot( gt_matched_classes.to(torch.int64), self.num_classes ) * pred_ious_this_matching.unsqueeze(-1) obj_target = fg_mask.unsqueeze(-1) reg_target = gt_bboxes_per_image[matched_gt_inds] if self.use_l1: l1_target = self.get_l1_target( outputs.new_zeros((num_fg_img, 4)), gt_bboxes_per_image[matched_gt_inds], expanded_strides[0][fg_mask], x_shifts=x_shifts[0][fg_mask], y_shifts=y_shifts[0][fg_mask], ) cls_targets.append(cls_target) reg_targets.append(reg_target) obj_targets.append(obj_target.to(dtype)) fg_masks.append(fg_mask) if self.use_l1: l1_targets.append(l1_target) cls_targets = torch.cat(cls_targets, 0) reg_targets = torch.cat(reg_targets, 0) obj_targets = torch.cat(obj_targets, 0) fg_masks = torch.cat(fg_masks, 0) if self.use_l1: l1_targets = torch.cat(l1_targets, 0) num_fg = max(num_fg, 1) loss_iou = ( self.iou_loss(bbox_preds.view(-1, 4)[fg_masks], reg_targets) ).sum() / num_fg loss_obj = ( self.bcewithlog_loss(obj_preds.view(-1, 1), obj_targets) ).sum() / num_fg # loss_obj = ( # self.focal_loss(obj_preds.sigmoid().view(-1, 1), obj_targets) # ).sum() / num_fg #cheng loss_cls = ( self.bcewithlog_loss( cls_preds.view(-1, self.num_classes)[fg_masks], cls_targets ) ).sum() / num_fg #cheng 原本是self.bcewithlog_loss() if self.use_l1: loss_l1 = ( self.l1_loss(origin_preds.view(-1, 4)[fg_masks], l1_targets) ).sum() / num_fg else: loss_l1 = 0.0 reg_weight = 5.0 loss = reg_weight * loss_iou + loss_obj + loss_cls + loss_l1 return ( loss, reg_weight * loss_iou, loss_obj, loss_cls, loss_l1, num_fg / max(num_gts, 1), )
标签分配:
具体步骤如下:
1、通过几何约束筛选第一遍anchor,具体函数见get_geometry_constraint,在每个gt中心画一个边长为3的正方形框,然后anchor中心在该正方形内的便是正样本。
2、将第一遍筛选得到的正样本继续采用simOTA算法进行筛选。见函数get_assignments和simota_matching。具体地,将第一遍筛选得到的正样本计算得到一个cost矩阵。按照iou排序,得到前十个框,然后将是个iou值相加取整,作为动态k个候选框(小于1的取值1)。同时对单个anchor对应多个gt的情况进行处理。得到最终动态分配正样本的结果。
@torch.no_grad() def get_assignments( self, batch_idx, num_gt, gt_bboxes_per_image, gt_classes, bboxes_preds_per_image, expanded_strides, x_shifts, y_shifts, cls_preds, obj_preds, mode="gpu", ): if mode == "cpu": print("-----------Using CPU for the Current Batch-------------") gt_bboxes_per_image = gt_bboxes_per_image.cpu().float() bboxes_preds_per_image = bboxes_preds_per_image.cpu().float() gt_classes = gt_classes.cpu().float() expanded_strides = expanded_strides.cpu().float() x_shifts = x_shifts.cpu() y_shifts = y_shifts.cpu() fg_mask, geometry_relation = self.get_geometry_constraint( gt_bboxes_per_image, expanded_strides, x_shifts, y_shifts, ) #几何约束,以gt中心为中心,边长为3的正方形,看预测框中心是否在这个正方形内,进行过滤。 # bboxes_preds_per_image.shape, [n_anchors_all, 4] bboxes_preds_per_image = bboxes_preds_per_image[fg_mask] cls_preds_ = cls_preds[batch_idx][fg_mask] obj_preds_ = obj_preds[batch_idx][fg_mask] num_in_boxes_anchor = bboxes_preds_per_image.shape[0] if mode == "cpu": gt_bboxes_per_image = gt_bboxes_per_image.cpu() bboxes_preds_per_image = bboxes_preds_per_image.cpu() pair_wise_ious = bboxes_iou(gt_bboxes_per_image, bboxes_preds_per_image, False) gt_cls_per_image = ( F.one_hot(gt_classes.to(torch.int64), self.num_classes) .float() ) pair_wise_ious_loss = -torch.log(pair_wise_ious + 1e-8) if mode == "cpu": cls_preds_, obj_preds_ = cls_preds_.cpu(), obj_preds_.cpu() with torch.cuda.amp.autocast(enabled=False): cls_preds_ = ( cls_preds_.float().sigmoid_() * obj_preds_.float().sigmoid_() ).sqrt() pair_wise_cls_loss = F.binary_cross_entropy( cls_preds_.unsqueeze(0).repeat(num_gt, 1, 1), gt_cls_per_image.unsqueeze(1).repeat(1, num_in_boxes_anchor, 1), reduction="none" ).sum(-1) del cls_preds_ cost = ( pair_wise_cls_loss + 3.0 * pair_wise_ious_loss + float(1e6) * (~geometry_relation) ) #第一次得出的anchor计算分类损失和iou损失得到cost ( num_fg, gt_matched_classes, pred_ious_this_matching, matched_gt_inds, ) = self.simota_matching(cost, pair_wise_ious, gt_classes, num_gt, fg_mask) del pair_wise_cls_loss, cost, pair_wise_ious, pair_wise_ious_loss if mode == "cpu": gt_matched_classes = gt_matched_classes.cuda() fg_mask = fg_mask.cuda() pred_ious_this_matching = pred_ious_this_matching.cuda() matched_gt_inds = matched_gt_inds.cuda() return ( gt_matched_classes, fg_mask, pred_ious_this_matching, matched_gt_inds, num_fg, )
几何约束:
def get_geometry_constraint( self, gt_bboxes_per_image, expanded_strides, x_shifts, y_shifts, ): """ Calculate whether the center of an object is located in a fixed range of an anchor. This is used to avert inappropriate matching. It can also reduce the number of candidate anchors so that the GPU memory is saved. """ expanded_strides_per_image = expanded_strides[0] x_centers_per_image = ((x_shifts[0] + 0.5) * expanded_strides_per_image).unsqueeze(0) y_centers_per_image = ((y_shifts[0] + 0.5) * expanded_strides_per_image).unsqueeze(0) #计算映射回原图的中心点的坐标,单位1加上0.5, #便是中心点,然后乘以对应的下采样倍数,乘号两边维度均为【8400】,结果为【1*8400】 # in fixed center center_radius = 1.5 center_dist = expanded_strides_per_image.unsqueeze(0) * center_radius #torch.Size([1, 8400]) #gt_bboxes_per_image,gt的中心以及宽高 #以gt为中心画一个边长为3的正方形, gt_bboxes_per_image格式为中心点坐标加上宽高 # print(gt_bboxes_per_image.shape) #[num_gt, 4] # print(gt_bboxes_per_image[:, 0:1].shape) torch.Size([num_gt, 1]) gt_bboxes_per_image_l = (gt_bboxes_per_image[:, 0:1]) - center_dist #[num_gt, num_anchor] gt_bboxes_per_image_r = (gt_bboxes_per_image[:, 0:1]) + center_dist gt_bboxes_per_image_t = (gt_bboxes_per_image[:, 1:2]) - center_dist gt_bboxes_per_image_b = (gt_bboxes_per_image[:, 1:2]) + center_dist c_l = x_centers_per_image - gt_bboxes_per_image_l c_r = gt_bboxes_per_image_r - x_centers_per_image c_t = y_centers_per_image - gt_bboxes_per_image_t c_b = gt_bboxes_per_image_b - y_centers_per_image center_deltas = torch.stack([c_l, c_t, c_r, c_b], 2) #[num_gt,num_anchor,4] is_in_centers = center_deltas.min(dim=-1).values > 0.0 #[num_gt, num_anchor] anchor_filter = is_in_centers.sum(dim=0) > 0 #剔除针对每个gt,gt中心值都不在框内的anchor,size为[num_anchor] geometry_relation = is_in_centers[:, anchor_filter] return anchor_filter, geometry_relation
simOTA算法:
def simota_matching(self, cost, pair_wise_ious, gt_classes, num_gt, fg_mask): # Dynamic K # --------------------------------------------------------------- matching_matrix = torch.zeros_like(cost, dtype=torch.uint8) n_candidate_k = min(10, pair_wise_ious.size(1)) topk_ious, _ = torch.topk(pair_wise_ious, n_candidate_k, dim=1) #找出最多前十个iou最匹配的anchor [:, 10] dynamic_ks = torch.clamp(topk_ious.sum(1).int(), min=1) #将输入input张量每个元素的范围限制到区间 [min,max], #返回结果到一个新张量;十个iou求和取整,即做为当前gt匹配得到的iou数量 for gt_idx in range(num_gt): _, pos_idx = torch.topk( cost[gt_idx], k=dynamic_ks[gt_idx], largest=False ) matching_matrix[gt_idx][pos_idx] = 1 del topk_ious, dynamic_ks, pos_idx anchor_matching_gt = matching_matrix.sum(0) # deal with the case that one anchor matches multiple ground-truths if anchor_matching_gt.max() > 1: #处理一个anchor对应多个gt的情况 multiple_match_mask = anchor_matching_gt > 1 #找出单个anchor对应多个gt的位置mask _, cost_argmin = torch.min(cost[:, multiple_match_mask], dim=0) #找出单个anchor对应多个gt的位置mask中的最小值mask matching_matrix[:, multiple_match_mask] *= 0 #单个anchor对应多个gt的位置全置为0 matching_matrix[cost_argmin, multiple_match_mask] = 1 #重复对应位置中的cost最小值为1,其余为0 fg_mask_inboxes = anchor_matching_gt > 0 num_fg = fg_mask_inboxes.sum().item() fg_mask[fg_mask.clone()] = fg_mask_inboxes matched_gt_inds = matching_matrix[:, fg_mask_inboxes].argmax(0) gt_matched_classes = gt_classes[matched_gt_inds] pred_ious_this_matching = (matching_matrix * pair_wise_ious).sum(0)[ fg_mask_inboxes ] return num_fg, gt_matched_classes, pred_ious_this_matching, matched_gt_inds
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。