赞
踩
本文主要是学习整理,结合DETR3D的模型结构与MMDetection3D的模型构建方法,首先介绍model dict的模型参数设置,然后介绍逐个介绍DETR3D中的子结构,过程中简单讲解mmdetection3d的模型构建流程。
model部分:定义按照backbone,neck,head的顺序设置模型参数。
# 此处省略关键参数,实际以具体的配置文件为准 model = dict( type='Detr3D', use_grid_mask=True, # resnet提取0,1,2,3层的特征 img_backbone=dict(), img_neck=dict(), # transformer head定义,本层的dict所指代的类负责对包含在内的 下一层dict实体 进行实例化 pts_bbox_head=dict( type='Detr3DHead', # head中只有decoder transformer=dict(), # loss,bbox,position_embedding bbox_coder=dict(), positional_encoding=dict(), loss_cls=dict(), train_cfg=dict())) )
MMDetection3D利用类之间的包含关系(head中包含transformer, transformer中包含decoder等)递归实例化每个组件, 在build_model后,通过registry这种注册机制,递归地实例化每个registry model。
具体如何初始化呢? 编者在第一次看源码时也遇到了问题,框架的抽象程度很高,但是逐步推进到底层源码,了解registry的注册、调用、初始化方式,可以清楚了解整个流程,这里以transformer与decoder为例:
@TRANSFORMER.register_module()
class Detr3DTransformer(BaseModule):
def __init__(self,
num_feature_levels=4,
num_cams=6,
two_stage_num_proposals=300,
decoder=None,
**kwargs):
super(Detr3DTransformer, self).__init__(**kwargs)
# 初始化decoder
self.decoder = build_transformer_layer_sequence(decoder)
def build_from_cfg(cfg, registry, default_args=None):
# obj_type:transformer
obj_type = args.pop('type')
if isinstance(obj_type, str):
# get registry for dataset
# 查询并获得registry注册好的decoder类
obj_cls = registry.get(obj_type)
return obj_cls
总结来说:
img_backbone=dict( type='ResNet', # resnet101 depth=101, # bottom-up结构特征图的C0,1,2,3 num_stages=4, out_indices=(0, 1, 2, 3), frozen_stages=1, norm_cfg=dict(type='BN2d', requires_grad=False), norm_eval=True, style='caffe', dcn=dict(type='DCNv2', deform_groups=1, fallback_on_stride=False), stage_with_dcn=(False, False, True, True)), img_neck=dict( type='FPN', # FPN的输入channel in_channels=[256, 512, 1024, 2048], # 最终的四个特征图都是256维 out_channels=256, start_level=1, add_extra_convs='on_output', num_outs=4, relu_before_extra_convs=True)
head继承自mmdet3d提供的DetrHead
pts_bbox_head=dict( type='Detr3DHead', num_query=900, num_classes=10, in_channels=256, sync_cls_avg_factor=True, with_box_refine=True, as_two_stage=False, # head中只有decoder transformer=dict(), # loss,bbox,position_embedding bbox_coder=dict( type='NMSFreeCoder', post_center_range=[-61.2, -61.2, -10.0, 61.2, 61.2, 10.0], pc_range=point_cloud_range, max_num=300, voxel_size=voxel_size, num_classes=10), positional_encoding=dict( type='SinePositionalEncoding', num_feats=128, normalize=True, offset=-0.5), loss_cls=dict( type='FocalLoss', use_sigmoid=True, gamma=2.0, alpha=0.25, loss_weight=2.0), loss_bbox=dict(type='L1Loss', loss_weight=0.25), loss_iou=dict(type='GIoULoss', loss_weight=0.0))
最底层的部分,完成了论文中的主要创新点部分:
transformer=dict( type='Detr3DTransformer', decoder=dict( type='Detr3DTransformerDecoder', num_layers=6, return_intermediate=True, # 设置单个decoder layer参数 transformerlayers=dict( type='DetrTransformerDecoderLayer', attn_cfgs=[ dict( type='MultiheadAttention', embed_dims=256, num_heads=8, dropout=0.1), dict( type='Detr3DCrossAtten', pc_range=point_cloud_range, num_points=1, embed_dims=256) ], feedforward_channels=512, ffn_dropout=0.1, operation_order=('self_attn', 'norm', 'cross_attn', 'norm','ffn', 'norm'))))
负责DETR3D的关键部分:reference points,特征抓取,queries refinement,objects cross attention
@TRANSFORMER.register_module() class Detr3DTransformer(BaseModule): def forward(self, mlvl_feats, query_embed, reg_branches=None, **kwargs): """ mlvl_feats (list(Tensor)): [bs, embed_dims, h, w]. query_embed (Tensor): [num_query, c]. mlvl_pos_embeds (list(Tensor)): [bs, embed_dims, h, w]. reg_branches (obj:`nn.ModuleList`): Regression heads with_box_refine """ bs = mlvl_feats[0].size(0) # 256 -> 128, 128 query_pos, query = torch.split(query_embed, self.embed_dims , dim=1) # -1为保持原样 query_pos = query_pos.unsqueeze(0).expand(bs, -1, -1) query = query.unsqueeze(0).expand(bs, -1, -1) # query_pos作为输入通过reg_branches回归参考点对应的2d position reference_points = self.reference_points(query_pos) reference_points = reference_points.sigmoid() init_reference_out = reference_points # decoder query = query.permute(1, 0, 2) query_pos = query_pos.permute(1, 0, 2) # decoder inter_states, inter_references = self.decoder( query=query, key=None, value=mlvl_feats, query_pos=query_pos, reference_points=reference_points, reg_branches=reg_branches, **kwargs) inter_references_out = inter_references return inter_states, init_reference_out, inter_references_out
decoder部分关键在于如何完成论文中提出的object queries refinement,这里着重进行介绍:
Decoder Block
每一个decoder block流程:预测上一层queries对应的reference points后对queries进行refinement后,进行self-attention,作为下一个block输入:self.dropout(output) + inp_residual + pos_feat,即输出=原始输入+双线性插值特征+query位置特征
如何对提取后的多尺度特征进行处理呢?
这里的提取的图像特征,从shape=(bs, c, num_query, num_cam, 1, len(num_feature_level))到shape=(bs, c, num_query),通过三个连续的sum(-1),将不同视角的相机特征,不同尺度的相机特征,进行求和,得到最终的图像特征,然后通过project将图像特征投影到与query同维度,最后直接求和作为下一个Decoder Block的输入。
output = output.sum(-1).sum(-1).sum(-1)
@ATTENTION.register_module() class Detr3DCrossAtten(BaseModule): def forward(self, query, key, value, residual=None, query_pos=None, key_padding_mask=None, reference_points=None, spatial_shapes=None, level_start_index=None, **kwargs): query = query.permute(1, 0, 2) bs, num_query, _ = query.size() attention_weights = self.attention_weights(query).view( bs, 1, num_query, self.num_cams, self.num_points, self.num_levels) # 双线性插值 reference_points_3d, output, mask = feature_sampling( value, reference_points, self.pc_range, kwargs['img_metas']) output = torch.nan_to_num(output) mask = torch.nan_to_num(mask) attention_weights = attention_weights.sigmoid() * mask output = output * attention_weights output = output.sum(-1).sum(-1).sum(-1) # sum后缩减三个维度:shape:[bs, c, num_query] output = output.permute(2, 0, 1) # [num_query, bs, c] output = self.output_proj(output) # (num_query, bs, embed_dims),将reference3d的dim转换到256 # output作为fetch的feature,与经过encoder后的query、原始query直接相加作为refinement query pos_feat = self.position_encoder(inverse_sigmoid(reference_points_3d)).permute(1, 0, 2) return self.dropout(output) + inp_residual + pos_feat
# 特征采样部分 # 特征采样部分, Input queries from different level. Each element has shape [bs, embed_dims, h, w] 也就是[4, bs, embed_dims, h, w] def feature_sampling(mlvl_feats, reference_points, pc_range, img_metas): lidar2img = [] # lidar2img:3D坐标以lidar为中心,求出3D点到img的转换关系也就是求出lidar到img的转换关系 for img_meta in img_metas: lidar2img.append(img_meta['lidar2img']) lidar2img = np.asarray(lidar2img) # N = 6,referrence_points:[bs, num_query, 3] lidar2img = reference_points.new_tensor(lidar2img) # (B, N, 4, 4) reference_points = reference_points.clone() reference_points_3d = reference_points.clone() # recompute top-left(x,y) and bottom-right(x) reference_points[..., 0:1] = reference_points[..., 0:1]*(pc_range[3] - pc_range[0]) + pc_range[0] reference_points[..., 1:2] = reference_points[..., 1:2]*(pc_range[4] - pc_range[1]) + pc_range[1] reference_points[..., 2:3] = reference_points[..., 2:3]*(pc_range[5] - pc_range[2]) + pc_range[2] # reference_points [bs, num_query, 3] reference_points = torch.cat((reference_points, torch.ones_like(reference_points[..., :1])), -1) B, num_query = reference_points.size()[:2] # num_cam = 6 num_cam = lidar2img.size(1) # from [b,1,num_query,4] to [b,num_cam,num_query, 4, 1] reference_points = reference_points.view(B, 1, num_query, 4).repeat(1, num_cam, 1, 1).unsqueeze(-1) # shape:[b, num_cam, num_query, 4, 4] lidar2img = lidar2img.view(B, num_cam, 1, 4, 4).repeat(1, 1, num_query, 1, 1) # project 3d -> 2d # shape:[b, num_cam, num_query, 4] reference_points_cam = torch.matmul(lidar2img, reference_points).squeeze(-1) eps = 1e-5 mask = (reference_points_cam[..., 2:3] > eps) # cam坐标归一化: reference_points_cam.shape:[b,num_cam,num_query,2] reference_points_cam = reference_points_cam[..., 0:2] / torch.maximum( reference_points_cam[..., 2:3], torch.ones_like(reference_points_cam[..., 2:3])*eps) # 0,1分别代表camera像素坐标系下的x,y坐标,并进行归一化 reference_points_cam[..., 0] /= img_metas[0]['img_shape'][0][1] reference_points_cam[..., 1] /= img_metas[0]['img_shape'][0][0] reference_points_cam = (reference_points_cam - 0.5) * 2 mask = (mask & (reference_points_cam[..., 0:1] > -1.0) & (reference_points_cam[..., 0:1] < 1.0) & (reference_points_cam[..., 1:2] > -1.0) & (reference_points_cam[..., 1:2] < 1.0)) mask = mask.view(B, num_cam, 1, num_query, 1, 1).permute(0, 2, 3, 1, 4, 5) mask = torch.nan_to_num(mask) sampled_feats = [] # 对四个特征层分别求出线性插值后的feature,其中N为num_query, [4, bs, embed_dims, h, w] for lvl, feat in enumerate(mlvl_feats): B, N, C, H, W = feat.size() # (num_key, bs, embed_dims) # N=num_cam feat = feat.view(B*N, C, H, W) # [b,num_cam,num_query,2] -> [b, num_cam, num_query, 1, 2] reference_points_cam_lvl = reference_points_cam.view(B*N, num_query, 1, 2) # F.grid_sample return:[b*n,c,num_query,1]每个query对应着一个grid采样(bilinear incorparation)后返回的值 sampled_feat = F.grid_sample(feat, reference_points_cam_lvl) # b,c,n_q,n,1 sampled_feat = sampled_feat.view(B, N, C, num_query, 1).permute(0, 2, 3, 1, 4) sampled_feats.append(sampled_feat) # [b,n,c,num_query,len(mlvl_feats)] sampled_feats = torch.stack(sampled_feats, -1) sampled_feats = sampled_feats.view(B, C, num_query, num_cam, 1, len(mlvl_feats)) return reference_points_3d, sampled_feats, mask
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。