当前位置:   article > 正文

VoxelNext,全稀疏的3D目标检测网络

voxelnext

GitHub - dvlab-research/VoxelNeXt: VoxelNeXt: Fully Sparse VoxelNet for 3D Object Detection and Tracking (CVPR 2023)

https://arxiv.org/abs/2303.11301

摘要

当前3D目标检测模型,在检测部分都是沿用2D的方法,在dense的特征图上,通过预设的anchor或者center来预测3D的框,本文的创新是利用点云的稀疏的特性,在通过spconv提取特征后,不转化到dense的特征图,直接在稀疏的特征上进行3D框的预测。经验证,在常用的公开数据集上都取得了很好的效果。

1. 介绍

以常用的centerpoint模型为例,其中有,sparse to dense,虽然能有效工作,但是带来如下问题:计算资源的浪费、流程复杂、需要nms后处理。

 本文提出的方法,省去了center的anchor、sparse to dense、rpn、nms等步骤,直接而且是只在稀疏的特征位置上进行预测。

VoxelNext和Centerpoint,flops的优化。 

VoxelNext方法,相对centerpoint,FSD,在不同检测范围下的latency的对比,VoxelNext对长距离目标检测很友好。

2. 相关工作

       Lidar Detectors

        目前3D的检测器,通常都是参照2D的检测器,比如rcnn系列,比如centerpoint系列,虽然3D点云相对于2D数据本身是稀疏的,但是目前的检测器都还是在dense的特征图上进行预测的。本文进行一个变化点,直接在稀疏的特征上进行目标预测。

         Sparse Detectors

           分析了一些sparse的detectors,比如waymo的RSN,先在range image上segmentation提取前景点,然后在稀疏的前景点上进行目标检测;SWFormer,FSD都是一些稀疏检测的尝试,但是过程都偏复杂,本文用常用的稀疏卷积,尽量简化过程。
pillarnet

RSN

        Sparse Convlution Network

          因为稀疏卷积的高效性,现在是3D网络backbone的主流方法。但是一般都不直接用于检测头。目前有一些尝试优化,比如用transformer增加感受野,但是本文是通过额外的下采样来实现感受野的增加。

        3D Object Tracking        

          常见的是用kalman filter对结果进行跟踪,也有centertrack那样的直接预测速度,本文也利用了voxel的query来进行关联,有效的预测了物体中心的偏差。

3. Fully Sparse Voxel-based Network

        voxelnext网络结构示意图:

3.1 backbone adaptation

additional down sampling

在原先的下采样基础上,{1,2,4,8},{F 1 , F 2 , F 3 , F 4 },继续下采样{16,32},{F5,F6},然后把F4,F5,F6的spatial resolution align到F4,然后生成Fc。

 F是稀疏的特征,P是3D的坐标值。Fc就是F4,F5,F6的特征叠加。同时更新P5,P6到P4的尺寸。

  1. x_conv5 = self.conv5(x_conv4)
  2. x_conv6 = self.conv6(x_conv5)
  3. x_conv5.indices[:, 1:] *= 2
  4. x_conv6.indices[:, 1:] *= 4
  5. x_conv4 = x_conv4.replace_feature(torch.cat([x_conv4.features, x_conv5.features, x_conv6.features]))
  6. x_conv4.indices = torch.cat([x_conv4.indices, x_conv5.indices, x_conv6.indices])

sparse height compression

常规的做法,稀疏变dense,然后z维度加到channel维度。

这里,把稀疏的特征直接放置在bev平面,然后add求和。非常高效。

  1. def bev_out(self, x_conv):
  2. features_cat = x_conv.features
  3. indices_cat = x_conv.indices[:, [0, 2, 3]]
  4. spatial_shape = x_conv.spatial_shape[1:]
  5. indices_unique, _inv = torch.unique(indices_cat, dim=0, return_inverse=True)
  6. features_unique = features_cat.new_zeros((indices_unique.shape[0], features_cat.shape[1]))
  7. features_unique.index_add_(0, _inv, features_cat)
  8. x_out = spconv.SparseConvTensor(
  9. features=features_unique,
  10. indices=indices_unique,
  11. spatial_shape=spatial_shape,
  12. batch_size=x_conv.batch_size
  13. )
  14. return x_out

spatially voxel prunning

在下采样的过程中,对不重要的背景特征进行prune。既可以突出前景,也可以提高运算效率。

3.2 sparse head

        1. class head

预测,NxF => NxK

target,靠近gt box中心最近的voxel,是positive sample。

loss, focal loss

inference, 使用sparse max pooling. voxel本身够稀疏,只在非空的位置操作。如果本身物体离的很近怎么办?

 实验发现,query voxel,并不一定在box中心,甚至不一定在box框内。

        2. regression head

positive的voxel筛选, N->n

预测,nxF => nx2(dx,dy), nx1(z), nx3(w,h,l), nx2(cos,sin)

loss, l1 loss

相关代码:

前向的网络结构,整体结构和之前的cenerhead比,卷积从2d的conv,变成2d的subMconv。hm还叫hm。

  1. class SeparateHead(nn.Module):
  2. def __init__(self, input_channels, sep_head_dict, kernel_size, init_bias=-2.19, use_bias=False):
  3. super().__init__()
  4. self.sep_head_dict = sep_head_dict
  5. for cur_name in self.sep_head_dict:
  6. output_channels = self.sep_head_dict[cur_name]['out_channels']
  7. num_conv = self.sep_head_dict[cur_name]['num_conv']
  8. fc_list = []
  9. for k in range(num_conv - 1):
  10. fc_list.append(spconv.SparseSequential(
  11. spconv.SubMConv2d(input_channels, input_channels, kernel_size, padding=int(kernel_size//2), bias=use_bias, indice_key=cur_name),
  12. nn.BatchNorm1d(input_channels),
  13. nn.ReLU()
  14. ))
  15. fc_list.append(spconv.SubMConv2d(input_channels, output_channels, 1, bias=True, indice_key=cur_name+'out'))
  16. fc = nn.Sequential(*fc_list)
  17. if 'hm' in cur_name:
  18. fc[-1].bias.data.fill_(init_bias)
  19. else:
  20. for m in fc.modules():
  21. if isinstance(m, spconv.SubMConv2d):
  22. kaiming_normal_(m.weight.data)
  23. if hasattr(m, "bias") and m.bias is not None:
  24. nn.init.constant_(m.bias, 0)
  25. self.__setattr__(cur_name, fc)
  26. def forward(self, x):
  27. ret_dict = {}
  28. for cur_name in self.sep_head_dict:
  29. ret_dict[cur_name] = self.__getattr__(cur_name)(x).features
  30. return ret_dict

目标编码,之前是dense的hm,以及gt对应的编码后的target boxes

现在是稀疏的hm,以及对应编码后的target boxes。

  1. def assign_target_of_single_head(
  2. self, num_classes, gt_boxes, num_voxels, spatial_indices, spatial_shape, feature_map_stride, num_max_objs=500,
  3. gaussian_overlap=0.1, min_radius=2
  4. ):
  5. """
  6. Args:
  7. gt_boxes: (N, 8)
  8. feature_map_size: (2), [x, y]
  9. Returns:
  10. """
  11. heatmap = gt_boxes.new_zeros(num_classes, num_voxels)
  12. ret_boxes = gt_boxes.new_zeros((num_max_objs, gt_boxes.shape[-1] - 1 + 1))
  13. inds = gt_boxes.new_zeros(num_max_objs).long()
  14. mask = gt_boxes.new_zeros(num_max_objs).long()
  15. x, y, z = gt_boxes[:, 0], gt_boxes[:, 1], gt_boxes[:, 2]
  16. coord_x = (x - self.point_cloud_range[0]) / self.voxel_size[0] / feature_map_stride
  17. coord_y = (y - self.point_cloud_range[1]) / self.voxel_size[1] / feature_map_stride
  18. coord_x = torch.clamp(coord_x, min=0, max=spatial_shape[1] - 0.5) # bugfixed: 1e-6 does not work for center.int()
  19. coord_y = torch.clamp(coord_y, min=0, max=spatial_shape[0] - 0.5) #
  20. center = torch.cat((coord_x[:, None], coord_y[:, None]), dim=-1)
  21. center_int = center.int()
  22. center_int_float = center_int.float()
  23. dx, dy, dz = gt_boxes[:, 3], gt_boxes[:, 4], gt_boxes[:, 5]
  24. dx = dx / self.voxel_size[0] / feature_map_stride
  25. dy = dy / self.voxel_size[1] / feature_map_stride
  26. radius = centernet_utils.gaussian_radius(dx, dy, min_overlap=gaussian_overlap)
  27. radius = torch.clamp_min(radius.int(), min=min_radius)
  28. for k in range(min(num_max_objs, gt_boxes.shape[0])):
  29. if dx[k] <= 0 or dy[k] <= 0:
  30. continue
  31. if not (0 <= center_int[k][0] <= spatial_shape[1] and 0 <= center_int[k][1] <= spatial_shape[0]):
  32. continue
  33. cur_class_id = (gt_boxes[k, -1] - 1).long()
  34. # 距离最近的voxel选为query voxel
  35. # inds也更新为此voxel的顺序
  36. distance = self.distance(spatial_indices, center[k])
  37. inds[k] = distance.argmin()
  38. mask[k] = 1
  39. # 在稀疏的hm上,进行hm的绘制
  40. if 'gt_center' in self.gaussian_type:
  41. centernet_utils.draw_gaussian_to_heatmap_voxels(heatmap[cur_class_id], distance, radius[k].item() * self.gaussian_ratio)
  42. if 'nearst' in self.gaussian_type:
  43. centernet_utils.draw_gaussian_to_heatmap_voxels(heatmap[cur_class_id], self.distance(spatial_indices, spatial_indices[inds[k]]), radius[k].item() * self.gaussian_ratio)
  44. # △x,△y,是center和代理voxel的spatial inds的offset
  45. ret_boxes[k, 0:2] = center[k] - spatial_indices[inds[k]][:2]
  46. ret_boxes[k, 2] = z[k]
  47. ret_boxes[k, 3:6] = gt_boxes[k, 3:6].log()
  48. ret_boxes[k, 6] = torch.cos(gt_boxes[k, 6])
  49. ret_boxes[k, 7] = torch.sin(gt_boxes[k, 6])
  50. if gt_boxes.shape[1] > 8:
  51. ret_boxes[k, 8:] = gt_boxes[k, 7:-1]
  52. return heatmap, ret_boxes, inds, mask

hm以及box的decode

  1. def decode_bbox_from_voxels_nuscenes(batch_size, indices, obj, rot_cos, rot_sin,
  2. center, center_z, dim, vel=None, iou=None, point_cloud_range=None, voxel_size=None, voxels_3d=None,
  3. feature_map_stride=None, K=100, score_thresh=None, post_center_limit_range=None, add_features=None):
  4. batch_idx = indices[:, 0]
  5. spatial_indices = indices[:, 1:]
  6. scores, inds, class_ids = _topk_1d(None, batch_size, batch_idx, obj, K=K, nuscenes=True)
  7. center = gather_feat_idx(center, inds, batch_size, batch_idx)
  8. rot_sin = gather_feat_idx(rot_sin, inds, batch_size, batch_idx)
  9. rot_cos = gather_feat_idx(rot_cos, inds, batch_size, batch_idx)
  10. center_z = gather_feat_idx(center_z, inds, batch_size, batch_idx)
  11. dim = gather_feat_idx(dim, inds, batch_size, batch_idx)
  12. spatial_indices = gather_feat_idx(spatial_indices, inds, batch_size, batch_idx)
  13. if not add_features is None:
  14. add_features = [gather_feat_idx(add_feature, inds, batch_size, batch_idx) for add_feature in add_features]
  15. if not isinstance(feature_map_stride, int):
  16. feature_map_stride = gather_feat_idx(feature_map_stride.unsqueeze(-1), inds, batch_size, batch_idx)
  17. angle = torch.atan2(rot_sin, rot_cos)
  18. xs = (spatial_indices[:, :, -1:] + center[:, :, 0:1]) * feature_map_stride * voxel_size[0] + point_cloud_range[0]
  19. ys = (spatial_indices[:, :, -2:-1] + center[:, :, 1:2]) * feature_map_stride * voxel_size[1] + point_cloud_range[1]
  20. #zs = (spatial_indices[:, :, 0:1]) * feature_map_stride * voxel_size[2] + point_cloud_range[2] + center_z
  21. box_part_list = [xs, ys, center_z, dim, angle]
  22. if not vel is None:
  23. vel = gather_feat_idx(vel, inds, batch_size, batch_idx)
  24. box_part_list.append(vel)
  25. if not iou is None:
  26. iou = gather_feat_idx(iou, inds, batch_size, batch_idx)
  27. iou = torch.clamp(iou, min=0, max=1.)
  28. final_box_preds = torch.cat((box_part_list), dim=-1)
  29. final_scores = scores.view(batch_size, K)
  30. final_class_ids = class_ids.view(batch_size, K)
  31. if not add_features is None:
  32. add_features = [add_feature.view(batch_size, K, add_feature.shape[-1]) for add_feature in add_features]
  33. assert post_center_limit_range is not None
  34. mask = (final_box_preds[..., :3] >= post_center_limit_range[:3]).all(2)
  35. mask &= (final_box_preds[..., :3] <= post_center_limit_range[3:]).all(2)
  36. if score_thresh is not None:
  37. mask &= (final_scores > score_thresh)
  38. ret_pred_dicts = []
  39. for k in range(batch_size):
  40. cur_mask = mask[k]
  41. cur_boxes = final_box_preds[k, cur_mask]
  42. cur_scores = final_scores[k, cur_mask]
  43. cur_labels = final_class_ids[k, cur_mask]
  44. cur_add_features = [add_feature[k, cur_mask] for add_feature in add_features] if not add_features is None else None
  45. cur_iou = iou[k, cur_mask] if not iou is None else None
  46. ret_pred_dicts.append({
  47. 'pred_boxes': cur_boxes,
  48. 'pred_scores': cur_scores,
  49. 'pred_labels': cur_labels,
  50. 'pred_ious': cur_iou,
  51. 'add_features': cur_add_features,
  52. })
  53. return ret_pred_dicts

3.3 object tracking

voxel association

   query voxel作为center的代理,用l2 distance去关联query voxel。

        

本文内容由网友自发贡献,转载请注明出处:【wpsshop博客】
推荐阅读
相关标签
  

闽ICP备14008679号