当前位置:   article > 正文

Openpcdet 系列 Pointpillar代码逐行解析之POST_PROCESSING模块_nms_post_maxsize

nms_post_maxsize

Openpcdet 的POST_PROCESSING模块

在这里插入图片描述

在OpenPCDet中,POST_PROCESSING模块是用于在模型输出的点云检测结果上进行后处理的组件。

该模块主要负责对检测结果进行滤波、聚类、追踪等操作,以提高检测的准确性和稳定性。

POST_PROCESSING模块通常包含以下几个主要的子模块或步骤:

  1. 点云滤波(Point Cloud Filtering):这一步骤用于去除原始点云中的噪声和离群点,常用的滤波方法包括体素下采样(Voxel Downsampling)、统计滤波(Statistical Outlier Removal)等。

  2. 检测框聚类(Box Clustering):在一些场景中,模型可能会输出多个相似的检测框,这些框可能对应着同一个物体。通过聚类算法,可以将这些相似的框归为一类,从而得到更准确的检测结果。

  3. 对象追踪(Object Tracking):在连续帧的点云数据中,通过追踪算法可以将同一个物体在不同帧之间进行关联,从而实现物体的连续跟踪。常用的追踪算法包括卡尔曼滤波(Kalman Filtering)、匈牙利算法(Hungarian Algorithm)等。

  4. 检测结果过滤(Detection Result Filtering):根据应用需求,可以对最终的检测结果进行进一步过滤,例如根据置信度阈值进行筛选,去除不满足要求的检测结果。

Pointpillar POST_PROCESSING 配置文件

    POST_PROCESSING:
        RECALL_THRESH_LIST: [0.3, 0.5, 0.7]
        SCORE_THRESH: 0.1
        OUTPUT_RAW_SCORE: False

        EVAL_METRIC: kitti

        NMS_CONFIG:
            MULTI_CLASSES_NMS: False
            NMS_TYPE: nms_gpu
            NMS_THRESH: 0.01
            NMS_PRE_MAXSIZE: 4096
            NMS_POST_MAXSIZE: 500
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  1. RECALL_THRESH_LIST: [0.3, 0.5, 0.7]:分别代表 [‘Car’, ‘Pedestrian’, ‘Cyclist’] 这是一个召回阈值列表,用于评估检测结果的召回率。在评估过程中,将计算不同召回阈值下的召回率,并输出相应的指标。

  2. SCORE_THRESH: 0.1:这是一个得分阈值,用于过滤检测结果。只有得分高于该阈值的检测结果才会被保留,低于阈值的结果将被丢弃。

  3. OUTPUT_RAW_SCORE: False:这是一个布尔值参数,用于指定是否在输出结果中包含原始的检测得分。如果设置为True,则在输出结果中将包含原始得分;如果设置为False,则只输出二值化的检测结果。

  4. EVAL_METRIC: kitti:这是一个评估指标的选择,用于衡量检测结果的性能。在这种情况下,选择了Kitti评估指标,该指标通常用于衡量目标检测在Kitti数据集上的性能。

  5. NMS_CONFIG:这是一个配置NMS(非极大值抑制)的子模块,用于在检测结果中进行框的合并和过滤。

  6. MULTI_CLASSES_NMS: False:这是一个布尔值参数,用于指定是否对多类别进行NMS。如果设置为True,则会对多个类别的检测框进行NMS;如果设置为False,则只对同一类别的检测框进行NMS。

  7. NMS_TYPE: nms_gpu:这是NMS算法的选择。在这种情况下,选择了nms_gpu算法,该算法使用GPU加速执行NMS操作。

  8. NMS_THRESH: 0.01:这是NMS的阈值,用于控制重叠度的判定。当两个框的重叠度高于该阈值时,较低得分的框将被抑制。

  9. NMS_PRE_MAXSIZE: 4096:这是NMS操作之前,每个类别最大保留的检测框数量。如果超过该数量,将根据得分进行排序并截断。

  10. NMS_POST_MAXSIZE: 500:这是NMS操作之后,每个类别最大保留的检测框数量。如果超过该数量,将根据得分进行排序并截断。

POST_PROCESSING 代码讲解

代码在OpenPCDet/pcdet/models/detectors/detector3d_template.py下面

    def post_processing(self, batch_dict):
        """
        Args:
            batch_dict:
                batch_size:
                batch_cls_preds: (B, num_boxes, num_classes | 1) or (N1+N2+..., num_classes | 1)
                                or [(B, num_boxes, num_class1), (B, num_boxes, num_class2) ...]
                multihead_label_mapping: [(num_class1), (num_class2), ...]
                batch_box_preds: (B, num_boxes, 7+C) or (N1+N2+..., 7+C)
                cls_preds_normalized: indicate whether batch_cls_preds is normalized
                batch_index: optional (N1+N2+...)
                has_class_labels: True/False
                roi_labels: (B, num_rois)  1 .. num_classes
                batch_pred_labels: (B, num_boxes, 1)
        Returns:
		        pred_dicts: 一个包含预测结果的列表,每个元素是一个字典,包含了预测框的坐标、得分和类别
        		recall_dict: 一个包含召回率信息的字典,用于评估检测结果的召回率
        """
        post_process_cfg = self.model_cfg.POST_PROCESSING
        batch_size = batch_dict['batch_size']
        recall_dict = {}
        pred_dicts = []
        
        for index in range(batch_size):
        # 根据是否包含batch_index来确定box_preds的形状
            if batch_dict.get('batch_index', None) is not None:
                assert batch_dict['batch_box_preds'].shape.__len__() == 2
                batch_mask = (batch_dict['batch_index'] == index)
            else:
                assert batch_dict['batch_box_preds'].shape.__len__() == 3
                batch_mask = index

            box_preds = batch_dict['batch_box_preds'][batch_mask]
            src_box_preds = box_preds
            # 处理分类预测结果
            if not isinstance(batch_dict['batch_cls_preds'], list):
                cls_preds = batch_dict['batch_cls_preds'][batch_mask]

                src_cls_preds = cls_preds
                assert cls_preds.shape[1] in [1, self.num_class]

                if not batch_dict['cls_preds_normalized']:
                    cls_preds = torch.sigmoid(cls_preds)
            else:
                cls_preds = [x[batch_mask] for x in batch_dict['batch_cls_preds']]
                src_cls_preds = cls_preds
                if not batch_dict['cls_preds_normalized']:
                    cls_preds = [torch.sigmoid(x) for x in cls_preds]
			 # 多类别NMS
            if post_process_cfg.NMS_CONFIG.MULTI_CLASSES_NMS:
                if not isinstance(cls_preds, list):
                    cls_preds = [cls_preds]
                    multihead_label_mapping = [torch.arange(1, self.num_class, device=cls_preds[0].device)]
                else:
                    multihead_label_mapping = batch_dict['multihead_label_mapping']

                cur_start_idx = 0
                pred_scores, pred_labels, pred_boxes = [], [], []
                for cur_cls_preds, cur_label_mapping in zip(cls_preds, multihead_label_mapping):
                    assert cur_cls_preds.shape[1] == len(cur_label_mapping)
                    cur_box_preds = box_preds[cur_start_idx: cur_start_idx + cur_cls_preds.shape[0]]
                    cur_pred_scores, cur_pred_labels, cur_pred_boxes = model_nms_utils.multi_classes_nms(
                        cls_scores=cur_cls_preds, box_preds=cur_box_preds,
                        nms_config=post_process_cfg.NMS_CONFIG,
                        score_thresh=post_process_cfg.SCORE_THRESH
                    )
                    cur_pred_labels = cur_label_mapping[cur_pred_labels]
                    pred_scores.append(cur_pred_scores)
                    pred_labels.append(cur_pred_labels)
                    pred_boxes.append(cur_pred_boxes)
                    cur_start_idx += cur_cls_preds.shape[0]

                final_scores = torch.cat(pred_scores, dim=0)
                final_labels = torch.cat(pred_labels, dim=0)
                final_boxes = torch.cat(pred_boxes, dim=0)
            else:
              # 单类别NMS
                cls_preds, label_preds = torch.max(cls_preds, dim=-1)
                if batch_dict.get('has_class_labels', False):
                    label_key = 'roi_labels' if 'roi_labels' in batch_dict else 'batch_pred_labels'
                    label_preds = batch_dict[label_key][index]
                else:
                    label_preds = label_preds + 1 
                selected, selected_scores = model_nms_utils.class_agnostic_nms(
                    box_scores=cls_preds, box_preds=box_preds,
                    nms_config=post_process_cfg.NMS_CONFIG,
                    score_thresh=post_process_cfg.SCORE_THRESH
                )

                if post_process_cfg.OUTPUT_RAW_SCORE:
                    max_cls_preds, _ = torch.max(src_cls_preds, dim=-1)
                    selected_scores = max_cls_preds[selected]

                final_scores = selected_scores
                final_labels = label_preds[selected]
                final_boxes = box_preds[selected]
             # 生成召回率记录    
            recall_dict = self.generate_recall_record(
                box_preds=final_boxes if 'rois' not in batch_dict else src_box_preds,
                recall_dict=recall_dict, batch_index=index, data_dict=batch_dict,
                thresh_list=post_process_cfg.RECALL_THRESH_LIST
            )        
 		# 构建预测结果字典,并添加到预测结果列表中
            record_dict = {
                'pred_boxes': final_boxes,
                'pred_scores': final_scores,
                'pred_labels': final_labels
            }
            pred_dicts.append(record_dict)

        return pred_dicts, recall_dict

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112

代码里面的一些小细节,后续有人问到再讲

引用

Openpcdet 系列 Pointpillar代码逐行解析
OpenPCDet 环境安装
OpenPCDet KITTI数据加载过程 (Pointpillar模型)
Openpcdet 系列 Pointpillar代码逐行解析之Voxel Feature Encoding (VFE)模块
Openpcdet 系列 Pointpillar代码逐行解析之MAP_TO_BEV模块

Openpcdet 系列 Pointpillar代码逐行解析之BACKBONE_2D模块

Openpcdet 系列 Pointpillar代码逐行解析之检测头(DENSE_HEAD)模块

Openpcdet 系列 Pointpillar代码逐行解析之POST_PROCESSING模块

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/IT小白/article/detail/212764
推荐阅读
相关标签
  

闽ICP备14008679号