当前位置:   article > 正文

YOLOV6-训练流程

yolov6

目录

一. 网络的输入与输出

1. 网络输入

2.网络输出

二、损失函数介绍

1. 前期准备(数据处理)

2. 正样本初步筛选(粗筛)

 三、损失函数计算

四、网络的输出解析


 该部分为本人结合代码对YOLOV6的损失函数以及后处理过程进行整理,并结合一些前辈的博客而得到的,只是为了更为清晰的解释YOLOV6的处理流程。

一. 网络的输入与输出

1. 网络输入

输入img: shape为[b,3,640,640]

输入targets是List格式,其中每个img的label以一个Tensor形式保存,shape为[N,5],N为该图像的GT-Box数目,5为5个维度[xmin,ymin,xmax,ymax,cls_index]。

2.网络输出

网络输出为三个分支各自整合的结果,存储于一个List中,shape分别为[b,11,80,80],[b,11,40,40],[b,11,20,20],其中11表示特征图每个像素点对应的anchor个数为1个,11=(4+1+6)x 1,即11=((中心点横坐标偏移,中心点纵坐标偏移,宽度缩放,高度缩放)+前景背景obj+类别数目)* anchor 个数(每个像素点一个anchor)

二、损失函数介绍

计算损失函数的前提是需要有目标targets,和预测值Preds,而对于预测值Preds的box、cls等的损失计算是需要提取出一定个数的正样本的,故计算损失函数之前的一个重要工作就是正样本的筛选

1. 前期准备(数据处理)

  1. def forward(self,outputs,targets_list):
  2. '''
  3. outputs: 三个分支的网络输出,shape分别为:[b,11,80,80],[b,11,40,40],[b,11,20,20]
  4. targets: 数据的标注值,batch中每张图像为一组[cx,cy,w,h,cls],其中[:4]为归一化后的值
  5. '''
  6. self.dtype = outputs[0].type()
  7. self.device = outputs[0].device
  8. # ------- 对输出数据进行解析 -------- #
  9. outputs, outputs_origin, gt_boxes_scale, xy_shifts, expanded_strides =
  10. self.__get_outputs_and_grids(outputs)

分别对三个支路的输出结果进行预处理,网络预测值为特征图像素中心点的偏移值(相对于像素左上角坐标)、宽高的缩放值Obj(前景背景)的预测值Cls(如6个类别的预测值),这里有个前提即:每个像素点只有一个anchor(与yolov3\v4\v5的3个anchor不同),这里对于中心点、宽高的值是需要还原到输入图像大小的,故下面要做的就是这件事。

  1. def __get_outputs_and_grids(self,outputs):
  2. '''对网络输出数据进行解析
  3. 三个支路的输出值shape分别为:[b,11,h,w]
  4. '''
  5. # 存放输出特征图每个像素点的横纵坐标(像素点左上角),用来与预测偏移量结合得到最终的预测结果
  6. # shape:[b,80*80+40*40+20*20,2]->[b,8400,2]
  7. xy_shifts = []
  8. # 存放输出特征图每个像素点相对于原图的缩放值,用来将预测结果映射回原图像
  9. # shape:[b,8400,1]
  10. expanded_strides = []
  11. # 存放输出特征图每个像素点预测结果映射回原图像的结果
  12. # shape:[b,8400*anchor_num,11]->[b,8400,11]
  13. outputs_new = []
  14. # 存放输出特征图每个像素点的预测结果
  15. # shape:[b,8400*anchor_num,11]->[b,8400,11]
  16. outputs_origin = []
  17. for k,output in enumerate(outputs):
  18. '''
  19. 依次对每个输出支路的结果进行处理
  20. output: shape-[b,cls+5,h,w],其为网络预测值,
  21. 前2个即output[:,:2,:,:]与对应像素点坐标相加后再通过output[:,2:4,:,:]
  22. 对宽高进行缩放后,并*stride后映射回原图
  23. output_origin: shape-(b,cls+5,w*h),网络预测值,没有与xy_shifts相加放缩并*stride,即网络原始输出值
  24. grid: 特征图每个像素点的坐标值(x,y),shape-(b,w*h,2),与xy_shifts一致
  25. feat_w,feat_h:特征图的宽高(80,80),(40,40)(20,20)
  26. '''
  27. output,output_origin,grid,fh,fw = self.__decode_output(output,k)
  28. # 该特征图的像素点坐标值shape-(b,fw*fh,2)
  29. xy_shift = grid
  30. # 该特征图的每个像素位置相对于输入图像的缩放值-shape-[1,fw*fh,1]
  31. expanded_stride = torch.full((1,grid.shape[1],1),self.strides[k],dtype=grid.dtype,device=grid.device)
  32. # 记录每个特征图像素点偏移值
  33. xy_shifts.append(xy_shift)
  34. # 记录每个特征图像素点缩放值
  35. expanded_strides.append(expanded_stride)
  36. # 记录每个输入特征图的更新后的预测值
  37. outputs_new.append(output)
  38. # 记录每个输入特征图的原始预测值
  39. outputs_origin.append(output_origin)
  40. # 三个特征图像素点坐标进行合并 shape:[b,8400,2]
  41. xy_shifts = torch.cat(xy_shifts,dim=1)
  42. # 三个特征图像素点缩放值进行合并 shape:[b,8400,1]
  43. expanded_strides = torch.cat(expanded_strides,dim=1)
  44. # 三个特征图最终的预测值进行合并(更新后) shape-[b,8400,11]
  45. outputs = torch.cat(outputs_new,dim=1)
  46. # 三个特征图输出预测值进行合并(更新前) shape-[b,8400,11]
  47. outputs_origin = torch.cat(outputs_origin,dim=1)
  48. # 输入图像尺寸,特征图尺寸*缩放值
  49. fh *= self.strides[-1]
  50. fw *= self.strides[-1]
  51. # shape-[1,4],用于与标注值映射回原图像,标注值[cx,cy,w,h,cls]前四个值为归一化之后的结 果
  52. gt_boxes_scale = torch.Tensor([fw,fh,fw,fh]).type_as(outputs)
  53. return outputs,outputs_origin,gt_boxes_scale,xy_shifts,expanded_strides
  1. def __decode_output(self,output,k):
  2. '''
  3. output解码
  4. '''
  5. bs = output.shape[0] # batch_size
  6. c = output.shape[1] # (cls+5)*anchor_num
  7. fh,fw = output.shape[2:4] # 特征图的高和宽
  8. # shape-[b,(5+cls)*anchor_num,h,w]->[b,anchor_num,5+cls,h,w]->[b,anchor_num,h,w,cls+5]
  9. # 此处:anchor_num = 1 ,cls=6, shape-[b,1,h,w,11]
  10. output = output.view(bs,self.n_anchors,c//self.n_anchors,fh,fw).permute(0,1,3,4,2).contiguous()
  11. # 获取特征图每个像素点对应的横纵坐标值,yv,xv的shape均为(h,w)
  12. yv,xv = torch.meshgrid([torch.arange(fh),torch.arange(fw)])
  13. # 横纵坐标进行组合,扩展一个新维度dim=2,并进行合并
  14. # shape:(h,w,2)->(1,1,h,w,2)
  15. grid = torch.stack((xv,yv),2).view(1,1,fh,fw,2).type(self.dtype).to(self.device)
  16. # 网络预测值reshape
  17. # output reshape - [b,1,h,w,11]->[b,1*h*w,11]
  18. output = output.view(bs,self.n_anchors*fh*fw,-1)
  19. # 网络预测值备份
  20. # [b,1*h*w,11]
  21. output_origin = output.clone()
  22. # grid reshape-[1,1,h,w,2] -> [1,1*h*w,2]
  23. grid = grid.view(1,-1,2)
  24. # 将每个像素点的网络预测(中心点)与对应的grid坐标值相加
  25. # [b,1*h*w,11]
  26. output[...,:2] = (output[...,:2] + grid)
  27. # 同理将宽高预测值进行放缩
  28. output[...,2:4] = torch.exp(output[...,2:4])
  29. # 将平移放缩后的结果恢复到输入图像大小
  30. output[...,:4] = output[...,:4] * self.strides[k]
  31. return output,output_origin,grid,fh,fw

引出问题:

为什么既要保留经过偏移缩放等处理的output值,也要同时保留网络输出的原始的预测值output_origin?

这个问题非常好,之所以同时保留两个结果,是因为在最后计算关于box的损失时,会同时计算box的iou_loss(Ciou Loss 或 Siou Loss)和 box的回归损失如L1 Loss

  1. loss_iou += (self.iou_loss(box_preds.view(-1, 4)[fg_masks].T, reg_targets)).sum() / num_fg
  2. loss_l1 += (self.l1_loss(box_preds_org.view(-1, 4)[fg_masks], l1_targets)).sum() / num_fg

这与YOLOv3,v4,v5是不同的,这三种算法对于box的损失只是采用一种,如yolov3的L1 Loss,以及yolov4,yolov5的Ciou Loss;

本方法采用iou loss和L1 Loss相结合,相当于时双重保障。

代码继续:

  1. # ------- 对输出数据进行解析 -------- #
  2. outputs, outputs_origin, gt_boxes_scale, xy_shifts, expanded_strides = self.__get_outputs_and_grids(outputs)
  3. #
  4. # 三个特征图的像素点总个数(像素点相当于anchor)
  5. total_num_anchors = outputs.shape[1] #8400
  6. # 解析后的预测框box,shape-[b,8400,4]
  7. box_preds = outputs[:,:,:4]
  8. # 网络预测值(偏移缩放值),shape-[b,8400,4]
  9. box_preds_org = outputs_origin[:,:,:4]
  10. # 预测的前景背景obj值,shape-[b,8400,1]
  11. obj_preds = outputs[:,:,4].unsqueeze(-1)
  12. # 预测的cls值,shape-[b,8400,cls_num]
  13. cls_preds = outputs[:,:,5:]
  14. # -------- 对targets进行解析 -------- #
  15. num_fg = 0
  16. cls_targets, reg_targets, l1_targets, obj_targets, fg_masks = [], [], [], [], []
  17. loss_cls, loss_obj, loss_iou, loss_l1 = torch.zeros(1, device=self.device), torch.zeros(1, device=self.device), \
  18. torch.zeros(1, device=self.device), torch.zeros(1, device=self.device)
  19. batch_size = box_preds.shape[0]
  20. for batch_idx in range(batch_size):
  21. # 获取每张图像的标注信息[cx,cy,w,h,cls]
  22. targets = targets_list[batch_idx]
  23. num_gt = targets.shape[0]
  24. # 每张图像的gtbox并恢复到原图中,shape:[n,4]
  25. gt_boxes_per_image = targets[:,:4].mul_(gt_boxes_scale)
  26. # 每张图像的gtbox类别,shape:[n]
  27. gt_classes = targets[:,4]
  28. # 每张图像的预测box信息,shape-[8400,4]
  29. box_preds_per_image = box_preds[batch_idx]
  30. # 每张图像的类别预测信息,shape-[8400,6]
  31. cls_preds_per_image = cls_preds[batch_idx]
  32. # 每张图像的obj预测信息,shape-[8400,1]
  33. obj_preds_per_image = obj_preds[batch_idx]
  34. # -------------- !!!!!!!! --------------- #
  35. # 正样本筛选 粗筛+SimOTA精筛
  36. (gt_matched_classes,
  37. fg_mask,
  38. pred_ious_this_matching,
  39. match_gt_inds,
  40. num_fg_img
  41. ) = self.__get_assignments(
  42. gt_boxes_per_image,
  43. expanded_strides,
  44. xy_shifts,
  45. total_num_anchors,
  46. num_gt,
  47. gt_classes,
  48. box_preds_per_image,
  49. cls_preds_per_image,
  50. obj_preds_per_image
  51. )

如上面代码所示:

  1. # 获取每张图像的标注信息[cx,cy,w,h,cls]
  2. targets = targets_list[batch_idx]
  3. num_gt = targets.shape[0]
  4. # 每张图像的gtbox并恢复到原图中,shape:[n,4]
  5. gt_boxes_per_image = targets[:,:4].mul_(gt_boxes_scale)
  6. # 每张图像的gtbox类别,shape:[n]
  7. gt_classes = targets[:,4]
  8. # 每张图像的预测box信息,shape-[8400,4]
  9. box_preds_per_image = box_preds[batch_idx]
  10. # 每张图像的类别预测信息,shape-[8400,6]
  11. cls_preds_per_image = cls_preds[batch_idx]
  12. # 每张图像的obj预测信息,shape-[8400,1]
  13. obj_preds_per_image = obj_preds[batch_idx]

获得了每张图像的标注信息gt_boxes_per_imagegt_classes,也获得了每张图像的预测信息:box_preds_per_imagecls_preds_per_imageobj_preds_per_image,但是这三个预测信息是针对于该图像(1*3*640*640)网络输出三个支路的8400个像素点的,对于box的预测、cls分类预测等这8400个像素点中是存在大量的负样本的,只有少数的正样本是有用的,所以是不能直接用来计算损失的,故需要采用一定的方式获取到8400个像素点中,正样本的mask,再通过mask对最终的损失Loss进行过滤

引出问题:

YOLOV6采用什么样的方式进行正样本的提取?与YOLOv3、YOLOv4、YOLOv5有什么不同?

(1) 首先,对于YOLOv3、v4、v5三种方法,对于正样本的提取规则是一致的,即将每个gtbox与该gtbox中心点所处像素点对应的三个anchors计算iou,判断其iou值是否超过设定的iou threshold阈值,如果超过阈值则该将对应的anchor设为正样本。

这样做是有一定弊端的:每个GT_Box的正样本的选取只限制在了该gtbox的中心点所在像素点,最终最多只有1个正样本,而其相邻的像素点的anchors与其iou也可能会有一个比较高的交并比,这样做会有一些本可以作为正样本的anchors并强制作为了负样本,最终在将测过程中可能会造成检测结果的疏漏。

(2)YOLOv6则采用了类似于CenterNet的方式,只能说是有那么一点像,但是计算流程更简单高效,即将落在gtbox内或落在以gtbox中心点为中心,以特征图stride值的2.5*2倍为边长的正方形区域内,则为正样本,这样获取的正样本只是粗略筛选,然后通过计算每个粗筛正样本的损失(iou loss + cls loss),通过一定策略进行自动筛选。

这样就相当于对正样本进行自动筛选,去除掉了人工的干预,且每个gtbox获得的正样本数目也是不固定的。

2. 正样本初步筛选(粗筛)

  1. def __get_assignments(self,
  2. gt_boxes_per_image,
  3. expanded_strides,
  4. xy_shifts,
  5. total_num_anchors,
  6. num_gt,
  7. gt_classes,
  8. box_preds_per_image,
  9. cls_preds_per_image,
  10. obj_preds_per_image
  11. ):
  12. '''
  13. gt_boxes_per_image: bacth内每幅图像的GT_box;
  14. expanded_strides: 每幅图像三个分支相对输入图像的缩放值的合并结果,shape[1,8400,1]
  15. 8400 = 80*80(stride=8) + 40*40(stride=16) + 20*20(stride=32)
  16. xy_shifts: 8400个像素点,每个像素点的坐标值(1,8400,2)
  17. total_num_anchors: 总的anchors数目,8400个像素点,每个像素点anchors个数为1,故为8400
  18. num_gt: 正样本GT_Box的数目,例num_gt=12
  19. gt_classes: 每个gt_box的对应类别,如12个gt_box标注框的标注类别
  20. box_preds_per_image: 每幅图像所有像素点(8400个像素点)的box预测,shape=[8400,4]
  21. cls_preds_per_image: 每幅图像所有像素点(8400个像素点)的cls预测,shape=[8400,6]
  22. cls_preds_per_image: 每幅图像所有像素点(8400个像素点)的box预测,shape=[8400,1]
  23. '''
  24. # 基于每张图像的gt_box得到对应可作为正样本anchors的mask - is_in_boxes_and_fix_center
  25. # 以及所有8400个anchors是否为正阳的mask - fg_mask
  26. # ---------------- 初步筛选 ---------------------- #
  27. fg_mask,is_in_boxes_and_fix_center = self.__get_in_boxes_info(gt_boxes_per_image,expanded_strides,
  28. xy_shifts,total_num_anchors,num_gt)
  29. # ------- 对预测结果进行过滤 ------------ #
  30. # 过滤掉负样本部分,例957个正anchors
  31. # box preds 过滤 shape-[957,4]
  32. box_preds_per_image = box_preds_per_image[fg_mask]
  33. # cls preds 过滤 shape-[957,6]
  34. cls_preds_per_image = cls_preds_per_image[fg_mask]
  35. # obj preds 过滤 shape-[957,1]
  36. obj_preds_per_image = obj_preds_per_image[fg_mask]
  37. # 计算正样本anchors的个数
  38. num_in_boxes_anchor = box_preds_per_image.shape[0]

初步筛选分成两种方式:根据中心点目标框判断

目标框:anchor box的中心点落在人工标注框GT Box的矩形范围中的所有anchor。

     图中绿色框为yolov6网络提取的特征方格(即特征图),在yolov6中每个方格即每个像素点表示一个anchor,红色方框表示GT box,则红色点落在GT box中的小方格(特征图像素点)内的,可用于预测正样本。

  中心点

   以GT Box中心点为基准,四周向外扩展2.5倍stride,构成边长为5倍stride的正方形,挑选anchor box中心点(即像素点)落在正方形内的所有anchor,挑选落在正方形内的所有像素点。

  图中以箭头扩大2.5倍的正方形为边界,anchor box中心点落在正方形中的anchor box,可能作为正样本的预测。

注意!!!这些计算都是将数值还原到输入图像大小后,再进行计算的。

代码如下:

  1. def __get_in_boxes_info(self,
  2. gt_boxes_per_image,
  3. expanded_strides,
  4. xy_shifts,
  5. total_num_anchors,
  6. num_gt
  7. ):
  8. '''
  9. !!!标签的初级筛选!!!
  10. gt_boxes_per_image: bacth内每幅图像的GT_box;
  11. expanded_strides: 每幅图像三个分支相对输入图像的缩放值的合并结果,shape[1,8400,1]
  12. 8400 = 80*80(stride=8) + 40*40(stride=16) + 20*20(stride=32)
  13. xy_shifts: 8400个像素点,每个像素点的坐标值(1,8400,2)
  14. total_num_anchors: 总的anchors数目,8400个像素点,每个像素点anchors个数为1,故为8400
  15. num_gt: 正样本GT_Box的数目,例num_gt=12
  16. '''
  17. # (1,8400,1) -> (8400,1)
  18. expanded_strides_per_image = expanded_strides[0]
  19. # 将像素点坐标转换到原图像大小 (1,8400,2) -> (8400,2)
  20. xy_shifts_per_image = xy_shifts[0] * expanded_strides_per_image
  21. # 上面像素点坐标为每个像素点左上角值,将其转换为每个像素的中心
  22. # (8400,2)->(1,8400,2)->(12,8400,2)
  23. # 因为有12个正样本,如此转换方便计算
  24. xy_center_per_image = (xy_shifts_per_image + 0.5*expanded_strides_per_image).unsqueeze(0).repeat(num_gt,1,1)
  25. # 计算所有gt_box的左上角坐标
  26. # (12,4) -> (12,1,2) ->(12,8400,2)
  27. # 相当于将每个GTBOX的坐标复制anchors_num个
  28. gt_boxes_per_image_lt = (gt_boxes_per_image[:,0:2] - 0.5*gt_boxes_per_image[:,2:4]).unsqueeze(1).repeat(1,total_num_anchors,1)
  29. # 同理计算所有gt_box的右下角坐标
  30. gt_boxes_per_image_rb = (gt_boxes_per_image[:,0:2] + 0.5*gt_boxes_per_image[:,2:4]).unsqueeze(1).repeat(1,total_num_anchors,1)
  31. # ------------ 判断anchor是否在gtbox范围内 --------------- #
  32. # 通过anchors的中心点坐标的范围是否落在gtbox范围内进行判断
  33. # 计算距离左上角(左边与上边)的距离
  34. b_lt = xy_center_per_image - gt_boxes_per_image_lt
  35. # 计算距离右下角(右边与下边)的距离
  36. b_rb = gt_boxes_per_image_rb - xy_center_per_image
  37. # 合并 (12,8400,4)
  38. bbox_deltas = torch.cat([b_lt,b_rb],2)
  39. # 对每个anchors的四个值进行判断,只要有一个值小于0,即为负样本
  40. # 这里采用用四个值中的最小值与0比较,大于零则说明都大于0,说明该anchor在gtbox范围内,shape(12,8400)
  41. is_in_boxes = bbox_deltas.min(dim=-1).values > 0.0 # 每个anchor的mask
  42. # 判断每个anchor是否都找到了与之对应的gt_box,shape为(8400)
  43. is_in_boxes_all = is_in_boxes.sum(dim=0)>0
  44. # --------------- 判断anchor是否在以gtbox中心点为中心,2.5倍stride的fixbox范围内 --------------- #
  45. # 计算所有fixbox的左上角坐标
  46. # (12,2)->(12,8400,2)
  47. gt_boxes_per_image_fix_lt = gt_boxes_per_image[:,0:2].unsqueeze(1).repeat(1,total_num_anchors,1) - \
  48. self.center_radius * expanded_strides_per_image.unsqueeze(0)
  49. # 计算所有fixbox的右下角坐标
  50. gt_boxes_per_image_fix_rb = gt_boxes_per_image[:,0:2].unsqueeze(1).repeat(1,total_num_anchors,1) + \
  51. self.center_radius * expanded_strides_per_image.unsqueeze(0)
  52. # 判断anchor是否在fixbox范围内
  53. # 计算距离左上角(左边与上边)的距离
  54. c_lt = xy_center_per_image - gt_boxes_per_image_fix_lt
  55. # 计算距离右下角(右边与下边)的距离
  56. c_rb = gt_boxes_per_image_fix_rb - xy_center_per_image
  57. # 合并(12,8400,4)
  58. center_deltas = torch.cat([c_lt, c_rb], 2)
  59. # 判断与上面判断是否在gtbox内同理
  60. is_in_centers = center_deltas.min(dim=-1).values > 0.0
  61. is_in_centers_all = is_in_centers.sum(dim=0) > 0
  62. # 判断每个anchor是否在[所有gtbox或所有fixbox]任意一个内部(即是否在fixbox和gtbox的并集内部)
  63. # 亦为每个anchors判断是否为正样本,True or False
  64. is_in_boxes_or_fix_center = is_in_boxes_all | is_in_centers_all
  65. # 针对每个gtbox筛选出可能为正样本的anchors
  66. # 因为is_in_boxes_or_fix_center代表存在并集中的anchors,并非一定在某个gtbox或fixbox内
  67. is_in_gt_boxes = is_in_boxes[:,is_in_boxes_or_fix_center]
  68. # 同理针对每个fixbox筛选出可能为正样本的anchors
  69. is_in_fix_boxes = is_in_centers[:,is_in_boxes_or_fix_center]
  70. # 上面两个集合求与,则得到最终每个gtbox和对应fixbox的正样本(交集区域)
  71. is_in_boxes_and_fix_center = is_in_gt_boxes & is_in_fix_boxes
  72. return is_in_boxes_or_fix_center,is_in_boxes_and_fix_center

3. 正样本精细筛选

  1. # ------------------------- SimOTA 正样本分配与精细化筛选 ------------------------ #
  2. # ----------- 代价矩阵计算 -------------
  3. # (1) 计算 每个gt_box和当前初筛特征点预测框的IOU重合度
  4. # 为了计算gt_boxes与预测框的代价矩阵,表示每个gt_box与所有预测box的代价关系
  5. # shape - [12,957]
  6. # 其中每行表示每个gtbox对于每个初筛正样本的iou值
  7. pair_wise_ious = predbox_gtbox_iou(gt_boxes_per_image,box_preds_per_image,box_format='xywh')
  8. # 计算box iou loss,其是代价矩阵的一部分,shape-[12,957]
  9. # 其中每行表示每个gtbox对于每个初筛正样本的iou loss,其损失越大表示iou越小,与gtbox的匹配度就越低。
  10. pair_wise_ious_loss = -torch.log(pair_wise_ious+1e-8)
  11. # (2) 计算gtbox和当前初筛特征点预测框的种类预测准确度
  12. # 将gtbox的标注类别转换为one-hot形式,[12]->[12,6]
  13. # 然后将其复制957份 [12,6] -> [12,957,6],表示每个gt_box相对于957个初筛预测框的标注cls信息
  14. gt_cls_per_image = (F.one_hot(gt_classes.to(torch.int64),self.class_num).
  15. float().
  16. unsqueeze(1).
  17. repeat(1,num_in_boxes_anchor,1)
  18. )
  19. # 得到针对每个gt_box相对于957个预测框的预测cls信息
  20. with torch.cuda.amp.autocast(enabled=False):
  21. # 当前每个像素点的预测cls值为cls预测值与obj值得乘积,并进行维度转换
  22. # shape-[957,6] -> [12,957,6]
  23. cls_preds_per_image = (
  24. cls_preds_per_image.float().sigmoid_().unsqueeze(0).repeat(num_gt, 1, 1)
  25. * obj_preds_per_image.float().sigmoid_().unsqueeze(0).repeat(num_gt, 1, 1)
  26. )
  27. # 计算cls loss,其是代价矩阵的又一部分 shape[12,957]
  28. # 其中每行表示每个gtbox对于每个初筛正样本的类别损失,其损失值越大表示匹配度越低
  29. pair_wise_cls_loss = F.binary_cross_entropy(
  30. cls_preds_per_image.sqrt_(), gt_cls_per_image, reduction="none"
  31. ).sum(-1)
  32. del cls_preds_per_image, obj_preds_per_image
  33. # 计算代价矩阵,同时负样本设置一个lamda=100000的值,shape-[12,957]
  34. cost = (self.cls_weight*pair_wise_cls_loss
  35. + self.iou_weight * pair_wise_ious_loss
  36. +100000.0*(~is_in_boxes_and_fix_center))
  37. # 这个cost代价矩阵是用来进行下面的标签分配策略的
  38. # -------------- SimOTA求解 --------------
  39. '''
  40. num_fg: 标签分配完成后,总共存在的候选框个数(matrix_matching每列保证一个候选框)
  41. matched_gt_inds: matrix_matching矩阵中存在候选框的位置idx(对于gtbox的索引,如idx=2,对应第二个gtbox),shape:[16]
  42. gt_matched_classes: 标签分配后,每列候选框预测目标的类别编号
  43. pred_ious_this_matching: 由标签分配的mask筛选真实框与预测框构成的IoU矩阵对应的IoU值
  44. '''
  45. (num_fg,
  46. gt_matched_classes,
  47. pred_ious_this_matching,
  48. match_gt_inds
  49. ) = self.__dynamic_k_matching(
  50. cost,
  51. pair_wise_ious,
  52. gt_classes,
  53. num_gt,
  54. fg_mask
  55. )
  56. del pair_wise_cls_loss, cost, pair_wise_ious, pair_wise_ious_loss
  57. return (gt_matched_classes,fg_mask,pred_ious_this_matching,match_gt_inds,num_fg)

 利用SimOTA进行正样本锚点框的精细化筛选的标签分配方法总体分成4个步骤:

1) 初步正样本锚点框筛选;前面已经解读相关代码,在get_assignments中调用get_in_boxes_info函数,获取以中心点和目标框筛选交集与并集mask。

  1. # ---------------- 初步筛选 ---------------------- #
  2. fg_mask,is_in_boxes_and_fix_center = self.__get_in_boxes_info(gt_boxes_per_image,expanded_strides,
  3. xy_shifts,total_num_anchors,num_gt)
  4. # ------- 对预测结果进行过滤 ------------ #
  5. # 过滤掉负样本部分,例957个正anchors
  6. # box preds 过滤 shape-[957,4]
  7. box_preds_per_image = box_preds_per_image[fg_mask]
  8. # cls preds 过滤 shape-[957,6]
  9. cls_preds_per_image = cls_preds_per_image[fg_mask]
  10. # obj preds 过滤 shape-[957,1]
  11. obj_preds_per_image = obj_preds_per_image[fg_mask]
  12. # 计算正样本anchors的个数
  13. num_in_boxes_anchor = box_preds_per_image.shape[0]

 2) 代价计算用于标签分配;这里计算的是bbox的损失与类别损失,计算Loss如下(包含两部分:边界框IOU损失+类别损失):

边界框损失:

  1. # ----------- 代价矩阵计算 -------------
  2. # (1) 计算 每个gt_box和当前初筛特征点预测框的IOU重合度
  3. # 为了计算gt_boxes与预测框的代价矩阵,表示每个gt_box与所有预测box的代价关系
  4. # shape - [12,957]
  5. pair_wise_ious = predbox_gtbox_iou(gt_boxes_per_image,box_preds_per_image,box_format='xywh')
  6. # 计算box iou loss,其是代价矩阵的一部分
  7. pair_wise_ious_loss = -torch.log(pair_wise_ious+1e-8)

  类别损失:

  1. # (2) 计算gtbox和当前初筛特征点预测框的种类预测准确度
  2. # 将gtbox的标注类别转换为one-hot形式,[12]->[12,6]
  3. # 然后将其复制957份 [12,6] -> [12,957,16],表示每个gt_box相对于957个初筛预测框的标注cls信息
  4. gt_cls_per_image = (F.one_hot(gt_classes.to(torch.int64),self.class_num).
  5. float().
  6. unsqueeze(1).
  7. repeat(1,num_in_boxes_anchor,1)
  8. )
  9. # 得到针对每个gt_box相对于957个预测框的预测cls信息
  10. with torch.cuda.amp.autocast(enabled=False):
  11. # 当前每个像素点的预测cls值为cls预测值与obj值得乘积,并进行维度转换
  12. # shape-[957,6] -> [12,957,6]
  13. cls_preds_per_image = (
  14. cls_preds_per_image.float().sigmoid_().unsqueeze(0).repeat(num_gt, 1, 1)
  15. * obj_preds_per_image.float().sigmoid_().unsqueeze(0).repeat(num_gt, 1, 1)
  16. )
  17. # 计算cls loss,其是代价矩阵的又一部分 shape[12,957]
  18. pair_wise_cls_loss = F.binary_cross_entropy(
  19. cls_preds_per_image.sqrt_(), gt_cls_per_image, reduction="none"
  20. ).sum(-1)

   总损失:

  1. # 计算代价矩阵,同时负样本设置一个lamda=100000的值
  2. cost = (self.cls_weight*pair_wise_cls_loss
  3. + self.iou_weight * pair_wise_ious_loss
  4. +100000.0*(~is_in_boxes_and_fix_center))

3) SimOTA求解。

  1. # -------------- SimOTA求解 --------------
  2. '''
  3. num_fg: 标签分配完成后,总共存在的候选框个数(matrix_matching每列保证一个候选框)
  4. matched_gt_inds: matrix_matching矩阵中存在候选框的位置idx(对于gtbox的索引,如idx=2,对应第二个gtbox),shape:[16]
  5. gt_matched_classes: 标签分配后,每列候选框预测目标的类别编号
  6. pred_ious_this_matching: 由标签分配的mask筛选真实框与预测框构成的IoU矩阵对应的IoU值
  7. '''
  8. (num_fg,
  9. gt_matched_classes,
  10. pred_ious_this_matching,
  11. match_gt_inds
  12. ) = self.__dynamic_k_matching(
  13. cost,
  14. pair_wise_ious,
  15. gt_classes,
  16. num_gt,
  17. fg_mask
  18. )

这一步需要详细介绍:

        参数作用提示:

    (1) pair_wise_ious: 其shape为[12,957],每行表示每个gtbox与所有初筛正样本的iou值,其作用是用于判断最终选择的正样本个数

          判断规则:

          对所有的iou值进行排序,然后从中选取Top10,即十个最大的iou值,然后将这10个值相加后并取整得到数值K,这个K就表示该gtbox的正样本个数。后面从Cost代价矩阵中,对应的gtbox那一行,选择出k个cost最小的位置,这些位置就是正样本位置,然后将相应的mask矩阵位置置1。

    (2) cost: 如上面(1)所讲,cost是与K结合,通过选择出每个gtbox的cost值最小的K个位置,并在mask矩阵中相应位置置为1。

  1. def __dynamic_k_matching(self,cost,pair_wise_ious,gt_classes,num_gt,fg_mask):
  2. '''
  3. !!!标签的精筛!!!
  4. cost: 由iou loss + cls loss,得到的代价矩阵cost,shape-[12,957]
  5. pair_wise_ious: 每个gt_box相对于所有正样本预测box的ious,shape-[12,957],即所有真实框与预测框的IOU
  6. gt_classes: 一张图像gtbox标注框的类别
  7. num_gt: gtbox个数
  8. fg_mask: 根据中心点与gtbox初步筛选的并集掩码(即初筛的正样本mask)(在中心点区域或gtbox区域的像素点为True)
  9. '''
  10. # 生成一个全0矩阵大小与cost一致,shape-[12,957]
  11. # 用于:记录cost中选择的具体位置,即记录最终正样本选择的位置
  12. matching_matrix = torch.zeros_like(cost,dtype=torch.uint8)
  13. # iou矩阵shape[12,957],表示每个gt与所有初筛预测box的iou
  14. ious_in_boxes_matrix = pair_wise_ious
  15. # 对每个gt_box,将其与所有初筛预测box的iou值按从大到小进行排序并设置topn
  16. n_candidate_k = min(10, ious_in_boxes_matrix.size(1)) # 从排序结果中选择10个候选box
  17. # 得到top10的ious,shape-[12,10]
  18. topk_ious,_ = torch.topk(ious_in_boxes_matrix,n_candidate_k,dim=1)
  19. # 对得到的top10 ious进行求和,最后每个gtbox得到一个值
  20. # 其中每个gtbox得到的值,将作为其真正候选框的数量,如值为2,则从957个初筛box中
  21. # 选择2个作为最终的候选框
  22. # 保证每个gtbox至少存在一个正样本
  23. dynamic_ks = torch.clamp(topk_ious.sum(1).int(), min=1)
  24. dynamic_ks = dynamic_ks.tolist()
  25. for gt_idx in range(num_gt):
  26. # 根据每行分配的候选框数量num=dynamic_ks[gt_idx]由cost找出num个最小的cost位置,也就是iou最大
  27. _,pos_idx = torch.topk(
  28. cost[gt_idx],k=dynamic_ks[gt_idx],largest=False
  29. )
  30. # 将对应的位置设置为1, shape:[12,957]
  31. matching_matrix[gt_idx][pos_idx] = 1
  32. # 但是这样难免会出现,某个候选box被多个gtbox共用,这是不可以的,故从中保留cost最小的,其余设为0
  33. # 对matching_matrix每列求和 ,shape[957]
  34. # 如果某列结果大于1,则说明该anchorbox被多个gtbox共用
  35. del topk_ious, dynamic_ks, pos_idx
  36. # 对每列
  37. anchor_matching_gt = matching_matrix.sum(0)# shape [957]
  38. if (anchor_matching_gt>1).sum()>0:# 说明存在上述情况
  39. # 选出共用列中,最小的位置,,即从每个存在共用情况的列中选出最小cost的位置
  40. _,cost_argmin = torch.min(cost[:,anchor_matching_gt>1],dim=0)
  41. # 先将该列所有位置设置为0
  42. matching_matrix[:,anchor_matching_gt>1]*=0
  43. # 然后将该列对应的行,即对应的cost最小的的gtbox位置设置为1
  44. matching_matrix[cost_argmin,anchor_matching_gt>1] = 1
  45. # 以上就得到了每个gtbox对所有初筛候选box的mask
  46. # 以下就是选择出正样本位置
  47. # 每列进行求和,找出大于0的列,每列对应一个anchor,大于0则说明该anchor为正样本
  48. # shape-[12,957],其中每列最多只有一位位置为True,为True则说明是正样本
  49. fg_mask_inboxes = matching_matrix.sum(0) > 0
  50. # 得到所有正样本个数,例如num_fg=16
  51. num_fg = fg_mask_inboxes.sum().item()
  52. # 把通过标签分配处理的mask,赋值给初筛选的mask
  53. fg_mask[copy.deepcopy(fg_mask)] = fg_mask_inboxes
  54. # 筛选出有候选框的列,并找出筛选列中最大值索引,即找出每个正样本对应的gt_box索引
  55. matched_gt_inds = matching_matrix[:, fg_mask_inboxes].argmax(0)
  56. # 根据索引,找出每个正样本anchor 对应的cls
  57. gt_matched_classes = gt_classes[matched_gt_inds]
  58. # 通过pair_wise_ious和matching_matrix相乘过滤负样本,求和
  59. # 然后与标签分配的mask(fg_mask_inboxes), 筛选存在候选框的IoU
  60. pred_ious_this_matching = (matching_matrix * pair_wise_ious).sum(0)[
  61. fg_mask_inboxes
  62. ]
  63. #num_fg正样本的个数,gt_matched_classes每个正样本类别,matched_gt_inds每个正样本对应的gt_box索引
  64. return num_fg, gt_matched_classes, pred_ious_this_matching, matched_gt_inds

第一步:某幅图像真实框,与通过初筛获取的预测框计算IoU,然后通过计算的IoU找出最大top10的数据,尺寸大小为[num_gt, 10]。再由最大top10的数据统计这幅图像每个目标分配的候选框,通过找出cost最小位置分配某个候选框。

  1. # 生成一个全0矩阵大小与cost一致,shape-[12,957]
  2. # 用于:记录cost中选择的具体位置
  3. matching_matrix = torch.zeros_like(cost,dtype=torch.uint8)
  4. # iou矩阵shape[12,957],表示每个gt与所有初筛预测box的iou
  5. ious_in_boxes_matrix = pair_wise_ious
  6. # 对每个gt_box,将其与所有初筛预测box的iou值按从大到小进行排序并设置topn
  7. n_candidate_k = min(10, ious_in_boxes_matrix.size(1)) # 从排序结果中选择10个候选box
  8. # 得到top10的ious,shape-[12,10]
  9. topk_ious,_ = torch.topk(ious_in_boxes_matrix,n_candidate_k,dim=1)
  10. # 对得到的top10 ious进行求和,最后每个gtbox得到一个值
  11. # 其中每个gtbox得到的值,将作为其真正候选框的数量,如值为2,则从957个初筛box中
  12. # 选择2个作为最终的候选框
  13. # 保证每个gtbox至少存在一个正样本
  14. dynamic_ks = torch.clamp(topk_ious.sum(1).int(), min=1)
  15. dynamic_ks = dynamic_ks.tolist()
  16. for gt_idx in range(num_gt):
  17. # 根据每行分配的候选框数量num=dynamic_ks[gt_idx]由cost找出num个最小的cost位置,也就是iou最大
  18. _,pos_idx = torch.topk(
  19. cost[gt_idx],k=dynamic_ks[gt_idx],largest=False
  20. )
  21. # 将对应的位置设置为1, shape:[12,957]
  22. matching_matrix[gt_idx][pos_idx] = 1
  23. # 但是这样难免会出现,某个候选box被多个gtbox共用,这是不可以的,故从中保留cost最小的,其余设为0
  24. # 对matching_matrix每列求和 ,shape[957]
  25. # 如果某列结果大于1,则说明该anchorbox被多个gtbox共用
  26. del topk_ious, dynamic_ks, pos_idx

计算每个目标框分配的候选框个数,假设ious_in_boxes_matrix为[3,13]的矩阵,则得到候选框个数如下:

 根据cost分配计算候选框的位置(找出每行中最小的cost),大致流程如下:

 第二步:过滤掉共用的候选框,即matching_matrix同列中 有多个1的情况,也就是某列候选框被多个gtbox关联。

  1. # 但是这样难免会出现,某个候选box被多个gtbox共用,这是不可以的,故从中保留cost最小的,其余设为0
  2. # 对matching_matrix每列求和 ,shape[957]
  3. # 如果某列结果大于1,则说明该anchorbox被多个gtbox共用
  4. del topk_ious, dynamic_ks, pos_idx
  5. # 对每列
  6. anchor_matching_gt = matching_matrix.sum(0)# shape [957]
  7. if (anchor_matching_gt>1).sum()>0:# 说明存在上述情况
  8. # 选出共用列中,最小的位置,,即从每个存在共用情况的列中选出最小cost的位置
  9. _,cost_argmin = torch.min(cost[:,anchor_matching_gt>1],dim=0)
  10. # 先将该列所有位置设置为0
  11. matching_matrix[:,anchor_matching_gt>1]*=0
  12. # 然后将该列对应的行,即对应的cost最小的的gtbox位置设置为1
  13. matching_matrix[cost_argmin,anchor_matching_gt>1] = 1
  14. # 以上就得到了每个gtbox对所有初筛候选box的mask
  15. # 以下就是选择出正样本位置
  16. # 每列进行求和,找出大于0的列,每列对应一个anchor,大于0则说明该anchor为正样本
  17. # shape-[12,957],其中每列最多只有一位位置为True,为True则说明是正样本
  18. fg_mask_inboxes = matching_matrix.sum(0) > 0
  19. # 得到所有正样本个数,例如num_fg=16
  20. num_fg = fg_mask_inboxes.sum().item()

  通过cost矩阵,找出共用候选框所在列中cost损失值最小的位置,mask设置为1,其余为0,具体过程如下:

  1. # 每列进行求和,找出大于0的列,每列对应一个anchor,大于0则说明该anchor为正样本
  2. # shape-[12,957],其中每列最多只有一位位置为True,为True则说明是正样本
  3. fg_mask_inboxes = matching_matrix.sum(0) > 0
  4. # 得到所有正样本个数,例如num_fg=16
  5. num_fg = fg_mask_inboxes.sum().item()
  6. # 把通过标签分配处理的mask,赋值给初筛选的mask
  7. fg_mask[copy.deepcopy(fg_mask)] = fg_mask_inboxes
  8. # 筛选出有候选框的列,并找出筛选列中最大值索引,即找出每个正样本对应的gt_box索引
  9. matched_gt_inds = matching_matrix[:, fg_mask_inboxes].argmax(0)
  10. # 根据索引,找出每个正样本anchor 对应的cls
  11. gt_matched_classes = gt_classes[matched_gt_inds]
  12. # 通过pair_wise_ious和matching_matrix相乘去除负样本,求和
  13. # 然后与标签分配的mask(fg_mask_inboxes), 筛选存在候选框的IoU
  14. pred_ious_this_matching = (matching_matrix * pair_wise_ious).sum(0)[
  15. fg_mask_inboxes
  16. ]
  17. #num_fg正样本的个数,gt_matched_classes每个正样本类别,matched_gt_inds每个正样本对应的gt_box索引
  18. return num_fg, gt_matched_classes, pred_ious_this_matching, matched_gt_inds

matched_gt_inds的查找,如图所示:

pred_ious_this_matching的计算如下所示:

    每个正样本的iou值,表示其与对应gtbox的交并比程度,iou越大重合度越高。

    作用:在计算分类损失的过程用到,通过与预测cls值相乘,相当于添加了一个权重,当iou值越小时,惩罚越高,loss值越大。

 三、损失函数计算

       通过标签分配之后得到的匹配类别标号(gt_matched_classes),候选框掩码(fg_mask),匹配之后的交并比(pred_ious_this_matching)计算真实的类别概率(cls_target),真实的置信度obj_target(即标签分配后的掩码fg_mask),再由matched_gt_inds筛选目标box,即reg_target。    

  1. # one_hot构成size为[num_gt,80]的矩阵, pred_ious_this_matching为num_gt的一维向量,unsqueeze(-1)表示reshape为[num_gt,1]
  2. # perd_ious_this_matching作为惩罚项
  3. cls_target = F.one_hot(gt_matched_classes.to(torch.int64), self.num_classes) * pred_ious_this_matching.unsqueeze(-1)
  4. obj_target = fg_mask.unsqueeze(-1) # 目标置信度
  5. reg_target = gt_bboxes_per_image[matched_gt_inds] # 通过匹配索引筛选目标box
  6. # 计算前景box正样本的中心点的偏移值和宽高的缩放值
  7. l1_target = self.__get_l1_target(
  8. outputs.new_zeros((num_fg_img,4)),
  9. reg_target,
  10. expanded_strides[0][fg_mask],
  11. xy_shifts=xy_shifts[0][fg_mask]
  12. )

统计一个batch下,三个损失结果,然后计算一个batch的损失。       

  1. for ...
  2. # 一个batch的每幅图像三个损失append
  3. cls_targets.append(cls_target)
  4. reg_targets.append(reg_target)
  5. obj_targets.append(obj_target.to(dtype))
  6. l1_targets.append(l1_target)
  7. fg_masks.append(fg_mask) # 目标置信度添加
  8. # cat操作
  9. cls_targets = torch.cat(cls_targets, 0)
  10. reg_targets = torch.cat(reg_targets, 0)
  11. obj_targets = torch.cat(obj_targets, 0)
  12. l1_targets = torch.cat(l1_targets, 0)
  13. fg_masks = torch.cat(fg_masks, 0)
  14. num_fg = max(num_fg, 1)
  15. # loss
  16. loss_iou += (self.iou_loss(box_preds.view(-1, 4)[fg_masks].T, reg_targets)).sum() / num_fg
  17. loss_l1 += (self.l1_loss(box_preds_org.view(-1, 4)[fg_masks], l1_targets)).sum() / num_fg
  18. loss_obj += (self.bce_loss(obj_preds.view(-1, 1), obj_targets*1.0)).sum() / num_fg
  19. loss_cls += (self.bce_loss(cls_preds.view(-1, self.class_num)[fg_masks], cls_targets)).sum() / num_fg
  20. total_losses = self.reg_weight * loss_iou + loss_l1 + loss_obj + loss_cls
  21. return total_losses, self.reg_weight * loss_iou, loss_l1, loss_obj, loss_cls

四、网络的输出解析

    网络输出为三个分支各自整合的结果,存储于一个List中,shape分别为[b,11,80,80],[b,11,40,40],[b,11,20,20],其中11表示特征图每个像素点对应的anchor个数为1个,11=(4+1+6)x 1

     将所有结果整合并进行reshape变换后为:output-shape->[b,8400,11],其中11 = (4+1+6)x 1

    4:表示对box的预测结果,前两个值为中心点cx,cy相对于所在像素左上角的偏移值后两个值为宽高经过对数处理后的值

          所以在经过解析的时候中心点需要与所在像素左上角的坐标值相加,宽高值需要采用指数torch.exp(val)进行处理,然后将得到的四个值乘以对应的strides.

    1:  每个预测框是前景的概率值,在进行预测是,首先会通过该值判断前景背景,并对背景进行过滤。

    6:每个预测框的类别概率值,这里是六个类别,每个类别会给出一个值,最大者即为对应的类别。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/小小林熬夜学编程/article/detail/698582
推荐阅读
相关标签
  

闽ICP备14008679号