当前位置:   article > 正文

Deformable-DETR代码学习笔记_deformable detr代码

deformable detr代码

先放一张Deformable-DETR架构图

论文地址:https://arxiv.org/pdf/2010.04159.pdf

代码地址:https://github.com/fundamentalvision/Deformable-DETR 

        Deformable-DETR在DETR的基础上有两大创新,加速网络的收敛同时在不用FPN的情况下提升了小目标的检测能力。

一、backbone

        好记性不如烂笔头,有些东西还是要记录下来方便复盘。。。Deformable-DETR代码基本沿用了DETR的工程,主要改了transformer部分的代码,输入的图像经过backbone以及数据进入encoder之前的部分内容可以参看DETR代码学习笔记(一)        

        其中的大部分内容都是通用的,唯一的区别在于backbone中resnet输出的feature map,DETR输出的是[N,2048,H,W]维的tensor,而Deformable-DETR输出resnet最后三层的feature map,channel的维度分别为512,1024,2048,尺度分别除以8,16,32。

        代码上的对比如下:

        和DETR一样,Deformable-DETR不仅仅对最后一层feature map生成mask,同时会对每一层的feature map也会生成对应mask,用于记录原始图像在padding中所占的位置。这部分在DETR学习笔记(一)中有细节的详细讲解,这里就不展开,不清楚的可以跳过去看看。之后提取到的特征图和mask会传入transformer中,作为transformer的输入。

二、encoder

        先从Deformable-DETR的主函数开始:

  1. class DeformableDETR(nn.Module):
  2. """ This is the Deformable DETR module that performs object detection """
  3. def __init__(self, backbone, transformer, num_classes, num_queries, num_feature_levels,
  4. aux_loss=True, with_box_refine=False, two_stage=False):
  5. """ Initializes the model.
  6. Parameters:
  7. backbone: torch module of the backbone to be used. See backbone.py
  8. transformer: torch module of the transformer architecture. See transformer.py
  9. num_classes: number of object classes
  10. num_queries: number of object queries, ie detection slot. This is the maximal number of objects
  11. DETR can detect in a single image. For COCO, we recommend 100 queries.
  12. aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used.
  13. with_box_refine: iterative bounding box refinement
  14. two_stage: two-stage Deformable DETR
  15. """
  16. super().__init__()
  17. self.num_queries = num_queries
  18. self.transformer = transformer
  19. hidden_dim = transformer.d_model
  20. self.class_embed = nn.Linear(hidden_dim, num_classes)
  21. self.bbox_embed = MLP(hidden_dim, hidden_dim, 4, 3)
  22. self.num_feature_levels = num_feature_levels
  23. if not two_stage:
  24. self.query_embed = nn.Embedding(num_queries, hidden_dim*2) # 代码中num_queries为300
  25. if num_feature_levels > 1:
  26. num_backbone_outs = len(backbone.strides)
  27. input_proj_list = []
  28. for _ in range(num_backbone_outs):
  29. in_channels = backbone.num_channels[_]
  30. input_proj_list.append(nn.Sequential(
  31. nn.Conv2d(in_channels, hidden_dim, kernel_size=1),
  32. nn.GroupNorm(32, hidden_dim),
  33. ))
  34. for _ in range(num_feature_levels - num_backbone_outs):
  35. input_proj_list.append(nn.Sequential(
  36. nn.Conv2d(in_channels, hidden_dim, kernel_size=3, stride=2, padding=1),
  37. nn.GroupNorm(32, hidden_dim),
  38. ))
  39. in_channels = hidden_dim
  40. self.input_proj = nn.ModuleList(input_proj_list)
  41. else:
  42. self.input_proj = nn.ModuleList([
  43. nn.Sequential(
  44. nn.Conv2d(backbone.num_channels[0], hidden_dim, kernel_size=1),
  45. nn.GroupNorm(32, hidden_dim),
  46. )])
  47. self.backbone = backbone
  48. self.aux_loss = aux_loss
  49. self.with_box_refine = with_box_refine
  50. self.two_stage = two_stage
  51. prior_prob = 0.01
  52. bias_value = -math.log((1 - prior_prob) / prior_prob)
  53. self.class_embed.bias.data = torch.ones(num_classes) * bias_value
  54. nn.init.constant_(self.bbox_embed.layers[-1].weight.data, 0)
  55. nn.init.constant_(self.bbox_embed.layers[-1].bias.data, 0)
  56. for proj in self.input_proj:
  57. nn.init.xavier_uniform_(proj[0].weight, gain=1)
  58. nn.init.constant_(proj[0].bias, 0)
  59. # if two-stage, the last class_embed and bbox_embed is for region proposal generation
  60. num_pred = (transformer.decoder.num_layers + 1) if two_stage else transformer.decoder.num_layers
  61. if with_box_refine:
  62. self.class_embed = _get_clones(self.class_embed, num_pred)
  63. self.bbox_embed = _get_clones(self.bbox_embed, num_pred)
  64. nn.init.constant_(self.bbox_embed[0].layers[-1].bias.data[2:], -2.0)
  65. # hack implementation for iterative bounding box refinement
  66. self.transformer.decoder.bbox_embed = self.bbox_embed
  67. else:
  68. nn.init.constant_(self.bbox_embed.layers[-1].bias.data[2:], -2.0)
  69. self.class_embed = nn.ModuleList([self.class_embed for _ in range(num_pred)])
  70. self.bbox_embed = nn.ModuleList([self.bbox_embed for _ in range(num_pred)])
  71. self.transformer.decoder.bbox_embed = None
  72. if two_stage:
  73. # hack implementation for two-stage
  74. self.transformer.decoder.class_embed = self.class_embed
  75. for box_embed in self.bbox_embed:
  76. nn.init.constant_(box_embed.layers[-1].bias.data[2:], 0.0)
  77. def forward(self, samples: NestedTensor):
  78. """ The forward expects a NestedTensor, which consists of:
  79. - samples.tensor: batched images, of shape [batch_size x 3 x H x W]
  80. - samples.mask: a binary mask of shape [batch_size x H x W], containing 1 on padded pixels
  81. It returns a dict with the following elements:
  82. - "pred_logits": the classification logits (including no-object) for all queries.
  83. Shape= [batch_size x num_queries x (num_classes + 1)]
  84. - "pred_boxes": The normalized boxes coordinates for all queries, represented as
  85. (center_x, center_y, height, width). These values are normalized in [0, 1],
  86. relative to the size of each individual image (disregarding possible padding).
  87. See PostProcess for information on how to retrieve the unnormalized bounding box.
  88. - "aux_outputs": Optional, only returned when auxilary losses are activated. It is a list of
  89. dictionnaries containing the two above keys for each decoder layer.
  90. """
  91. if not isinstance(samples, NestedTensor):
  92. samples = nested_tensor_from_tensor_list(samples)
  93. features, pos = self.backbone(samples)
  94. srcs = []
  95. masks = []
  96. for l, feat in enumerate(features):
  97. src, mask = feat.decompose()
  98. srcs.append(self.input_proj[l](src)) # 每一层的feature map通过1*1的卷积进行降维[N,512/1024/2048,H,W] -> [N,256,H,W],此处的H和W为对应层的feature map的尺寸
  99. masks.append(mask) # mask的维度始终为[N,H,W]
  100. assert mask is not None
  101. if self.num_feature_levels > len(srcs): # 其中self.num_feature_levels == 4
  102. _len_srcs = len(srcs)
  103. for l in range(_len_srcs, self.num_feature_levels):
  104. if l == _len_srcs:
  105. src = self.input_proj[l](features[-1].tensors) # 取feature map的最后一层进行步长为2的3*3卷进降采样[N,2048,H,W] -> [N,256,H//2,W//2]
  106. else:
  107. src = self.input_proj[l](srcs[-1])
  108. m = samples.mask
  109. mask = F.interpolate(m[None].float(), size=src.shape[-2:]).to(torch.bool)[0] # 对最后一层的feature map做完3*3卷积后,需要得到对应的mask
  110. pos_l = self.backbone[1](NestedTensor(src, mask)).to(src.dtype) # 前三层的mask的位置编码在backbone的前向过程中已经得到了,这里单独对第四层的mask做位置编码
  111. srcs.append(src)
  112. masks.append(mask)
  113. pos.append(pos_l)
  114. query_embeds = None
  115. if not self.two_stage:
  116. query_embeds = self.query_embed.weight # 维度为[300,512]
  117. hs, init_reference, inter_references, enc_outputs_class, enc_outputs_coord_unact = self.transformer(srcs, masks, pos, query_embeds)
  118. outputs_classes = []
  119. outputs_coords = []
  120. for lvl in range(hs.shape[0]):
  121. if lvl == 0:
  122. reference = init_reference
  123. else:
  124. reference = inter_references[lvl - 1]
  125. reference = inverse_sigmoid(reference)
  126. outputs_class = self.class_embed[lvl](hs[lvl]) # 分类[N,300,91]
  127. tmp = self.bbox_embed[lvl](hs[lvl]) # 经过多个Linear得到边界框
  128. if reference.shape[-1] == 4:
  129. tmp += reference
  130. else:
  131. assert reference.shape[-1] == 2
  132. tmp[..., :2] += reference
  133. outputs_coord = tmp.sigmoid() # [N,300,4]
  134. outputs_classes.append(outputs_class)
  135. outputs_coords.append(outputs_coord)
  136. outputs_class = torch.stack(outputs_classes) # [6,N,300,91]
  137. outputs_coord = torch.stack(outputs_coords) # [6,N,300,4]
  138. out = {'pred_logits': outputs_class[-1], 'pred_boxes': outputs_coord[-1]}
  139. if self.aux_loss:
  140. out['aux_outputs'] = self._set_aux_loss(outputs_class, outputs_coord)
  141. if self.two_stage:
  142. enc_outputs_coord = enc_outputs_coord_unact.sigmoid()
  143. out['enc_outputs'] = {'pred_logits': enc_outputs_class, 'pred_boxes': enc_outputs_coord}
  144. return out
  145. @torch.jit.unused
  146. def _set_aux_loss(self, outputs_class, outputs_coord):
  147. # this is a workaround to make torchscript happy, as torchscript
  148. # doesn't support dictionary with non-homogeneous values, such
  149. # as a dict having both a Tensor and a list.
  150. return [{'pred_logits': a, 'pred_boxes': b}
  151. for a, b in zip(outputs_class[:-1], outputs_coord[:-1])]

        self.backbone是resnet50,由其生成的3层特征图(C3-C5)以及实际每层特征图在padding后的特征图上对应的padding mask。

        在C5的基础上用步长为2的3*3,2048维的卷积生成C6,同时生成对应C6长宽尺寸的mask。

       由于代码默认设置的num_queries为300,所以query_embed的维度为[300,512],后面会分成两部分,在后面再细讲。

        接下来进入self.transformer:(这里仅贴了部分关键代码)

  1. class DeformableTransformer(nn.Module):
  2. def __init__(self, d_model=256, nhead=8,
  3. num_encoder_layers=6, num_decoder_layers=6, dim_feedforward=1024, dropout=0.1,
  4. activation="relu", return_intermediate_dec=False,
  5. num_feature_levels=4, dec_n_points=4, enc_n_points=4,
  6. two_stage=False, two_stage_num_proposals=300):
  7. super().__init__()
  8. self.d_model = d_model
  9. self.nhead = nhead
  10. self.two_stage = two_stage
  11. self.two_stage_num_proposals = two_stage_num_proposals
  12. encoder_layer = DeformableTransformerEncoderLayer(d_model, dim_feedforward,
  13. dropout, activation,
  14. num_feature_levels, nhead, enc_n_points)
  15. self.encoder = DeformableTransformerEncoder(encoder_layer, num_encoder_layers)
  16. decoder_layer = DeformableTransformerDecoderLayer(d_model, dim_feedforward,
  17. dropout, activation,
  18. num_feature_levels, nhead, dec_n_points)
  19. self.decoder = DeformableTransformerDecoder(decoder_layer, num_decoder_layers, return_intermediate_dec)
  20. self.level_embed = nn.Parameter(torch.Tensor(num_feature_levels, d_model))
  21. if two_stage:
  22. self.enc_output = nn.Linear(d_model, d_model)
  23. self.enc_output_norm = nn.LayerNorm(d_model)
  24. self.pos_trans = nn.Linear(d_model * 2, d_model * 2)
  25. self.pos_trans_norm = nn.LayerNorm(d_model * 2)
  26. else:
  27. self.reference_points = nn.Linear(d_model, 2)
  28. self._reset_parameters()
  29. def get_valid_ratio(self, mask):
  30. _, H, W = mask.shape
  31. valid_H = torch.sum(~mask[:, :, 0], 1) # 取feature map中非padding部分的H
  32. valid_W = torch.sum(~mask[:, 0, :], 1) # 取feature map中非padding部分的W
  33. valid_ratio_h = valid_H.float() / H # 计算feature map中非padding部分的H在当前batch下feature map中的H所占的比例
  34. valid_ratio_w = valid_W.float() / W # 计算feature map中非padding部分的W在当前batch下feature map中的W所占的比例
  35. valid_ratio = torch.stack([valid_ratio_w, valid_ratio_h], -1)
  36. return valid_ratio
  37. def forward(self, srcs, masks, pos_embeds, query_embed=None):
  38. assert self.two_stage or query_embed is not None
  39. # prepare input for encoder
  40. src_flatten = []
  41. mask_flatten = []
  42. lvl_pos_embed_flatten = []
  43. spatial_shapes = []
  44. for lvl, (src, mask, pos_embed) in enumerate(zip(srcs, masks, pos_embeds)):
  45. bs, c, h, w = src.shape
  46. spatial_shape = (h, w)
  47. spatial_shapes.append(spatial_shape)
  48. src = src.flatten(2).transpose(1, 2) # 将H和W打平 [N,256,H,W] -> [N,H*W,256]
  49. mask = mask.flatten(1) # [N,H,W] -> [N,H*W]
  50. pos_embed = pos_embed.flatten(2).transpose(1, 2) # 同样将H和W打平 [N,256,H,W] -> [N,H*W,256]
  51. lvl_pos_embed = pos_embed + self.level_embed[lvl].view(1, 1, -1) # self.level_embed是一个[4,256]的tensor
  52. lvl_pos_embed_flatten.append(lvl_pos_embed)
  53. src_flatten.append(src)
  54. mask_flatten.append(mask)
  55. src_flatten = torch.cat(src_flatten, 1) # 将打平后的tensor cat在一起
  56. mask_flatten = torch.cat(mask_flatten, 1)
  57. lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1)
  58. spatial_shapes = torch.as_tensor(spatial_shapes, dtype=torch.long, device=src_flatten.device) # 存放着每一层feature map的[H,W],维度为[4,2]
  59. level_start_index = torch.cat((spatial_shapes.new_zeros((1, )), spatial_shapes.prod(1).cumsum(0)[:-1])) # cat在一起后feature map的起始索引,如:第一层是0,第二层是H1*W1+0,第三层是H2*W2+H1*W1+0,最后一层H3*W3+H2*W2+H1*W1+0 共4维
  60. valid_ratios = torch.stack([self.get_valid_ratio(m) for m in masks], 1) # 输出一个[N,4,2]的tensor,表示每一层的feature map中对应的非padding部分的有效长宽与该层feature map长宽的比值
  61. # encoder 输出的memory的维度[N,H*W,256] 其中的H*W是四层feature map尺寸相乘并累和的结果
  62. memory = self.encoder(src_flatten, spatial_shapes, level_start_index, valid_ratios, lvl_pos_embed_flatten, mask_flatten) # [N,len_q,256] len_q为四层feature map的H*W的和
  63. # prepare input for decoder
  64. bs, _, c = memory.shape
  65. if self.two_stage:
  66. output_memory, output_proposals = self.gen_encoder_output_proposals(memory, mask_flatten, spatial_shapes)
  67. # hack implementation for two-stage Deformable DETR
  68. enc_outputs_class = self.decoder.class_embed[self.decoder.num_layers](output_memory)
  69. enc_outputs_coord_unact = self.decoder.bbox_embed[self.decoder.num_layers](output_memory) + output_proposals
  70. topk = self.two_stage_num_proposals
  71. topk_proposals = torch.topk(enc_outputs_class[..., 0], topk, dim=1)[1]
  72. topk_coords_unact = torch.gather(enc_outputs_coord_unact, 1, topk_proposals.unsqueeze(-1).repeat(1, 1, 4))
  73. topk_coords_unact = topk_coords_unact.detach()
  74. reference_points = topk_coords_unact.sigmoid()
  75. init_reference_out = reference_points
  76. pos_trans_out = self.pos_trans_norm(self.pos_trans(self.get_proposal_pos_embed(topk_coords_unact)))
  77. query_embed, tgt = torch.split(pos_trans_out, c, dim=2)
  78. else:
  79. query_embed, tgt = torch.split(query_embed, c, dim=1) # 将query_embed([300,512])拆分成两个[300,256]的tensor,query_embed和tgt,这个query_embed可以理解为基于纯卷积目标检测中的anchor,提供一个位置
  80. query_embed = query_embed.unsqueeze(0).expand(bs, -1, -1) # query_embed由[300,256] -> [N,300,256]
  81. tgt = tgt.unsqueeze(0).expand(bs, -1, -1) # tgt由[300,256] -> [N,300,256]
  82. reference_points = self.reference_points(query_embed).sigmoid() # 由query_embed经过一个Linear(256,2),reference_points的维度为[N,300,2]
  83. init_reference_out = reference_points
  84. # query_embed分离出维度减半的query_embed,tgt,维度减半的query_embed再经过Linear得到reference_points
  85. # decoder
  86. hs, inter_references = self.decoder(tgt, reference_points, memory,
  87. spatial_shapes, level_start_index, valid_ratios, query_embed, mask_flatten)
  88. # hs维度[6,N,300,256],inter_references维度[6,N,300,256]
  89. inter_references_out = inter_references
  90. if self.two_stage:
  91. return hs, init_reference_out, inter_references_out, enc_outputs_class, enc_outputs_coord_unact
  92. return hs, init_reference_out, inter_references_out, None, None

输入encoder之前,还做了一些预处理工作:

1、把四层feature map整合成query,假设C2的尺寸为[H,W],那么它的维度为len_q = H*W + H//2*W//2 + H//4*W//4 + H//8*W//8,最终的维度为[N,len_q,256],其中N为batch size

2、mask的维度对其query,为[N,len_q]

3、spatial_shapes记录了四层feature map的尺寸

4、level_start_index记录cat在一起后feature map的起始索引,如:第一层是0,第二层是H1*W1+0,第三层是H2*W2+H1*W1+0,最后一层H3*W3+H2*W2+H1*W1+0 共4维

5、valid_ratios输出一个[N,4,2]的tensor,表示每一层的feature map中对应的非padding部分(实际有效feature map)的有效长宽与该层feature map长宽的比值

  1. class DeformableTransformerEncoder(nn.Module):
  2. def __init__(self, encoder_layer, num_layers):
  3. super().__init__()
  4. self.layers = _get_clones(encoder_layer, num_layers)
  5. self.num_layers = num_layers
  6. @staticmethod
  7. def get_reference_points(spatial_shapes, valid_ratios, device):
  8. reference_points_list = []
  9. for lvl, (H_, W_) in enumerate(spatial_shapes): # 遍历feature map,第0层是尺寸最大的feature map
  10. # 根据feature map的尺寸生成网格,生成每个像素点的中心点归一化后的x,y坐标
  11. ref_y, ref_x = torch.meshgrid(torch.linspace(0.5, H_ - 0.5, H_, dtype=torch.float32, device=device),
  12. torch.linspace(0.5, W_ - 0.5, W_, dtype=torch.float32, device=device))
  13. ref_y = ref_y.reshape(-1)[None] / (valid_ratios[:, None, lvl, 1] * H_)
  14. ref_x = ref_x.reshape(-1)[None] / (valid_ratios[:, None, lvl, 0] * W_)
  15. ref = torch.stack((ref_x, ref_y), -1)
  16. reference_points_list.append(ref)
  17. reference_points = torch.cat(reference_points_list, 1) # 再将所有的归一化后的中心点坐标cat在一起
  18. reference_points = reference_points[:, :, None] * valid_ratios[:, None] # 归一化的x,y坐标乘实际feature map有效区域的比值,得到每个中心点在实际feature map上归一化的坐标
  19. return reference_points
  20. def forward(self, src, spatial_shapes, level_start_index, valid_ratios, pos=None, padding_mask=None):
  21. output = src
  22. reference_points = self.get_reference_points(spatial_shapes, valid_ratios, device=src.device)
  23. for _, layer in enumerate(self.layers): # output[N,len_q,256],reference_points[N,len_q,4,2]在len_q上每一个feature map对应的像素点上取4个点
  24. output = layer(output, pos, reference_points, spatial_shapes, level_start_index, padding_mask)
  25. return output

        其中reference_points的shape为[N,len_q,4,2],得到的是在每一层特征图中的相对位置

  1. class DeformableTransformerEncoderLayer(nn.Module):
  2. def __init__(self,
  3. d_model=256, d_ffn=1024,
  4. dropout=0.1, activation="relu",
  5. n_levels=4, n_heads=8, n_points=4):
  6. super().__init__()
  7. # self attention
  8. self.self_attn = MSDeformAttn(d_model, n_levels, n_heads, n_points)
  9. self.dropout1 = nn.Dropout(dropout)
  10. self.norm1 = nn.LayerNorm(d_model)
  11. # ffn
  12. self.linear1 = nn.Linear(d_model, d_ffn)
  13. self.activation = _get_activation_fn(activation)
  14. self.dropout2 = nn.Dropout(dropout)
  15. self.linear2 = nn.Linear(d_ffn, d_model)
  16. self.dropout3 = nn.Dropout(dropout)
  17. self.norm2 = nn.LayerNorm(d_model)
  18. @staticmethod
  19. def with_pos_embed(tensor, pos):
  20. return tensor if pos is None else tensor + pos
  21. def forward_ffn(self, src):
  22. src2 = self.linear2(self.dropout2(self.activation(self.linear1(src))))
  23. src = src + self.dropout3(src2)
  24. src = self.norm2(src)
  25. return src
  26. def forward(self, src, pos, reference_points, spatial_shapes, level_start_index, padding_mask=None):
  27. # self attention
  28. src2 = self.self_attn(self.with_pos_embed(src, pos), reference_points, src, spatial_shapes, level_start_index, padding_mask)
  29. src = src + self.dropout1(src2)
  30. src = self.norm1(src)
  31. # ffn
  32. src = self.forward_ffn(src)
  33. return src

        encoder的图解:

        encoderlayer和DETR中的大致一样(图解中的小图在下面有放大,这也是论文中的图解),这里重点讲下MSDeformAttn

  1. class MSDeformAttn(nn.Module):
  2. def __init__(self, d_model=256, n_levels=4, n_heads=8, n_points=4):
  3. """
  4. Multi-Scale Deformable Attention Module
  5. :param d_model hidden dimension
  6. :param n_levels number of feature levels
  7. :param n_heads number of attention heads
  8. :param n_points number of sampling points per attention head per feature level
  9. """
  10. super().__init__()
  11. if d_model % n_heads != 0:
  12. raise ValueError('d_model must be divisible by n_heads, but got {} and {}'.format(d_model, n_heads))
  13. _d_per_head = d_model // n_heads
  14. # you'd better set _d_per_head to a power of 2 which is more efficient in our CUDA implementation
  15. if not _is_power_of_2(_d_per_head):
  16. warnings.warn("You'd better set d_model in MSDeformAttn to make the dimension of each attention head a power of 2 "
  17. "which is more efficient in our CUDA implementation.")
  18. self.im2col_step = 64
  19. self.d_model = d_model
  20. self.n_levels = n_levels
  21. self.n_heads = n_heads
  22. self.n_points = n_points
  23. self.sampling_offsets = nn.Linear(d_model, n_heads * n_levels * n_points * 2) # 每个head为每个level产生n_point(文中为4)个点的偏置
  24. self.attention_weights = nn.Linear(d_model, n_heads * n_levels * n_points) # 每个位置点的权重,由网络直接生成
  25. self.value_proj = nn.Linear(d_model, d_model)
  26. self.output_proj = nn.Linear(d_model, d_model)
  27. self._reset_parameters()
  28. def _reset_parameters(self):
  29. constant_(self.sampling_offsets.weight.data, 0.)
  30. thetas = torch.arange(self.n_heads, dtype=torch.float32) * (2.0 * math.pi / self.n_heads)
  31. grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
  32. grid_init = (grid_init / grid_init.abs().max(-1, keepdim=True)[0]).view(self.n_heads, 1, 1, 2).repeat(1, self.n_levels, self.n_points, 1)
  33. for i in range(self.n_points):
  34. grid_init[:, :, i, :] *= i + 1 # 每个level每个point偏置对应的head进行编码
  35. with torch.no_grad():
  36. self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1)) # 对不同的偏置进行编码, 不同点的编码不同但不同level是相同的
  37. constant_(self.attention_weights.weight.data, 0.)
  38. constant_(self.attention_weights.bias.data, 0.)
  39. xavier_uniform_(self.value_proj.weight.data)
  40. constant_(self.value_proj.bias.data, 0.)
  41. xavier_uniform_(self.output_proj.weight.data)
  42. constant_(self.output_proj.bias.data, 0.)
  43. def forward(self, query, reference_points, input_flatten, input_spatial_shapes, input_level_start_index, input_padding_mask=None):
  44. """
  45. :param query (N, Length_{query}, C)
  46. :param reference_points (N, Length_{query}, n_levels, 2), range in [0, 1], top-left (0,0), bottom-right (1, 1), including padding area
  47. or (N, Length_{query}, n_levels, 4), add additional (w, h) to form reference boxes
  48. :param input_flatten (N, \sum_{l=0}^{L-1} H_l \cdot W_l, C)
  49. :param input_spatial_shapes (n_levels, 2), [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})]
  50. :param input_level_start_index (n_levels, ), [0, H_0*W_0, H_0*W_0+H_1*W_1, H_0*W_0+H_1*W_1+H_2*W_2, ..., H_0*W_0+H_1*W_1+...+H_{L-1}*W_{L-1}]
  51. :param input_padding_mask (N, \sum_{l=0}^{L-1} H_l \cdot W_l), True for padding elements, False for non-padding elements
  52. :return output (N, Length_{query}, C)
  53. """
  54. # 该函数就是将加了pos_embeds的srcs作为query传入
  55. # 每一个query在特征图上对应一个reference_point,基于每个reference_point再选取n = 4(源码中设置)
  56. # 个keys,根据Linear生成的attention_weights进行特征融合(注意力权重不是Q * k算来的,而是对query直接Linear得到的)。
  57. # 这样大大提高了收敛速度,有选择性的注意Sparse区域来训练attention
  58. N, Len_q, _ = query.shape
  59. N, Len_in, _ = input_flatten.shape # Len_in的大小取决于当前batch中四层feature map的尺寸,假设第一层的feature map大小为H*W,则Len_q=H*W+H//2*W//2+H//4*W//4+H//8*W//8
  60. assert (input_spatial_shapes[:, 0] * input_spatial_shapes[:, 1]).sum() == Len_in
  61. value = self.value_proj(input_flatten) # 输入经过一个Linear层,维度由[N,Len_in,256] -> [N,Len_in,256],得到value
  62. if input_padding_mask is not None:
  63. value = value.masked_fill(input_padding_mask[..., None], float(0)) # 在value中,mask中对应元素为True的位置都用0填充
  64. value = value.view(N, Len_in, self.n_heads, self.d_model // self.n_heads) # value的shape由[N,Len_in,256] -> [N,Len_in,8,32]
  65. sampling_offsets = self.sampling_offsets(query).view(N, Len_q, self.n_heads, self.n_levels, self.n_points, 2) # 每个query产生对应不同head不同level的偏置,sampling_offsets的shape由[N,Len_q,256] -> [N,Len_q,8,4,4,2]
  66. attention_weights = self.attention_weights(query).view(N, Len_q, self.n_heads, self.n_levels * self.n_points) # 每个偏置向量的权重,经过Linear(256,128),attention_weights的shape由[N,Len_q,256] -> [N,Len_q,8,16]
  67. attention_weights = F.softmax(attention_weights, -1).view(N, Len_q, self.n_heads, self.n_levels, self.n_points) # 对属于同一个query的来自与不同level的向量权重在每个head分别归一化,softmax后attention_weights的shape由[N,Len_q,8,16] -> [N,Len_q,8,4,4]
  68. # N, Len_q, n_heads, n_levels, n_points, 2
  69. if reference_points.shape[-1] == 2:
  70. offset_normalizer = torch.stack([input_spatial_shapes[..., 1], input_spatial_shapes[..., 0]], -1) # offset_normalizer 将input_spatial_shapes中[H,W]的形式转化为[W,H]
  71. sampling_locations = reference_points[:, :, None, :, None, :] \
  72. + sampling_offsets / offset_normalizer[None, None, None, :, None, :] # 采样点的坐标[N,Len_q,8,4,4,2]
  73. elif reference_points.shape[-1] == 4:
  74. sampling_locations = reference_points[:, :, None, :, None, :2] \
  75. + sampling_offsets / self.n_points * reference_points[:, :, None, :, None, 2:] * 0.5
  76. else:
  77. raise ValueError(
  78. 'Last dim of reference_points must be 2 or 4, but get {} instead.'.format(reference_points.shape[-1]))
  79. output = MSDeformAttnFunction.apply(
  80. value, input_spatial_shapes, input_level_start_index, sampling_locations, attention_weights, self.im2col_step)
  81. output = self.output_proj(output) # 输出经过一个Linear层,维度由[N,Len_q,256] -> [N,Len_q,256]
  82. return output

        源码中n_head设置为8,d_model为256,n_levels为4,n_points为4。

        MSDeformAttn函数就是将加了pos_embeds的srcs作为query传入,每一个query在特征图上对应一个reference_point,基于每个reference_point再选取n = 4个keys,根据Linear生成的attention_weights进行特征融合(注意力权重不是Q * k算来的,而是对query直接Linear得到的)。sampling_offsets,attention_weights的具体信息在上面的代码段中有标注,这里就不多说了。

        MSDeformAttnFunction调用的是cuda编程,不过代码里头有一个pytorch的实现:

  1. def ms_deform_attn_core_pytorch(value, value_spatial_shapes, sampling_locations, attention_weights):
  2. # for debug and test only,
  3. # need to use cuda version instead
  4. N_, S_, M_, D_ = value.shape # value shpae [N,len_q,8,32]
  5. _, Lq_, M_, L_, P_, _ = sampling_locations.shape # shape [N,len_q,8,4,4,2]
  6. value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes], dim=1) # 区分每个feature map level
  7. sampling_grids = 2 * sampling_locations - 1
  8. sampling_value_list = []
  9. for lid_, (H_, W_) in enumerate(value_spatial_shapes):
  10. # N_, H_*W_, M_, D_ -> N_, H_*W_, M_*D_ -> N_, M_*D_, H_*W_ -> N_*M_, D_, H_, W_
  11. value_l_ = value_list[lid_].flatten(2).transpose(1, 2).reshape(N_*M_, D_, H_, W_) # [N,H_*W_,8,32] -> [N*8,32,H_,W_]
  12. # N_, Lq_, M_, P_, 2 -> N_, M_, Lq_, P_, 2 -> N_*M_, Lq_, P_, 2
  13. sampling_grid_l_ = sampling_grids[:, :, :, lid_].transpose(1, 2).flatten(0, 1)
  14. # N_*M_, D_, Lq_, P_
  15. # F.grid_sample这个函数的作用就是给定输入input和网格grid,根据grid中的像素位置从input中取出对应位置的值(可能需要插值)得到输出output。
  16. sampling_value_l_ = F.grid_sample(value_l_, sampling_grid_l_,
  17. mode='bilinear', padding_mode='zeros', align_corners=False)
  18. sampling_value_list.append(sampling_value_l_)
  19. # (N_, Lq_, M_, L_, P_) -> (N_, M_, Lq_, L_, P_) -> (N_, M_, 1, Lq_, L_*P_)
  20. attention_weights = attention_weights.transpose(1, 2).reshape(N_*M_, 1, Lq_, L_*P_) # shape [N,len_q,8,4,4] -> [N*8,1,len_q,16]
  21. output = (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights).sum(-1).view(N_, M_*D_, Lq_) # 对应上论文中的公式
  22. return output.transpose(1, 2).contiguous()

 论文中的示意图: 

对应的公式: 

 三、decoder

        由encoder生成memory(shape[N,len_q,256])之后进入decoder,在这之前还需要做一些数据预处理。

        这里仅贴DeformableTransformer的部分代码

  1. bs, _, c = memory.shape
  2. if self.two_stage:
  3. pass
  4. else:
  5. query_embed, tgt = torch.split(query_embed, c,
  6. dim=1) # 将query_embed([300,512])拆分成两个[300,256]的tensor,query_embed和tgt,这个query_embed可以理解为基于纯卷积目标检测中的anchor,提供一个位置
  7. query_embed = query_embed.unsqueeze(0).expand(bs, -1, -1) # query_embed由[300,256] -> [2,300,256]
  8. tgt = tgt.unsqueeze(0).expand(bs, -1, -1) # tgt由[300,256] -> [2,300,256]
  9. reference_points = self.reference_points(
  10. query_embed).sigmoid() # 由query_embed经过一个Linear(256,2),reference_points的维度为[2,300,2]
  11. init_reference_out = reference_points
  12. # query_embed分离出维度减半的query_embed,tgt,维度减半的query_embed再经过Linear得到reference_points

 之后进入decoder

  1. class DeformableTransformerDecoder(nn.Module):
  2. def __init__(self, decoder_layer, num_layers, return_intermediate=False):
  3. super().__init__()
  4. self.layers = _get_clones(decoder_layer, num_layers)
  5. self.num_layers = num_layers
  6. self.return_intermediate = return_intermediate
  7. # hack implementation for iterative bounding box refinement and two-stage Deformable DETR
  8. self.bbox_embed = None
  9. self.class_embed = None
  10. def forward(self, tgt, reference_points, src, src_spatial_shapes, src_level_start_index, src_valid_ratios,
  11. query_pos=None, src_padding_mask=None):
  12. output = tgt
  13. intermediate = []
  14. intermediate_reference_points = []
  15. for lid, layer in enumerate(self.layers):
  16. if reference_points.shape[-1] == 4:
  17. reference_points_input = reference_points[:, :, None] \
  18. * torch.cat([src_valid_ratios, src_valid_ratios], -1)[:, None]
  19. else:
  20. assert reference_points.shape[-1] == 2
  21. reference_points_input = reference_points[:, :, None] * src_valid_ratios[:, None] # reference_points_input维度为[N,300,4,2]
  22. output = layer(output, query_pos, reference_points_input, src, src_spatial_shapes, src_level_start_index, src_padding_mask)
  23. # hack implementation for iterative bounding box refinement
  24. if self.bbox_embed is not None:
  25. tmp = self.bbox_embed[lid](output)
  26. if reference_points.shape[-1] == 4:
  27. new_reference_points = tmp + inverse_sigmoid(reference_points)
  28. new_reference_points = new_reference_points.sigmoid()
  29. else:
  30. assert reference_points.shape[-1] == 2
  31. new_reference_points = tmp
  32. new_reference_points[..., :2] = tmp[..., :2] + inverse_sigmoid(reference_points)
  33. new_reference_points = new_reference_points.sigmoid()
  34. reference_points = new_reference_points.detach()
  35. if self.return_intermediate:
  36. intermediate.append(output)
  37. intermediate_reference_points.append(reference_points)
  38. if self.return_intermediate:
  39. return torch.stack(intermediate), torch.stack(intermediate_reference_points)
  40. return output, reference_points

        在decoderlayer中,首先进行多头自注意的计算,得到一个query,这个query作为cross attn的tgt同时加上位置编码,此时cross attn同样使用的MSDeformAttn

  1. class DeformableTransformerDecoderLayer(nn.Module):
  2. def __init__(self, d_model=256, d_ffn=1024,
  3. dropout=0.1, activation="relu",
  4. n_levels=4, n_heads=8, n_points=4):
  5. super().__init__()
  6. # cross attention
  7. self.cross_attn = MSDeformAttn(d_model, n_levels, n_heads, n_points)
  8. self.dropout1 = nn.Dropout(dropout)
  9. self.norm1 = nn.LayerNorm(d_model)
  10. # self attention
  11. self.self_attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout)
  12. self.dropout2 = nn.Dropout(dropout)
  13. self.norm2 = nn.LayerNorm(d_model)
  14. # ffn
  15. self.linear1 = nn.Linear(d_model, d_ffn)
  16. self.activation = _get_activation_fn(activation)
  17. self.dropout3 = nn.Dropout(dropout)
  18. self.linear2 = nn.Linear(d_ffn, d_model)
  19. self.dropout4 = nn.Dropout(dropout)
  20. self.norm3 = nn.LayerNorm(d_model)
  21. @staticmethod
  22. def with_pos_embed(tensor, pos):
  23. return tensor if pos is None else tensor + pos
  24. def forward_ffn(self, tgt):
  25. tgt2 = self.linear2(self.dropout3(self.activation(self.linear1(tgt))))
  26. tgt = tgt + self.dropout4(tgt2)
  27. tgt = self.norm3(tgt)
  28. return tgt
  29. def forward(self, tgt, query_pos, reference_points, src, src_spatial_shapes, level_start_index, src_padding_mask=None):
  30. # self attention src是encoder生成的memory
  31. q = k = self.with_pos_embed(tgt, query_pos) # tgt, query_pos都是由query_embed(300,512)生成
  32. tgt2 = self.self_attn(q.transpose(0, 1), k.transpose(0, 1), tgt.transpose(0, 1))[0].transpose(0, 1)
  33. tgt = tgt + self.dropout2(tgt2)
  34. tgt = self.norm2(tgt)
  35. # cross attention
  36. tgt2 = self.cross_attn(self.with_pos_embed(tgt, query_pos),
  37. reference_points,
  38. src, src_spatial_shapes, level_start_index, src_padding_mask)
  39. tgt = tgt + self.dropout1(tgt2)
  40. tgt = self.norm1(tgt)
  41. # ffn
  42. tgt = self.forward_ffn(tgt)
  43. return tgt

        decoder的图解:

        其中的tgt, query_pos都是由query_embed(300,512)生成,其中分离出来的tgt仅仅在第一次进入decoder时作为tgt输入,后面的tgt都是decoder的输出作为tgt        

        完整的架构图

        之后在DeformableDETR中得进行分类并回归边界框位置,相当于网络输出预测是长、宽和基于reference_point的偏移量,如以下代码:

  1. outputs_classes = []
  2. outputs_coords = []
  3. for lvl in range(hs.shape[0]):
  4. if lvl == 0:
  5. reference = init_reference
  6. else:
  7. reference = inter_references[lvl - 1]
  8. reference = inverse_sigmoid(reference)
  9. outputs_class = self.class_embed[lvl](hs[lvl]) # 分类[N,300,91]
  10. tmp = self.bbox_embed[lvl](hs[lvl]) # 经过多个Linear得到边界框
  11. if reference.shape[-1] == 4:
  12. tmp += reference
  13. else:
  14. assert reference.shape[-1] == 2
  15. tmp[..., :2] += reference
  16. outputs_coord = tmp.sigmoid() # [N,300,4]
  17. outputs_classes.append(outputs_class)
  18. outputs_coords.append(outputs_coord)
  19. outputs_class = torch.stack(outputs_classes) # [6,N,300,91]
  20. outputs_coord = torch.stack(outputs_coords) # [6,N,300,4]

        最后就是计算loss, 这部分大致和DETR一样,可以参看DETR代码学习笔记(三)

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/菜鸟追梦旅行/article/detail/380382
推荐阅读
相关标签
  

闽ICP备14008679号