当前位置:   article > 正文

YOLOV :基于YOLOX,使静态图像对象检测器在视频对象检测方面表现出色,注意力机制的魅力 Loss计算部分_oom runtimeerror is raised due to the huge memory

oom runtimeerror is raised due to the huge memory cost during label assignme

1.上一讲注意力机制后面便是loss部分,他输入的是图像、x_shifts是一张图片被分成了多少的网格数量、expanded_strides网格的长度、fc_output注意力机制后的特征、pred_idx预测

 

  1. if self.training:
  2. return self.get_losses(
  3. imgs,
  4. x_shifts,
  5. y_shifts,
  6. expanded_strides,
  7. labels,
  8. torch.cat(outputs, 1),
  9. origin_preds,
  10. dtype=xin[0].dtype,
  11. refined_cls=fc_output,
  12. idx=pred_idx,
  13. pred_res=pred_result,
  14. )
  15. else:
  16. class_conf, class_pred = torch.max(fc_output, -1, keepdim=False) #看哪个类别的概率最大
  17. result, result_ori = postprocess(copy.deepcopy(pred_result), self.num_classes, fc_output,nms_thre=nms_thresh )
  18. return result, result_ori # result

2.这部分比较简单,获取每一帧图像的box、真实分类 [batch,120,class+xywh]、每张图片的预测框。

  1. bbox_preds = outputs[:, :, :4] # [batch, n_anchors_all, 4]
  2. obj_preds = outputs[:, :, 4].unsqueeze(-1) # [batch, n_anchors_all, 1]
  3. cls_preds = outputs[:, :, 5:] # [batch, n_anchors_all, n_cls]
  4. # calculate targets 判断标签的类别是否是5个
  5. mixup = labels.shape[2] > 5
  6. if mixup:
  7. label_cut = labels[..., :5]
  8. else:
  9. label_cut = labels
  10. nlabel = (label_cut.sum(dim=2) > 0).sum(dim=1) # number of objects 去除120中没有的标签 获取标签的真实个数
  11. total_num_anchors = outputs.shape[1] # n_anchors_all
  12. x_shifts = torch.cat(x_shifts, 1) # [1, n_anchors_all]
  13. y_shifts = torch.cat(y_shifts, 1) # [1, n_anchors_all]
  14. expanded_strides = torch.cat(expanded_strides, 1)
  15. if self.use_l1: # 80轮之后的L1损失
  16. origin_preds = torch.cat(origin_preds, 1)
  17. cls_targets = []
  18. reg_targets = []
  19. l1_targets = []
  20. obj_targets = []
  21. fg_masks = []
  22. ref_targets = []
  23. num_fg = 0.0
  24. num_gts = 0.0
  25. ref_masks = []
  26. for batch_idx in range(outputs.shape[0]): # batch的大小
  27. num_gt = int(nlabel[batch_idx])
  28. num_gts += num_gt
  29. if num_gt == 0:
  30. cls_target = outputs.new_zeros((0, self.num_classes))
  31. reg_target = outputs.new_zeros((0, 4))
  32. l1_target = outputs.new_zeros((0, 4))
  33. obj_target = outputs.new_zeros((total_num_anchors, 1))
  34. fg_mask = outputs.new_zeros(total_num_anchors).bool()
  35. ref_target = outputs.new_zeros((idx[batch_idx].shape[0], self.num_classes + 1))
  36. ref_target[:, -1] = 1
  37. else:
  38. gt_bboxes_per_image = labels[batch_idx, :num_gt, 1:5] # 每张图片的box
  39. gt_classes = labels[batch_idx, :num_gt, 0] # 真实分类 [batch,120,class+xywh]
  40. bboxes_preds_per_image = bbox_preds[batch_idx] # 每张图片的预测框

3.这一步是获取每一帧图像上正样本的类别、mask掩码、iou、数量等。

输入是第几个batch、真实框的数量、所有框的数量(5379)、真实框、真实框的类别、每一帧图像的预测框(5376x4)、三个特征图与原图的缩放比、左上角的xy坐标、类别预测(8x5376x30)、框的预测(8x5376x4)、置信度预测(8x5376x1)、类别、图像。

  1. try:
  2. (
  3. gt_matched_classes, # 正样本的类别
  4. fg_mask, # 5376中正样本30 的mask掩码
  5. pred_ious_this_matching, # 正样本与它对应真实框的iou
  6. matched_gt_inds, # 正样本与真实框对应
  7. num_fg_img, # 正样本的数量
  8. ) = self.get_assignments( # noqa
  9. batch_idx,
  10. num_gt,
  11. total_num_anchors,
  12. gt_bboxes_per_image,
  13. gt_classes,
  14. bboxes_preds_per_image,
  15. expanded_strides,
  16. x_shifts,
  17. y_shifts,
  18. cls_preds,
  19. bbox_preds,
  20. obj_preds,
  21. labels,
  22. imgs,
  23. )
  24. except RuntimeError:
  25. logger.error(
  26. "OOM RuntimeError is raised due to the huge memory cost during label assignment. \
  27. CPU mode is applied in this batch. If you want to avoid this issue, \
  28. try to reduce the batch size or image size."
  29. )
  30. torch.cuda.empty_cache()
  31. (
  32. gt_matched_classes,
  33. fg_mask,
  34. pred_ious_this_matching,
  35. matched_gt_inds,
  36. num_fg_img,
  37. ) = self.get_assignments( # noqa
  38. batch_idx,
  39. num_gt,
  40. total_num_anchors,
  41. gt_bboxes_per_image,
  42. gt_classes,
  43. bboxes_preds_per_image,
  44. expanded_strides,
  45. x_shifts,
  46. y_shifts,
  47. cls_preds,
  48. bbox_preds,
  49. obj_preds,
  50. labels,
  51. imgs,
  52. "cpu",
  53. )

3.1这部分代码主要分成这几个模块:

1.将在预测框的中心点在真实框和自己设计的4.5x4.5大小框中的预测框选出来。

2.计算cost

  1. def get_assignments(
  2. self,
  3. batch_idx,
  4. num_gt,
  5. total_num_anchors,
  6. gt_bboxes_per_image,
  7. gt_classes,
  8. bboxes_preds_per_image,
  9. expanded_strides,
  10. x_shifts,
  11. y_shifts,
  12. cls_preds,
  13. bbox_preds,
  14. obj_preds,
  15. labels,
  16. imgs,
  17. mode="gpu",
  18. ):
  19. if mode == "cpu":
  20. print("------------CPU Mode for This Batch-------------")
  21. gt_bboxes_per_image = gt_bboxes_per_image.cpu().float()
  22. bboxes_preds_per_image = bboxes_preds_per_image.cpu().float()
  23. gt_classes = gt_classes.cpu().float()
  24. expanded_strides = expanded_strides.cpu().float()
  25. x_shifts = x_shifts.cpu()
  26. y_shifts = y_shifts.cpu()
  27. # 预测框的中心点既在真实框中也在4.5x4.5中的预测框
  28. fg_mask, is_in_boxes_and_center = self.get_in_boxes_info(
  29. gt_bboxes_per_image,
  30. expanded_strides,
  31. x_shifts,
  32. y_shifts,
  33. total_num_anchors,
  34. num_gt,
  35. )
  36. # 根据是否在框中,删除不在框中的数据
  37. bboxes_preds_per_image = bboxes_preds_per_image[fg_mask]
  38. cls_preds_ = cls_preds[batch_idx][fg_mask]
  39. obj_preds_ = obj_preds[batch_idx][fg_mask]
  40. num_in_boxes_anchor = bboxes_preds_per_image.shape[0]
  41. if mode == "cpu":
  42. gt_bboxes_per_image = gt_bboxes_per_image.cpu()
  43. bboxes_preds_per_image = bboxes_preds_per_image.cpu()
  44. pair_wise_ious = bboxes_iou(gt_bboxes_per_image, bboxes_preds_per_image, False) # 计算真实框和预测框的iou
  45. gt_cls_per_image = ( # 4x656x30 一张图片上有四个真实框 每个框的类别复制656 对应着656个预测框
  46. F.one_hot(gt_classes.to(torch.int64), self.num_classes)
  47. .float()
  48. .unsqueeze(1)
  49. .repeat(1, num_in_boxes_anchor, 1)
  50. )
  51. pair_wise_ious_loss = -torch.log(pair_wise_ious + 1e-8)
  52. if mode == "cpu":
  53. cls_preds_, obj_preds_ = cls_preds_.cpu(), obj_preds_.cpu()
  54. with torch.cuda.amp.autocast(enabled=False):
  55. cls_preds_ = ( # 置信度*类别
  56. cls_preds_.float().unsqueeze(0).repeat(num_gt, 1, 1).sigmoid_() # 4x656x30 656x30 复制4份
  57. * obj_preds_.unsqueeze(0).repeat(num_gt, 1, 1).sigmoid_() # 4x656x1
  58. )
  59. pair_wise_cls_loss = F.binary_cross_entropy( # 预测框的类别和真实框的类别做计算 得出他的类别
  60. cls_preds_.sqrt_(), gt_cls_per_image, reduction="none"
  61. ).sum(-1)
  62. del cls_preds_
  63. cost = (
  64. pair_wise_cls_loss #分类的一个损失
  65. + 3.0 * pair_wise_ious_loss #iou损失
  66. + 100000.0 * (~is_in_boxes_and_center) # 如果不在里面 给她一个很大的值,cos就不会选到他
  67. )
  68. (
  69. num_fg,
  70. gt_matched_classes,
  71. pred_ious_this_matching,
  72. matched_gt_inds,
  73. ) = self.dynamic_k_matching(cost, pair_wise_ious, gt_classes, num_gt, fg_mask)
  74. del pair_wise_cls_loss, cost, pair_wise_ious, pair_wise_ious_loss
  75. if mode == "cpu":
  76. gt_matched_classes = gt_matched_classes.cuda()
  77. fg_mask = fg_mask.cuda()
  78. pred_ious_this_matching = pred_ious_this_matching.cuda()
  79. matched_gt_inds = matched_gt_inds.cuda()
  80. return (
  81. gt_matched_classes,
  82. fg_mask,
  83. pred_ious_this_matching,
  84. matched_gt_inds,
  85. num_fg,
  86. )

3.1.1由于每一个格子有一个预测框,因此计算每一个格子的中心点,判断在真实框和自己设计的4.5x4.5大小的框中的中心点,去除不在这些框中的中心点。

  1. def get_in_boxes_info(
  2. self,
  3. gt_bboxes_per_image,
  4. expanded_strides,
  5. x_shifts,
  6. y_shifts,
  7. total_num_anchors,
  8. num_gt,
  9. ):
  10. expanded_strides_per_image = expanded_strides[0]
  11. x_shifts_per_image = x_shifts[0] * expanded_strides_per_image
  12. y_shifts_per_image = y_shifts[0] * expanded_strides_per_image #左上角的坐标
  13. x_centers_per_image = ( # 计算每一个格子的中心点的位置
  14. (x_shifts_per_image + 0.5 * expanded_strides_per_image)
  15. .unsqueeze(0)
  16. .repeat(num_gt, 1)
  17. ) # [n_anchor] -> [n_gt, n_anchor]
  18. y_centers_per_image = (
  19. (y_shifts_per_image + 0.5 * expanded_strides_per_image) #每个格子的中心点
  20. .unsqueeze(0)
  21. .repeat(num_gt, 1)
  22. )
  23. #计算真实框的四边 l_x l_y r_x r_y
  24. gt_bboxes_per_image_l = (
  25. (gt_bboxes_per_image[:, 0] - 0.5 * gt_bboxes_per_image[:, 2]) #中心点和长宽
  26. .unsqueeze(1)
  27. .repeat(1, total_num_anchors)
  28. )
  29. gt_bboxes_per_image_r = (
  30. (gt_bboxes_per_image[:, 0] + 0.5 * gt_bboxes_per_image[:, 2])
  31. .unsqueeze(1)
  32. .repeat(1, total_num_anchors)
  33. )
  34. gt_bboxes_per_image_t = (
  35. (gt_bboxes_per_image[:, 1] - 0.5 * gt_bboxes_per_image[:, 3])
  36. .unsqueeze(1)
  37. .repeat(1, total_num_anchors)
  38. )
  39. gt_bboxes_per_image_b = (
  40. (gt_bboxes_per_image[:, 1] + 0.5 * gt_bboxes_per_image[:, 3])
  41. .unsqueeze(1)
  42. .repeat(1, total_num_anchors)
  43. )
  44. # 判断5376个框那些中心点在真实框中
  45. b_l = x_centers_per_image - gt_bboxes_per_image_l
  46. b_r = gt_bboxes_per_image_r - x_centers_per_image
  47. b_t = y_centers_per_image - gt_bboxes_per_image_t
  48. b_b = gt_bboxes_per_image_b - y_centers_per_image
  49. bbox_deltas = torch.stack([b_l, b_t, b_r, b_b], 2)
  50. # 4x5376x4 四个真实框 5376个预测框 四个xy相减的值 -》4x5376 mask
  51. is_in_boxes = bbox_deltas.min(dim=-1).values > 0.0
  52. is_in_boxes_all = is_in_boxes.sum(dim=0) > 0 # 四个真实框都没有的预测框去除(这里真实框的数量为4,可能不同)
  53. # in fixed center
  54. #与上面一样
  55. center_radius = 4.5 #生成一个4.5x4.5的格子
  56. gt_bboxes_per_image_l = (gt_bboxes_per_image[:, 0]).unsqueeze(1).repeat(
  57. 1, total_num_anchors
  58. ) - center_radius * expanded_strides_per_image.unsqueeze(0)
  59. gt_bboxes_per_image_r = (gt_bboxes_per_image[:, 0]).unsqueeze(1).repeat(
  60. 1, total_num_anchors
  61. ) + center_radius * expanded_strides_per_image.unsqueeze(0)
  62. gt_bboxes_per_image_t = (gt_bboxes_per_image[:, 1]).unsqueeze(1).repeat(
  63. 1, total_num_anchors
  64. ) - center_radius * expanded_strides_per_image.unsqueeze(0)
  65. gt_bboxes_per_image_b = (gt_bboxes_per_image[:, 1]).unsqueeze(1).repeat(
  66. 1, total_num_anchors
  67. ) + center_radius * expanded_strides_per_image.unsqueeze(0)
  68. c_l = x_centers_per_image - gt_bboxes_per_image_l
  69. c_r = gt_bboxes_per_image_r - x_centers_per_image
  70. c_t = y_centers_per_image - gt_bboxes_per_image_t
  71. c_b = gt_bboxes_per_image_b - y_centers_per_image
  72. center_deltas = torch.stack([c_l, c_t, c_r, c_b], 2)
  73. is_in_centers = center_deltas.min(dim=-1).values > 0.0
  74. is_in_centers_all = is_in_centers.sum(dim=0) > 0
  75. # in boxes and in centers
  76. is_in_boxes_anchor = is_in_boxes_all | is_in_centers_all #两者并集 预测框或者在真实框或者在4.5x4.5中
  77. is_in_boxes_and_center = (
  78. is_in_boxes[:, is_in_boxes_anchor] & is_in_centers[:, is_in_boxes_anchor] #两者交集
  79. )
  80. return is_in_boxes_anchor, is_in_boxes_and_center

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/IT小白/article/detail/418200
推荐阅读
相关标签
  

闽ICP备14008679号