赞
踩
不同于LSS、BEVDepth的bottom-up式,先进行深度估计,设计2D转3D的模块。DETR3D是一种3D转2D的top-down思路。先预设一系列预测框的查询向量object querys,利用它们生成3D reference point,将这些3D reference point 利用相机参数转换矩阵,投影回2D图像坐标,并根据他们在图像的位置去找到对应的图像特征,用图像特征和object querys做cross-attention,不断refine object querys。最后利用两个MLP分支分别输出分类预测结果与回归预测结果。正负样本则采用和DETR相同的二分图匹配,即根据最小cost在900个object querys中找到与GT数量最匹配的N个预测框。由于正负样本匹配以及object querys这种查询目标的方式与DETR类似,因此可以看成是DETR在3D的扩展。
①. 利用Resnet101 + fpn提取6张环视图像特征,获得1/4, 1/8, 1/16, 1/32, 4个不同尺度的输出(注意这里6张图的输入方式采用将Batch 和 N(camera_nums)拼接在一起的方式)
②. 预设900个object_querys(类似于2D中的priobox先验框), 拆分object query为query和query_pos, 利用全连接将query_pos的维度由[900, 256]映射到[900, 3], 此时就获得了BEV空间3D reference point (x, y, z)的参考点。
③. 进入transformer decoder,共有6层decoder layer,其中在每层layer之中,令q=k=v=query,即所有的object query之间先做self-attention来相互交互获取全局信息并避免多个query收敛到同个物体。
④. 将预测的3D reference point左乘转换矩阵, 除以深度Zc,转换到二维的图像坐标系, 获得2D reference point。
⑤. 预测的3D reference point投影回2D中,可能无对应的点或者在当前相机下不可见,因此使用一个mask 表示3D reference point是否在当前相机位中。
⑥. 遍历fpn输出的四个特征层,利用2D reference point中的位置信息,在特征层中进行grid_sample(双线性插值)采样,获得与2D reference point对应的图像特征。
⑦. query作为attention权重,与图像特征进行cross-attention。
⑧. 用取到的特征去 refine(优化) 3D reference point,refine 的方式也非常简单粗暴,直接相加即可。
⑨. 利用全连接输出回归预测分支与分类预测分支
⑩. 匈牙利算法进行二分图匹配,获得正负样本,计算分类损失(focal loss)、回归损失(L1 loss)。
优点:
①. 只查询object query对应的特征,没有完整显式地表示出整个BEV, 节省了内存和计算量,速度更快。
缺点:
①. 由于3D向2D投影时利用的是3D reference point这一物体中心点去fpn特征图中寻找特征,因此当感受野不足时,找到的特征可能不全,因此在实际应用中较长的目标比如bus,可能预测框偏小。
②. 同个BEV网格上的 reference point投影回2D采样的图像特征是相同的,缺乏深度信息,reference point和图像特征是否匹配需要通过不断的隐式学习去迭代。
利用Resnet101 + fpn提取6张环视图像特征,获得1/4, 1/8, 1/16, 1/32, 4个不同尺度的输出:
注意这里6张图的输入方式采用将Batch 和 N(camera_nums)拼接在一起的方式
def extract_img_feat(self, img, img_metas): """Extract features of images.""" B = img.size(0) if img is not None: input_shape = img.shape[-2:] # update real input shape of each single img for img_meta in img_metas: img_meta.update(input_shape=input_shape) if img.dim() == 5 and img.size(0) == 1: img.squeeze_() elif img.dim() == 5 and img.size(0) > 1: B, N, C, H, W = img.size() # 合并batch和N的维度 img = img.view(B * N, C, H, W) if self.use_grid_mask: img = self.grid_mask(img) # Resnet输出1/4, 1/8, 1/16, 1/32尺度特征图 img_feats = self.img_backbone(img) if isinstance(img_feats, dict): img_feats = list(img_feats.values()) else: return None # FPN if self.with_img_neck: img_feats = self.img_neck(img_feats) img_feats_reshaped = [] for img_feat in img_feats: BN, C, H, W = img_feat.size() img_feats_reshaped.append(img_feat.view(B, int(BN / B), C, H, W)) return img_feats_reshaped
拆分object query为query和query_pos, 利用全连接处理query_pos获得3D reference point。
class Detr3DTransformer(BaseModule): def forward(self, mlvl_feats, query_embed, reg_branches=None, **kwargs): """Forward function for `Detr3DTransformer`. Args: mlvl_feats:2d图像特征 query_embed:object querys:[num_query, 256] """ assert query_embed is not None bs = mlvl_feats[0].size(0) # 将object query的位置编码和query拆分,query_pos: [900, 256], query:[900, 256] query_pos, query = torch.split(query_embed, self.embed_dims , dim=1) # query_pos: [900, 256] ---> [1, 900, 256] query_pos = query_pos.unsqueeze(0).expand(bs, -1, -1) # query: [900, 256] ---> [1, 900, 256] query = query.unsqueeze(0).expand(bs, -1, -1) # 全连接预测出在BEV空间的3D reference point offset坐标(x, y, z), query_pos: [1, 900, 256]--->[1, 900, 3] reference_points = self.reference_points(query_pos) # sigmoid约束offset到(0, 1)范围 reference_points = reference_points.sigmoid() init_reference_out = reference_points # query: [1, 900, 256] ---> [900, 1, 256] query = query.permute(1, 0, 2) # query_pos: [1, 900, 256] ---> [900, 1, 256] query_pos = query_pos.permute(1, 0, 2) # decoder,inter_states: 3d参考点在fpn的采样特征, inter_references:修正后的3d参考点 inter_states, inter_references = self.decoder( query=query, key=None, value=mlvl_feats, query_pos=query_pos, reference_points=reference_points, reg_branches=reg_branches, **kwargs) inter_references_out = inter_references return inter_states, init_reference_out, inter_references_out
共有6层decoder layer,其中在每层layer之中,所有的object query之间做self-attention来相互交互获取全局信息并避免多个query收敛到同个物体。object query再和图像特征之间做cross-attention。
此时Mutilhead Attention的参数q=k=v=query。该query为2.2中object query中的query。
if layer == 'self_attn':
temp_key = temp_value = query
query = self.attentions[attn_index](
query,
temp_key,
temp_value,
identity if self.pre_norm else None,
query_pos=query_pos,
key_pos=query_pos,
attn_mask=attn_masks[attn_index],
key_padding_mask=query_key_padding_mask,
**kwargs)
①. 将预测的3D reference point左乘转换矩阵, 除以深度Zc,转换到二维的图像坐标系, 获得2D reference point。
②. 过滤越界点,获得满足条件的mask
③. 遍历fpn输出的四个特征层,利用2D reference point中的位置信息,在特征层中进行grid_sample采样,获得与2D reference point对应的图像特征。
④. query作为attention权重,与图像特征进行cross-attention
elif layer == 'cross_attn':
query = self.attentions[attn_index](
query,
key,
value,
identity if self.pre_norm else None,
query_pos=query_pos,
key_pos=key_pos,
attn_mask=attn_masks[attn_index],
key_padding_mask=key_padding_mask,
**kwargs)
class Detr3DCrossAtten(BaseModule): def forward(self, query, key, value, residual=None, query_pos=None, key_padding_mask=None, reference_points=None, spatial_shapes=None, level_start_index=None, **kwargs): if key is None: key = query if value is None: value = key if residual is None: inp_residual = query if query_pos is not None: query = query + query_pos # query: [900, 1, 256]--->[1, 900, 256] query = query.permute(1, 0, 2) # 1, 900, 256 bs, num_query, _ = query.size() # 全连接, query: [1, 900, 256] --->[1, 900, 24]--->[1, 1, 900, 6, 1, 4] attention_weights = self.attention_weights(query).view(bs, 1, num_query, self.num_cams, self.num_points, self.num_levels) # 将3D 参考点利用转换矩阵转换到2D图像坐标系,通过2D reference_points的位置信息利用grid_sample去fpn输出特征层采样 reference_points_3d, output, mask = feature_sampling(value, reference_points, self.pc_range, kwargs['img_metas']) # 替换nan的值为0 output = torch.nan_to_num(output) mask = torch.nan_to_num(mask) # 保存满足边界条件的attention_weights attention_weights = attention_weights.sigmoid() * mask # query与图像特征进行cross-attention output = output * attention_weights # output: [1, 256, 900, 6, 1, 4]--->[1, 256, 900] output = output.sum(-1).sum(-1).sum(-1) # output: [900, 1, 256] output = output.permute(2, 0, 1) # 全连接 output: [900, 1, 256] output = self.output_proj(output) # pos_feat:[900, 1, 256] pos_feat = self.position_encoder(inverse_sigmoid(reference_points_3d)).permute(1, 0, 2) return self.dropout(output) + inp_residual + pos_feat
3D转2D + 采样代码
def feature_sampling(mlvl_feats, reference_points, pc_range, img_metas): lidar2img = [] # 获得雷达坐标系转换矩阵 for img_meta in img_metas: lidar2img.append(img_meta['lidar2img']) lidar2img = np.asarray(lidar2img) lidar2img = reference_points.new_tensor(lidar2img) # (B, N, 4, 4) reference_points = reference_points.clone() reference_points_3d = reference_points.clone() # 从[-51.2, 51.2]转换为[0, 51.2] # offset_x*(51.2+51.2)-51.2 reference_points[..., 0:1] = reference_points[..., 0:1]*(pc_range[3] - pc_range[0]) + pc_range[0] # offset_y*(51.2+51.2)-51.2 reference_points[..., 1:2] = reference_points[..., 1:2]*(pc_range[4] - pc_range[1]) + pc_range[1] # offset_z*(3+5)-5 reference_points[..., 2:3] = reference_points[..., 2:3]*(pc_range[5] - pc_range[2]) + pc_range[2] # reference_points: [1, 900, 4] reference_points = torch.cat((reference_points, torch.ones_like(reference_points[..., :1])), -1) B, num_query = reference_points.size()[:2] num_cam = lidar2img.size(1) # reference_points: [1, 900, 4]--->[1, 1, 900, 4]--->[1, 6, 900, 4, 1] reference_points = reference_points.view(B, 1, num_query, 4).repeat(1, num_cam, 1, 1).unsqueeze(-1) # lidar2img: [1, 6, 4, 4]--->[1, 6, 1, 4, 4]--->[1, 6, 900, 4, 4] lidar2img = lidar2img.view(B, num_cam, 1, 4, 4).repeat(1, 1, num_query, 1, 1) # 左乘转换矩阵转图片坐标, reference_points_cam: [1, 6, 900, 4, 1]--->[1, 6, 900, 4] reference_points_cam = torch.matmul(lidar2img, reference_points).squeeze(-1) eps = 1e-5 # 过滤offset_z > 0 mask = (reference_points_cam[..., 2:3] > eps) # 除以深度坐标Zc,获得二维坐标(x,y) reference_points_cam = reference_points_cam[..., 0:2] / torch.maximum(reference_points_cam[..., 2:3], torch.ones_like(reference_points_cam[..., 2:3])*eps) # 归一化 reference_points_cam[..., 0] /= img_metas[0]['img_shape'][0][1] reference_points_cam[..., 1] /= img_metas[0]['img_shape'][0][0] # 中心点转换为顶点 reference_points_cam = (reference_points_cam - 0.5) * 2 # 过滤越界 mask = (mask & (reference_points_cam[..., 0:1] > -1.0) & (reference_points_cam[..., 0:1] < 1.0) & (reference_points_cam[..., 1:2] > -1.0) & (reference_points_cam[..., 1:2] < 1.0)) # mask: [1, 6, 900, 1]---> [1, 6, 1, 900, 1, 1]--->[1, 1, 900, 6, 1, 1] mask = mask.view(B, num_cam, 1, num_query, 1, 1).permute(0, 2, 3, 1, 4, 5) mask = torch.nan_to_num(mask) sampled_feats = [] for lvl, feat in enumerate(mlvl_feats): B, N, C, H, W = feat.size() feat = feat.view(B*N, C, H, W) # reference_points_cam_lvl: [1, 6, 900, 2] --->[6, 900, 1, 2] reference_points_cam_lvl = reference_points_cam.view(B*N, num_query, 1, 2) # 利用F.grid_sample去feat上采样指定位置的点,feat: [6, 256, 116, 200], reference_points_cam_lvl:[6, 900, 1, 2], sampled_feat: [6, 256, 900, 1] sampled_feat = F.grid_sample(feat, reference_points_cam_lvl) # sampled_feat: [6, 256, 900, 1]--->[1, 6, 256, 900, 1]--->[1, 256, 900, 6, 1] sampled_feat = sampled_feat.view(B, N, C, num_query, 1).permute(0, 2, 3, 1, 4) # 保存每个特征层采样结果 sampled_feats.append(sampled_feat) # 在最后一维stack, sampled_feats: [1, 256, 900, 6, 1, 4] sampled_feats = torch.stack(sampled_feats, -1) # sampled_feats: [1, 256, 900, 6, 1, 4]--->[1, 256, 900, 6, 1, 4] sampled_feats = sampled_feats.view(B, C, num_query, num_cam, 1, len(mlvl_feats)) return reference_points_3d, sampled_feats, mask
和DETR类似,在所有object queries预测出来的预测框和所有GT box之间利用匈牙利算法进行二分图匹配,找到使得cost最小的最优匹配。
①. 分别计算分类cost:focal loss, 回归cost:l1 loss, 两者的和作为cost代价矩阵
②. 利用二分图匹配获得使cost最小的匹配结果
def assign(self, bbox_pred, cls_pred, gt_bboxes, gt_labels, gt_bboxes_ignore=None, eps=1e-7): num_gts, num_bboxes = gt_bboxes.size(0), bbox_pred.size(0) # 1. 将GT 索引和类别索引初始化为-1, assigned_gt_inds:[900], assigned_labels: [900] assigned_gt_inds = bbox_pred.new_full((num_bboxes, ), -1, dtype=torch.long) assigned_labels = bbox_pred.new_full((num_bboxes, ), -1, dtype=torch.long) # No ground truth or boxes, return empty assignment if num_gts == 0 or num_bboxes == 0: if num_gts == 0: # 没有GT,则全部声明为背景0 assigned_gt_inds[:] = 0 return AssignResult(num_gts, assigned_gt_inds, None, labels=assigned_labels) # 2. 分别计算分类cost与回归cost # 分类cost: focal loss cls_cost = self.cls_cost(cls_pred, gt_labels) # 归一化 normalized_gt_bboxes = normalize_bbox(gt_bboxes, self.pc_range) # 回归cost: l1 cost reg_cost = self.reg_cost(bbox_pred[:, :8], normalized_gt_bboxes[:, :8]) # weighted sum of above two costs cost = cls_cost + reg_cost # 3. 利用二分图匹配,获得最小cost匹配结果 cost = cost.detach().cpu() if linear_sum_assignment is None: raise ImportError('Please run "pip install scipy" ''to install scipy first.') matched_row_inds, matched_col_inds = linear_sum_assignment(cost) matched_row_inds = torch.from_numpy(matched_row_inds).to(bbox_pred.device) matched_col_inds = torch.from_numpy(matched_col_inds).to(bbox_pred.device) # 4. assign backgrounds and foregrounds # assign all indices to backgrounds first assigned_gt_inds[:] = 0 # assign foregrounds based on matching results assigned_gt_inds[matched_row_inds] = matched_col_inds + 1 assigned_labels[matched_row_inds] = gt_labels[matched_col_inds] return AssignResult(num_gts, assigned_gt_inds, None, labels=assigned_labels)
分类损失: focal loss
回归损失: l1 loss
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。