赞
踩
本文主要是自己在阅读mmdet中Deformable Detr的源码时的一个记录,如有错误或者问题,欢迎指正
首先zq即为object query,通过一个线性层,先预测出offset,后将三组offset添加到reference point上来得到采样后的位置,object query通过一个线性层和softmax,获取到attention weight(这就说明了deformable attention根本不需要用K点乘V来算attention weight,因为其attention weight是通过object query学到的),将attention weight与采样点的feature相乘,就得到了聚合后的value,在通过一个linear,就得到了output
Deformable Detr相对于detr的一个改进就是使用了多尺度的特征图,从配置文件中我们也可以看出
backbone=dict( type='ResNet', depth=50, num_stages=4, out_indices=(1, 2, 3), # 使用了resnet的3层feature map frozen_stages=1, norm_cfg=dict(type='BN', requires_grad=False), norm_eval=True, style='pytorch', init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')), neck=dict( type='ChannelMapper', in_channels=[512, 1024, 2048], kernel_size=1, out_channels=256, # 将三层feature map的输出通道统一为256 act_cfg=None, norm_cfg=dict(type='GN', num_groups=32), num_outs=4),
在代码层面,和DETR一样,首先是进入single_stage的forward_train中来提取feature map
super(SingleStageDetector, self).forward_train(img, img_metas)
x = self.extract_feat(img)
losses = self.bbox_head.forward_train(x, img_metas, gt_bboxes,
gt_labels, gt_bboxes_ignore)
这里的x是resnet提取出来的全部四层feature map
然后进入到detr_head的forward_rain()中(因为是deformable detr head是基础了DETRHead),在DETRHead的forward_train()中,通过下面的代码的进入到deformable detr head的forward中
outs = self(x, img_metas)
deformable detr head的整体逻辑和detr_head几乎相同,不同之处在于使用了多尺度的feature map。
batch_size = mlvl_feats[0].size(0) input_img_h, input_img_w = img_metas[0]['batch_input_shape'] img_masks = mlvl_feats[0].new_ones( (batch_size, input_img_h, input_img_w)) # 对于batch_size中的每一个图片,生成相应的原图的mask矩阵,将原始图像部分设置为0,1的位置表示pad部分 for img_id in range(batch_size): img_h, img_w, _ = img_metas[img_id]['img_shape'] img_masks[img_id, :img_h, :img_w] = 0 mlvl_masks = [] mlvl_positional_encodings = [] #对原来的每个img_masks进行下采样,使其和相应的feature map大小相匹配 for feat in mlvl_feats: mlvl_masks.append( #索引当中的None是增加维度的作用,img_masks扩充了一个维度:[b,h,w]-->[1,b,h,w] F.interpolate(img_masks[None], size=feat.shape[-2:]).to(torch.bool).squeeze(0)) # 生成positionan encoding,因为mlvl_masks每次append都是在最后一个,所以这里的索引每次取-1就好 mlvl_positional_encodings.append( self.positional_encoding(mlvl_masks[-1]))
mlvl_feats如下所示,我这里batch_size为1
这里有一个点值得注意,就是为什么在进行F.interpolate之前要先使用img_masks[None]增加一个维度,这是因为F.interpolate函数对于要采样的矩阵的维度有要求,即为批量(batch_size)×通道(channel)×[可选深度]×[可选高度]×宽度(前两个维度具有特殊的含义,不进行采样处理)
参考:F.interpolate——数组采样操作
在deformable detr head的forward中,通过下面的代码进入transformer
query_embeds = None
if not self.as_two_stage:
query_embeds = self.query_embedding.weight
hs, init_reference, inter_references, \
enc_outputs_class, enc_outputs_coord = self.transformer(
mlvl_feats,
mlvl_masks,
query_embeds, #[300,512] [num_query,embed_dims * 2]
mlvl_positional_encodings,
reg_branches=self.reg_branches if self.with_box_refine else None, # noqa:E501
cls_branches=self.cls_branches if self.as_two_stage else None # noqa:E501
)
代码跳转到DeformableDetrTransformer的forward中,首先会进行一些进入transformer的准备工作
feat_flatten = [] mask_flatten = [] lvl_pos_embed_flatten = [] spatial_shapes = [] # 将各个特征层的feature map,mask等拉直 for lvl, (feat, mask, pos_embed) in enumerate( zip(mlvl_feats, mlvl_masks, mlvl_pos_embeds)): bs, c, h, w = feat.shape spatial_shape = (h, w) spatial_shapes.append(spatial_shape) feat = feat.flatten(2).transpose(1, 2) # [bs,h*w,c] mask = mask.flatten(1) # [bs,h*w] pos_embed = pos_embed.flatten(2).transpose(1, 2) # [bs,h*w,c] lvl_pos_embed = pos_embed + self.level_embeds[lvl].view(1, 1, -1) lvl_pos_embed_flatten.append(lvl_pos_embed) feat_flatten.append(feat) mask_flatten.append(mask) feat_flatten = torch.cat(feat_flatten, 1) # [bs,四层的h*w加起来,c] mask_flatten = torch.cat(mask_flatten, 1) # [bs,四层的h*w加起来] lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1) # [bs,四层的h*w加起来,c] #转成tensor spatial_shapes = torch.as_tensor( spatial_shapes, dtype=torch.long, device=feat_flatten.device) # 记录每一层feature map的起始位置 level_start_index = torch.cat((spatial_shapes.new_zeros( (1, )), spatial_shapes.prod(1).cumsum(0)[:-1])) #得到每张特征图的有效宽高比例 [bs,4(num_levels),2(长和宽)] valid_ratios = torch.stack( [self.get_valid_ratio(m) for m in mlvl_masks], 1)
通过下面的函数获取reference point,最后得到的reference point是在0-1尺度上的值
def get_reference_points(spatial_shapes, valid_ratios, device): """Get the reference points used in decoder. Args: spatial_shapes (Tensor): The shape of all feature maps, has shape (num_level, 2). valid_ratios (Tensor): The radios of valid points on the feature map, has shape (bs, num_levels, 2) device (obj:`device`): The device where reference_points should be. Returns: Tensor: reference points used in decoder, has \ shape (bs, num_keys, num_levels, 2). """ reference_points_list = [] for lvl, (H, W) in enumerate(spatial_shapes): # TODO check this 0.5 # 获取每个reference point中心横纵坐标,加减0.5是确保每个初始点是在每个pixel的中心 ref_y, ref_x = torch.meshgrid( torch.linspace( 0.5, H - 0.5, H, dtype=torch.float32, device=device), torch.linspace( 0.5, W - 0.5, W, dtype=torch.float32, device=device)) # 将横纵坐标进行归一化 ref_y = ref_y.reshape(-1)[None] / ( valid_ratios[:, None, lvl, 1] * H) ref_x = ref_x.reshape(-1)[None] / ( valid_ratios[:, None, lvl, 0] * W) ref = torch.stack((ref_x, ref_y), -1) reference_points_list.append(ref) reference_points = torch.cat(reference_points_list, 1) # 将参考点的位置映射到有效区域 reference_points = reference_points[:, :, None] * valid_ratios[:, None] return reference_points
memory = self.encoder(
query=feat_flatten, # 输入query,是展平后的多尺度feature map [所有H*W的和, bs, 256]
key=None, #在self attention中,k和v是由q算出,因此输入为None
value=None,
query_pos=lvl_pos_embed_flatten, #输入query的位置编码, [所有H*W的和, bs, 256]
query_key_padding_mask=mask_flatten, # padding mask [bs, 所有H*W的和]
spatial_shapes=spatial_shapes, #每层feature map的h和w [num_levels, bs]
reference_points=reference_points, #[bs, 所有H*W的和, num_levels, 2]
level_start_index=level_start_index,# 每层feature map展平后的第一个元素的位置索引 [num_levels]
valid_ratios=valid_ratios, # 每层feature map对应的mask中有效的宽高比 [B, num_levels, 2]
**kwargs)
# memory:encoder的输出,经过自注意力后的多尺度feature map [所有H*W的和, bs, 256]
进入encoder之后会按照在配置文件中的的顺序来
encoder=dict(
type='DetrTransformerEncoder',
num_layers=6,
transformerlayers=dict(
type='BaseTransformerLayer',
attn_cfgs=dict(
type='MultiScaleDeformableAttention', embed_dims=256),
feedforward_channels=1024,
ffn_dropout=0.1,
operation_order=('self_attn', 'norm', 'ffn', 'norm'))),
这里的self-attn变成了MultiScaleDeformableAttention,
MultiScaleDeformableAttention的代码如下:在mmcv\ops\multi_scale_deform_attn.py中
if value is None: value = query if identity is None: identity = query if query_pos is not None: query = query + query_pos if not self.batch_first: # change to (bs, num_query ,embed_dims) query = query.permute(1, 0, 2) value = value.permute(1, 0, 2) bs, num_query, _ = query.shape bs, num_value, _ = value.shape assert (spatial_shapes[:, 0] * spatial_shapes[:, 1]).sum() == num_value # value的值是从query中学到的,最开始的value为None,被赋值为query,然后通过一个线性层得到真正的value [bs,所有H*W的和,256] value = self.value_proj(value) if key_padding_mask is not None: value = value.masked_fill(key_padding_mask[..., None], 0.0) #[bs,所有H*W的和,256] ---> [bs,所有H*W的和,8,32] value = value.view(bs, num_value, self.num_heads, -1) ''' self.sampling_offsets: Linear(in_features=256, out_features=256, bias=True) self.attention_weights: Linear(in_features=256, out_features=128, bias=True) ''' # sampling_offsets : [bs,所有H*W的和, 8, 4, 4, 2] sampling_offsets = self.sampling_offsets(query).view( bs, num_query, self.num_heads, self.num_levels, self.num_points, 2) # attention_weights:[1, 10458, 8, 16] attention_weights = self.attention_weights(query).view( bs, num_query, self.num_heads, self.num_levels * self.num_points) # 为啥要softmax? # 经过一个线性层映射+softmax得到每个query的注意力权重 attention_weights = attention_weights.softmax(-1) #[1, 所有H*W的和, 8, 16] ---> [1,所有H*W的和,8,4,4] attention_weights = attention_weights.view(bs, num_query, self.num_heads, self.num_levels, self.num_points) if reference_points.shape[-1] == 2: # 首先是sampling_offsets / offset_normalizer进行归一化 然后再和reference_points相加 offset_normalizer = torch.stack( [spatial_shapes[..., 1], spatial_shapes[..., 0]], -1) sampling_locations = reference_points[:, :, None, :, None, :] \ + sampling_offsets \ / offset_normalizer[None, None, None, :, None, :] elif reference_points.shape[-1] == 4: sampling_locations = reference_points[:, :, None, :, None, :2] \ + sampling_offsets / self.num_points \ * reference_points[:, :, None, :, None, 2:] \ * 0.5 else: raise ValueError( f'Last dim of reference_points must be' f' 2 or 4, but get {reference_points.shape[-1]} instead.') # 调用cuda算子进行deformable atten if torch.cuda.is_available() and value.is_cuda: output = MultiScaleDeformableAttnFunction.apply( value, spatial_shapes, level_start_index, sampling_locations, attention_weights, self.im2col_step) else: output = multi_scale_deformable_attn_pytorch( value, spatial_shapes, sampling_locations, attention_weights) output = self.output_proj(output) if not self.batch_first: # (num_query, bs ,embed_dims) output = output.permute(1, 0, 2) # 这个identity是上一次的query return self.dropout(output) + identity
在做完multi_scale_deformable_attn之后,会进行norm,ffn,norm,这样一个encoder layer就走完了,这个过程将重复6次,最后返回到DeformableDetrTransformer的forward中,返回值memory为encoder的输出,也即经过multi_scale_deformable_attn后的多尺度feature map,其维度为:[所有H*W的和, bs, 256]
inter_states, inter_references = self.decoder(
query=query, # [num_query,bs,256]
key=None,
value=memory, # encoder的输出 经过encoder后的feature map
query_pos=query_pos,
key_padding_mask=mask_flatten,
reference_points=reference_points,
spatial_shapes=spatial_shapes,
level_start_index=level_start_index,
valid_ratios=valid_ratios,
reg_branches=reg_branches,
**kwargs)
query_pos, query = torch.split(query_embed, c, dim=1) query_pos = query_pos.unsqueeze(0).expand(bs, -1, -1) #[bs,300,256] query = query.unsqueeze(0).expand(bs, -1, -1)#[bs,300,256] # 将query_pos经过一次线性变换+sigmoid正好能作为初始参考点坐标 reference_points = self.reference_points(query_pos).sigmoid() init_reference_out = reference_points # decoder query = query.permute(1, 0, 2) #[300(num_query),bs,256] memory = memory.permute(1, 0, 2) #[所有H*W的和,bs,256] query_pos = query_pos.permute(1, 0, 2)#[300(num_query),bs,256] inter_states, inter_references = self.decoder( query=query, #[300(num_query),bs,256] key=None, value=memory,#经过encoder的feature map query_pos=query_pos, key_padding_mask=mask_flatten, reference_points=reference_points, spatial_shapes=spatial_shapes, level_start_index=level_start_index, valid_ratios=valid_ratios, reg_branches=reg_branches, #None **kwargs)
进入到self.decoder中后,代码跳转到DeformableDetrTransformerDecoder中的forward函数中,在mmdetection/mmdet/models/utils/transformer.py中
output = query intermediate = [] #存储每层decoder layer的query intermediate_reference_points = [] # 用来存储每层decoder layer的reference_points for lid, layer in enumerate(self.layers): if reference_points.shape[-1] == 4: reference_points_input = reference_points[:, :, None] * \ torch.cat([valid_ratios, valid_ratios], -1)[:, None] else: assert reference_points.shape[-1] == 2 reference_points_input = reference_points[:, :, None] * \ valid_ratios[:, None] output = layer( output, # query *args, reference_points=reference_points_input, **kwargs) # kwargs包含了['key', 'value', 'query_pos', 'key_padding_mask', 'spatial_shapes', 'level_start_index'] # key为None ,value为从encoder中得到的memory output = output.permute(1, 0, 2) # reg_branches默认问None if reg_branches is not None: tmp = reg_branches[lid](output) if reference_points.shape[-1] == 4: new_reference_points = tmp + inverse_sigmoid( reference_points) new_reference_points = new_reference_points.sigmoid() else: assert reference_points.shape[-1] == 2 new_reference_points = tmp new_reference_points[..., :2] = tmp[ ..., :2] + inverse_sigmoid(reference_points) new_reference_points = new_reference_points.sigmoid() reference_points = new_reference_points.detach() output = output.permute(1, 0, 2) # 将中间的query和reference_point存下来,query有更新,reference_points其实每一层都是一样的 if self.return_intermediate: intermediate.append(output) intermediate_reference_points.append(reference_points) if self.return_intermediate: # true return torch.stack(intermediate), torch.stack( intermediate_reference_points) return output, reference_points
decoder最后返回两个值,也即所有六层decoder的query和reference_points,每一层的query是不同的,但是每一层的referen_points是相同的
最后整个transformer返回三个值,inter_states,init_reference_out,inter_references_out
inter_states :[num_dec_layers, bs, num_query, embed_dims] 表示每个decode layer的query
init_reference_out : [bs,num_query,2] 表示最开始的reference_points
inter_references_out:[num_dec_layers, bs, num_query, embed_dims] 表示每一层的reference points
在经过了transformer部分之后,代码回到了deformable detr head中
hs = hs.permute(0, 2, 1, 3) outputs_classes = [] outputs_coords = [] # 逐个decoder layer去做预测 for lvl in range(hs.shape[0]): if lvl == 0: reference = init_reference else: reference = inter_references[lvl - 1] reference = inverse_sigmoid(reference) # 做反sigmoid outputs_class = self.cls_branches[lvl](hs[lvl]) # 这里预测出的tmp是相对于reference的offset tmp = self.reg_branches[lvl](hs[lvl]) if reference.shape[-1] == 4: tmp += reference else: assert reference.shape[-1] == 2 tmp[..., :2] += reference #reference与预测出的offset相加 outputs_coord = tmp.sigmoid() outputs_classes.append(outputs_class) outputs_coords.append(outputs_coord) outputs_classes = torch.stack(outputs_classes) outputs_coords = torch.stack(outputs_coords) if self.as_two_stage: return outputs_classes, outputs_coords, \ enc_outputs_class, \ enc_outputs_coord.sigmoid() else: return outputs_classes, outputs_coords, \ None, None
后面就是计算loss了,这部分和DETR应该是一样的,我在DETR的源码阅读中已经写过了,这里就不写了,感兴趣的可以去看我的另一篇博客:DETR源码阅读
encoder时候的只有self_atten,QKV都是feature map
decoder时候,self_atten时候,QKV都是object query([num_query,bs,256])
cross_atten时候,Q是object query V是feature map,K这里是None,因为deformable atten不需要通过Q点乘K来获取attention_weight,其attention_weight是通过object query学出来的
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。