赞
踩
YOLOX简洁且高效,分享具体实现过程。部分代码可以迁移,很具有参考价值。
测试比较简单,首先看demo.py。
-运行需要指定三个参数:
–path:测试图片路径
–exp_file:指定使用模型配置文件,如default/yolox_m.py
–ckpt:预训练权重,如yolox_m.pth
outputs, img_info = predictor.inference(image_name) # output:(14,7):x1,y1,x2,y2,conf,conf,class
result_image = predictor.visual(outputs[0], img_info, predictor.confthre)
img = cv2.imread(img)
ratio = min(self.test_size[0] / img.shape[0], self.test_size[1] / img.shape[1])
# 这里是对原图做比例缩放,至640*640
img, _ = self.preproc(img, None, self.test_size) # 转为(3,640,640)
with torch.no_grad():
outputs = self.model(img) # ([1, 8400, 85]):8400 = 80*80 +40*40 +20*20; 85 = 80+4+1
outputs = postprocess(
outputs, self.num_classes, self.confthre,
self.nmsthre, class_agnostic=True
)
fpn_outs = self.backbone(x)
# (128, 80, 80]) (256, 40, 40) (512, 20, 20) 下采样的三个特征图
outputs = self.head(fpn_outs)
for k, (cls_conv, reg_conv, stride_this_level, x) in enumerate( zip(self.cls_convs, self.reg_convs, self.strides, xin) ): # 循环3次,每次对一个特征图进行分类和回归 x = self.stems[k](x) # 将特征图维度变换至128,如特征1:(1,128,80,80) cls_x = x reg_x = x cls_feat = cls_conv(cls_x) # 这里是解藕头,连续两个conv(128,128,3,1)+bn+SiLU cls_output = self.cls_preds[k](cls_feat) # Conv2d(128, 20),分类 reg_feat = reg_conv(reg_x) # 解藕头,同上 reg_output = self.reg_preds[k](reg_feat) # Conv2d(128, 4),回归 obj_output = self.obj_preds[k](reg_feat) # Conv2d(128, 1),目标预测 output = torch.cat( [reg_output, obj_output.sigmoid(), cls_output.sigmoid()], 1 ) # (1,25,80,80) outputs.append(output) # (1,25,80,80) (1,25,40,40) (1,25,20,20) self.hw = [x.shape[-2:] for x in outputs] # torch.Size(80, 80)(40, 40), (20, 20) outputs = torch.cat( [x.flatten(start_dim=2) for x in outputs], dim=2 ).permute(0, 2, 1) # ([1, 8400, 25]) if self.decode_in_inference: # True return self.decode_outputs(outputs, dtype=xin[0].type()) else: return outputs
def decode_outputs(self, outputs, dtype): grids = [] strides = [] for (hsize, wsize), stride in zip(self.hw, self.strides): # 80,40,20,对应下采样[8, 16, 32] yv, xv = torch.meshgrid([torch.arange(hsize), torch.arange(wsize)]) # 以(80,80)特征图为例,生成两个(80,80)坐标点 grid = torch.stack((xv, yv), 2).view(1, -1, 2) # ([1, 6400, 2]) grids.append(grid) shape = grid.shape[:2] # ([1, 6400]) strides.append(torch.full((*shape, 1), stride)) # (1,6400,1)*[8] (1,1600,1)*[16] (1,400,1)*[32] grids = torch.cat(grids, dim=1).type(dtype) strides = torch.cat(strides, dim=1).type(dtype) outputs[..., :2] = (outputs[..., :2] + grids) * strides # (预测x、y+anchor中心点坐标)*下采样倍数 outputs[..., 2:4] = torch.exp(outputs[..., 2:4]) * strides # (预测w、h)*下采样倍数 return outputs # ([1, 8400, 85]):8400 = 80*80 +40*40 +20*20; 85 = 80+4+1
outputs = postprocess(outputs, self.num_classes, self.confthre, self.nmsthre, class_agnostic=True): def postprocess(prediction, num_classes, conf_thre=0.7, nms_thre=0.45, class_agnostic=False): box_corner = prediction.new(prediction.shape) ## 转为左上角与右下角坐标:x1 y1 x2 y2 box_corner[:, :, 0] = prediction[:, :, 0] - prediction[:, :, 2] / 2 box_corner[:, :, 1] = prediction[:, :, 1] - prediction[:, :, 3] / 2 box_corner[:, :, 2] = prediction[:, :, 0] + prediction[:, :, 2] / 2 box_corner[:, :, 3] = prediction[:, :, 1] + prediction[:, :, 3] / 2 prediction[:, :, :4] = box_corner[:, :, :4] output = [None for _ in range(len(prediction))] for i, image_pred in enumerate(prediction): # image_pred:(8400, 85) # If none are remaining => process next image if not image_pred.size(0): continue # Get score and class with highest confidence class_conf, class_pred = torch.max(image_pred[:, 5: 5 + num_classes], 1, keepdim=True) # 类别分数*置信度,用0.3筛选 conf_mask = (image_pred[:, 4] * class_conf.squeeze() >= conf_thre).squeeze() # Detections ordered as (x1, y1, x2, y2, obj_conf, class_conf, class_pred) detections = torch.cat((image_pred[:, :5], class_conf, class_pred.float()), 1) # (8400, 7) detections = detections[conf_mask] # (93, 7) 根据0.3置信度筛选后 if class_agnostic: nms_out_index = torchvision.ops.nms( detections[:, :4], detections[:, 4] * detections[:, 5], nms_thre, ) # NMS(根据分数和位置):返回剩余目标的index else: nms_out_index = torchvision.ops.batched_nms( detections[:, :4], detections[:, 4] * detections[:, 5], detections[:, 6], nms_thre, ) # 未执行 detections = detections[nms_out_index] # (14,7) if output[i] is None: output[i] = detections else: output[i] = torch.cat((output[i], detections)) return output
outputs, img_info = predictor.inference(image_name)
result_image = predictor.visual(outputs[0], img_info, predictor.confthre)
if save_result:
save_folder = os.path.join(
vis_folder, time.strftime("%Y_%m_%d_%H_%M_%S", current_time)
)
os.makedirs(save_folder, exist_ok=True)
save_file_name = os.path.join(save_folder, os.path.basename(image_name))
logger.info("Saving detection result in {}".format(save_file_name))
cv2.imwrite(save_file_name, result_image)
ch = cv2.waitKey(0)
def visual(self, output, img_info, cls_conf=0.35): ratio = img_info["ratio"] # 缩放比例:0.45 img = img_info["raw_img"] # (1050, 1400, 3) if output is None: return img output = output.cpu() bboxes = output[:, 0:4] # preprocessing: resize bboxes /= ratio cls = output[:, 6] scores = output[:, 4] * output[:, 5] vis_res = vis(img, bboxes, scores, cls, cls_conf, self.cls_names) return vis_res
def vis(img, boxes, scores, cls_ids, conf=0.5, class_names=None): for i in range(len(boxes)): box = boxes[i] cls_id = int(cls_ids[i]) score = scores[i] if score < conf: continue x0 = int(box[0]) y0 = int(box[1]) x1 = int(box[2]) y1 = int(box[3]) color = (_COLORS[cls_id] * 255).astype(np.uint8).tolist() text = '{}:{:.1f}%'.format(class_names[cls_id], score * 100) txt_color = (0, 0, 0) if np.mean(_COLORS[cls_id]) > 0.5 else (255, 255, 255) font = cv2.FONT_HERSHEY_SIMPLEX txt_size = cv2.getTextSize(text, font, 0.4, 1)[0] cv2.rectangle(img, (x0, y0), (x1, y1), color, 2) txt_bk_color = (_COLORS[cls_id] * 255 * 0.7).astype(np.uint8).tolist() cv2.rectangle( img, (x0, y0 + 1), (x0 + txt_size[0] + 1, y0 + int(1.5*txt_size[1])), txt_bk_color, -1 ) cv2.putText(img, text, (x0, y0 + txt_size[1]), font, 0.4, txt_color, thickness=1) return img
训练阶段数据格式:在datadets/VOCdevkit/VOC2007/文件夹中存放三个文件夹,分别为:JPEGImages(若干张jpg图像)Annotations(对应的若干个xml标注)ImageSets文件夹。
训练从train.py第line 110进入trainer.train()
yolox.py line30:
fpn_outs = self.backbone(x) if self.training: assert targets is not None loss, iou_loss, conf_loss, cls_loss, l1_loss, num_fg = self.head( fpn_outs, targets, x ) outputs = { "total_loss": loss, "iou_loss": iou_loss, "l1_loss": l1_loss, "conf_loss": conf_loss, "cls_loss": cls_loss, "num_fg": num_fg, } else: outputs = self.head(fpn_outs) # Iou损失、类别与置信度损失 return outputs
主要函数是 self.get_assignments,用来分配正标签,下面会给出具体分析
以及其中的self.dynamic_k_matching函数,动态获得k个正样本
class YOLOXHead(nn.Module): def get_losses(self,imgs, x_shifts, y_shifts, expanded_strides, labels, outputs, origin_preds, dtype): bbox_preds = outputs[:, :, :4] # [bs, n_anchors, 4]:([8, 8400, 4]) obj_preds = outputs[:, :, 4].unsqueeze(-1) # ([8, 8400, 1]) cls_preds = outputs[:, :, 5:] # ([8, 8400, 20]) # calculate targets nlabel = (labels.sum(dim=2) > 0).sum(dim=1) # gt_num:[ 5, 6, 21, 2, 5, 2, 2, 6] total_num_anchors = outputs.shape[1] # 8400 x_shifts = torch.cat(x_shifts, 1) # [1, n_anchors_all] x_shifts[0]:(1, 6400) x_shifts[1]:(1, 1600) x_shifts[2]:(1, 400) [0,1,2,...19,0,1,2...] y_shifts = torch.cat(y_shifts, 1) # [1, n_anchors_all] ([1, 8400]) expanded_strides = torch.cat(expanded_strides, 1) # (1,8400): 6400*[8,8,8...] 1600*[16,16,16...] 400*[32,32,32,...] if self.use_l1: origin_preds = torch.cat(origin_preds, 1) cls_targets = [] reg_targets = [] l1_targets = [] obj_targets = [] fg_masks = [] num_fg = 0.0 num_gts = 0.0 for batch_idx in range(outputs.shape[0]): # batchsize num_gt = int(nlabel[batch_idx]) num_gts += num_gt # 5 if num_gt == 0: cls_target = outputs.new_zeros((0, self.num_classes)) reg_target = outputs.new_zeros((0, 4)) l1_target = outputs.new_zeros((0, 4)) obj_target = outputs.new_zeros((total_num_anchors, 1)) fg_mask = outputs.new_zeros(total_num_anchors).bool() else: gt_bboxes_per_image = labels[batch_idx, :num_gt, 1:5] # (8,4) gt_classes = labels[batch_idx, :num_gt, 0] # (8) gt_num bboxes_preds_per_image = bbox_preds[batch_idx] # (8400,4) try: ( gt_matched_classes, fg_mask, pred_ious_this_matching, matched_gt_inds, num_fg_img, ) = self.get_assignments( batch_idx, num_gt, total_num_anchors, gt_bboxes_per_image, gt_classes, bboxes_preds_per_image, expanded_strides, x_shifts, y_shifts, cls_preds, bbox_preds, obj_preds, labels, imgs) # 以上函数:分配正负样本。返回值可查看 3.1节self.get_assignments 最后结果 torch.cuda.empty_cache() num_fg += num_fg_img # 34 cls_target = F.one_hot( gt_matched_classes.to(torch.int64), self.num_classes ) * pred_ious_this_matching.unsqueeze(-1) # (34) --> ( 34,20 ) *iou_score obj_target = fg_mask.unsqueeze(-1) # ( 8400,1 ) :34*True reg_target = gt_bboxes_per_image[matched_gt_inds] # ( 34,4 ) cls_targets.append(cls_target) reg_targets.append(reg_target) obj_targets.append(obj_target.to(dtype)) fg_masks.append(fg_mask) if self.use_l1: # False l1_targets.append(l1_target) cls_targets = torch.cat(cls_targets, 0) # ( 385,20 ) reg_targets = torch.cat(reg_targets, 0) # ( 385,4 ) obj_targets = torch.cat(obj_targets, 0) # ( 67200,1 ) 8400*8 = 67200 fg_masks = torch.cat(fg_masks, 0) # ( 67200 ) if self.use_l1: l1_targets = torch.cat(l1_targets, 0) num_fg = max(num_fg, 1) loss_iou = ( self.iou_loss(bbox_preds.view(-1, 4)[fg_masks], reg_targets) ).sum() / num_fg loss_obj = ( self.bcewithlog_loss(obj_preds.view(-1, 1), obj_targets) ).sum() / num_fg loss_cls = ( self.bcewithlog_loss( cls_preds.view(-1, self.num_classes)[fg_masks], cls_targets ) ).sum() / num_fg if self.use_l1: loss_l1 = ( self.l1_loss(origin_preds.view(-1, 4)[fg_masks], l1_targets) ).sum() / num_fg else: loss_l1 = 0.0 reg_weight = 5.0 loss = reg_weight * loss_iou + loss_obj + loss_cls + loss_l1 return ( loss, reg_weight * loss_iou, loss_obj, loss_cls, loss_l1, num_fg / max(num_gts, 1), )
这里是把标签gt分配到三张特征图上(共8400个点),并作出正负样本分类。
def get_assignments( self, batch_idx, num_gt, total_num_anchors, gt_bboxes_per_image, gt_classes, bboxes_preds_per_image, expanded_strides, x_shifts, y_shifts, cls_preds, bbox_preds, obj_preds, labels, imgs, mode="gpu"): fg_mask, is_in_boxes_and_center = self.get_in_boxes_info( gt_bboxes_per_image, expanded_strides, x_shifts, y_shifts, total_num_anchors, num_gt) # (8400) : 3473*[True] # (5, 3473) :325*[True] bboxes_preds_per_image = bboxes_preds_per_image[fg_mask] # ([8400, 4]) ---> ([3473, 4]) cls_preds_ = cls_preds[batch_idx][fg_mask] # ([3473, 20]) obj_preds_ = obj_preds[batch_idx][fg_mask] # ([3473, 1]) num_in_boxes_anchor = bboxes_preds_per_image.shape[0] # 3473 pair_wise_ious = bboxes_iou(gt_bboxes_per_image, bboxes_preds_per_image, False # (5,4) & (3473,4) --> (5, 3473) gt_cls_per_image = ( F.one_hot(gt_classes.to(torch.int64), self.num_classes) .float() .unsqueeze(1) .repeat(1, num_in_boxes_anchor, 1)) # (5,1) --> (5,20) --> (5,3473,20) pair_wise_ious_loss = -torch.log(pair_wise_ious + 1e-8) # (5, 3473) with torch.cuda.amp.autocast(enabled=False): cls_preds_ = ( cls_preds_.float().unsqueeze(0).repeat(num_gt, 1, 1).sigmoid_() * obj_preds_.float().unsqueeze(0).repeat(num_gt, 1, 1).sigmoid_() ) # ( 3473, 20 ) --> sigmoid --> ( 5, 3473, 20 ) pair_wise_cls_loss = F.binary_cross_entropy( cls_preds_.sqrt_(), gt_cls_per_image, reduction="none" ).sum(-1) # ( 5, 3473, 20 ) & ( 5, 3473, 20 ) ---> ( 5,3473 ) del cls_preds_ cost = ( pair_wise_cls_loss + 3.0 * pair_wise_ious_loss + 100000.0 * (~is_in_boxes_and_center) ) # ( 5, 3473 ) ( num_fg, gt_matched_classes, pred_ious_this_matching, matched_gt_inds, ) = self.dynamic_k_matching(cost, pair_wise_ious, gt_classes, num_gt, fg_mask) del pair_wise_cls_loss, cost, pair_wise_ious, pair_wise_ious_loss return ( gt_matched_classes, #(34)34个正样本的类别 fg_mask, #(8400)中有34个True pred_ious_this_matching, #(34)34个正样本的IOU matched_gt_inds, # (34) 34个正样本,跟第几个gt更匹配 num_fg, )
对预测的8400个目标作初步筛选
根据anchor中心点与gt左上右下的偏移值,筛选出偏移大于0的结果(计算b_l, b_t, b_r, b_b的位置)(c_l, c_t, c_r, c_b也是同理)
def get_in_boxes_info( self, gt_bboxes_per_image, expanded_strides, x_shifts, y_shifts, total_num_anchors, num_gt): expanded_strides_per_image = expanded_strides[0] # (8400) x_shifts_per_image = x_shifts[0] * expanded_strides_per_image # (8400) [0,1,2...79,...0,1,2,...39,0,1,2,...19]*stride y_shifts_per_image = y_shifts[0] * expanded_strides_per_image x_centers_per_image = ( (x_shifts_per_image + 0.5 * expanded_strides_per_image) .unsqueeze(0) .repeat(num_gt, 1) # (5,8400) 8400个中心点坐标(640*640图像上的绝对值) ) # [n_anchor] -> [n_gt, n_anchor] y_centers_per_image = ( (y_shifts_per_image + 0.5 * expanded_strides_per_image) .unsqueeze(0) .repeat(num_gt, 1) ) gt_bboxes_per_image_l = ( (gt_bboxes_per_image[:, 0] - 0.5 * gt_bboxes_per_image[:, 2]) .unsqueeze(1) .repeat(1, total_num_anchors) ) # ([5, 8400]) x1 gt_bboxes_per_image_r = ( (gt_bboxes_per_image[:, 0] + 0.5 * gt_bboxes_per_image[:, 2]) .unsqueeze(1) .repeat(1, total_num_anchors) ) # ([5, 8400]) x2 gt_bboxes_per_image_t = ( (gt_bboxes_per_image[:, 1] - 0.5 * gt_bboxes_per_image[:, 3]) .unsqueeze(1) .repeat(1, total_num_anchors) ) # ([5, 8400]) y1 gt_bboxes_per_image_b = ( (gt_bboxes_per_image[:, 1] + 0.5 * gt_bboxes_per_image[:, 3]) .unsqueeze(1) .repeat(1, total_num_anchors) ) # ([5, 8400]) y2 b_l = x_centers_per_image - gt_bboxes_per_image_l # ([5, 8400]) b_r = gt_bboxes_per_image_r - x_centers_per_image b_t = y_centers_per_image - gt_bboxes_per_image_t b_b = gt_bboxes_per_image_b - y_centers_per_image bbox_deltas = torch.stack([b_l, b_t, b_r, b_b], 2) # ([5, 8400, 4]) gt与anchor中心点的四个偏移值 is_in_boxes = bbox_deltas.min(dim=-1).values > 0.0 # ([5, 8400]) is_in_boxes_all = is_in_boxes.sum(dim=0) > 0 # in fixed center center_radius = 2.5 gt_bboxes_per_image_l = (gt_bboxes_per_image[:, 0]).unsqueeze(1).repeat( 1, total_num_anchors # (5,1) ->(5.8400) ) - center_radius * expanded_strides_per_image.unsqueeze(0) gt_bboxes_per_image_r = (gt_bboxes_per_image[:, 0]).unsqueeze(1).repeat( 1, total_num_anchors ) + center_radius * expanded_strides_per_image.unsqueeze(0) gt_bboxes_per_image_t = (gt_bboxes_per_image[:, 1]).unsqueeze(1).repeat( 1, total_num_anchors ) - center_radius * expanded_strides_per_image.unsqueeze(0) gt_bboxes_per_image_b = (gt_bboxes_per_image[:, 1]).unsqueeze(1).repeat( 1, total_num_anchors ) + center_radius * expanded_strides_per_image.unsqueeze(0) c_l = x_centers_per_image - gt_bboxes_per_image_l c_r = gt_bboxes_per_image_r - x_centers_per_image c_t = y_centers_per_image - gt_bboxes_per_image_t c_b = gt_bboxes_per_image_b - y_centers_per_image center_deltas = torch.stack([c_l, c_t, c_r, c_b], 2) # ([5, 8400, 4]) is_in_centers = center_deltas.min(dim=-1).values > 0.0 is_in_centers_all = is_in_centers.sum(dim=0) > 0 # in boxes and in centers is_in_boxes_anchor = is_in_boxes_all | is_in_centers_all # (8400) : 3473*[True] is_in_boxes_and_center = ( is_in_boxes[:, is_in_boxes_anchor] & is_in_centers[:, is_in_boxes_anchor] # ([5, 3473]) :325*[True] ) return is_in_boxes_anchor, is_in_boxes_and_center
根据iou动态选择k个样本
例如:给5个gt分配了34个样本,并返回这34个样本的最大iou分数(pred_ious_this_matching)
def dynamic_k_matching(self, cost, pair_wise_ious, gt_classes, num_gt, fg_mask): # Dynamic K # --------------------------------------------------------------- matching_matrix = torch.zeros_like(cost, dtype=torch.uint8) # ([5, 3473]) ious_in_boxes_matrix = pair_wise_ious n_candidate_k = min(10, ious_in_boxes_matrix.size(1)) # 10 topk_ious, _ = torch.topk(ious_in_boxes_matrix, n_candidate_k, dim=1) # ( 5, 10 ) dynamic_ks = torch.clamp(topk_ious.sum(1).int(), min=1) dynamic_ks = dynamic_ks.tolist() # [3, 7, 9, 9, 6] for gt_idx in range(num_gt): _, pos_idx = torch.topk( cost[gt_idx], k=dynamic_ks[gt_idx], largest=False ) # ([3473])中取前3个 pos_idx: [ 3236, 3235, 3237 ] matching_matrix[gt_idx][pos_idx] = 1 # 全0矩阵matching_matrix([5, 3473])的每行(每个gt)中,分别有 [3, 7, 9, 9, 6]个是1 del topk_ious, dynamic_ks, pos_idx anchor_matching_gt = matching_matrix.sum(0) # ( 3473 ) if (anchor_matching_gt > 1).sum() > 0: _, cost_argmin = torch.min(cost[:, anchor_matching_gt > 1], dim=0) matching_matrix[:, anchor_matching_gt > 1] *= 0 matching_matrix[cost_argmin, anchor_matching_gt > 1] = 1 fg_mask_inboxes = matching_matrix.sum(0) > 0 # ( 3473 ) 34*[ True ] num_fg = fg_mask_inboxes.sum().item() # 34 fg_mask[fg_mask.clone()] = fg_mask_inboxes matched_gt_inds = matching_matrix[:, fg_mask_inboxes].argmax(0) # ([5, 3473]) --> ([5, 34]).argmax --> (34) # [4, 4, 2, 4, 4, 4, 3, 3, 3, 3, 1, 1, 1, 1, 0, 0, 0, 2, 2, 2, 2, 2, 3, 2, 4, 3, 3, 2, 2, 1, 3, 1, 3, 1] gt_matched_classes = gt_classes[matched_gt_inds] # ( 34 ): [ 14., 14., 14., 14., 14., 14., 14., 14., 14., 14., 8., 8., 8., 8., 11., 11., 11., 14., 14., 14., 14., 14., 14., 14., 14., 14., 14., 14., 14., 8., 14., 8., 14., 8. ] pred_ious_this_matching = (matching_matrix * pair_wise_ious).sum(0)[ fg_mask_inboxes ] # ( 34 ) scoers return num_fg, # 34 gt_matched_classes, #(34)34个正样本的类别 pred_ious_this_matching, #(34)34个正样本的IOU matched_gt_inds # (34) 34个正样本,跟第几个gt更匹配 fg_mask # (8400)中有34个True
outputs = self.model(inps, targets)
loss = outputs["total_loss"]
self.optimizer.zero_grad()
self.scaler.scale(loss).backward()
self.scaler.step(self.optimizer)
self.scaler.update()
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。