当前位置:   article > 正文

Mask2former代码详解

mask2former代码

1.整体流程

        Mask2former流程如图所示,对于输入图片,首先经过Resnet等骨干网络获得多层级特征,对于获得的多层级特征,一个方向经过pixel decoder(基于DetrTransformerEncoderLayer)得到per-pixel embedding,另外一个方向经过transformer decoder,得到mask embedding,矩阵乘法得到mask pediction,对于语义分割任务使用class prediction和mask prediction做矩阵乘法得到预测结果。

2.backbone

        可以使用resnet等作为backbone,获得多层级特征。

3.pixel decoder

        这个模块进行解码阶段的特征提取,在Mask2former中,为了减少计算量和加速收敛,采用了deformable detr的transformer的设计。具体包括:

  • 多层级特征的预处理,包括维度变换、采样点和位置编码
  • 使用deformable transformer进行特征提取
  • 对特征图进行上采样,并进行特征融合,并根据最后一层特征图学习一个mask

整体代码如下:

  1. class MSDeformAttnPixelDecoder(BaseModule):
  2. """Pixel decoder with multi-scale deformable attention.
  3. Args:
  4. in_channels (list[int] | tuple[int]): Number of channels in the
  5. input feature maps.
  6. strides (list[int] | tuple[int]): Output strides of feature from
  7. backbone.
  8. feat_channels (int): Number of channels for feature.
  9. out_channels (int): Number of channels for output.
  10. num_outs (int): Number of output scales.
  11. norm_cfg (:obj:`ConfigDict` or dict): Config for normalization.
  12. Defaults to dict(type='GN', num_groups=32).
  13. act_cfg (:obj:`ConfigDict` or dict): Config for activation.
  14. Defaults to dict(type='ReLU').
  15. encoder (:obj:`ConfigDict` or dict): Config for transformer
  16. encoder. Defaults to None.
  17. positional_encoding (:obj:`ConfigDict` or dict): Config for
  18. transformer encoder position encoding. Defaults to
  19. dict(num_feats=128, normalize=True).
  20. init_cfg (:obj:`ConfigDict` or dict or list[:obj:`ConfigDict` or \
  21. dict], optional): Initialization config dict. Defaults to None.
  22. """
  23. def __init__(self,
  24. in_channels: Union[List[int],
  25. Tuple[int]] = [256, 512, 1024, 2048],
  26. strides: Union[List[int], Tuple[int]] = [4, 8, 16, 32],
  27. feat_channels: int = 256,
  28. out_channels: int = 256,
  29. num_outs: int = 3,
  30. norm_cfg: ConfigType = dict(type='GN', num_groups=32),
  31. act_cfg: ConfigType = dict(type='ReLU'),
  32. encoder: ConfigType = None,
  33. positional_encoding: ConfigType = dict(
  34. num_feats=128, normalize=True),
  35. init_cfg: OptMultiConfig = None) -> None:
  36. super().__init__(init_cfg=init_cfg)
  37. self.strides = strides
  38. self.num_input_levels = len(in_channels)
  39. self.num_encoder_levels = \
  40. encoder.layer_cfg.self_attn_cfg.num_levels
  41. assert self.num_encoder_levels >= 1, \
  42. 'num_levels in attn_cfgs must be at least one'
  43. input_conv_list = []
  44. # from top to down (low to high resolution)
  45. for i in range(self.num_input_levels - 1,
  46. self.num_input_levels - self.num_encoder_levels - 1,
  47. -1):
  48. input_conv = ConvModule(
  49. in_channels[i],
  50. feat_channels,
  51. kernel_size=1,
  52. norm_cfg=norm_cfg,
  53. act_cfg=None,
  54. bias=True)
  55. input_conv_list.append(input_conv)
  56. self.input_convs = ModuleList(input_conv_list)
  57. self.encoder = Mask2FormerTransformerEncoder(**encoder)
  58. self.postional_encoding = SinePositionalEncoding(**positional_encoding)
  59. # high resolution to low resolution
  60. self.level_encoding = nn.Embedding(self.num_encoder_levels,
  61. feat_channels)
  62. # fpn-like structure
  63. self.lateral_convs = ModuleList()
  64. self.output_convs = ModuleList()
  65. self.use_bias = norm_cfg is None
  66. # from top to down (low to high resolution)
  67. # fpn for the rest features that didn't pass in encoder
  68. for i in range(self.num_input_levels - self.num_encoder_levels - 1, -1,
  69. -1):
  70. lateral_conv = ConvModule(
  71. in_channels[i],
  72. feat_channels,
  73. kernel_size=1,
  74. bias=self.use_bias,
  75. norm_cfg=norm_cfg,
  76. act_cfg=None)
  77. output_conv = ConvModule(
  78. feat_channels,
  79. feat_channels,
  80. kernel_size=3,
  81. stride=1,
  82. padding=1,
  83. bias=self.use_bias,
  84. norm_cfg=norm_cfg,
  85. act_cfg=act_cfg)
  86. self.lateral_convs.append(lateral_conv)
  87. self.output_convs.append(output_conv)
  88. self.mask_feature = Conv2d(
  89. feat_channels, out_channels, kernel_size=1, stride=1, padding=0)
  90. self.num_outs = num_outs
  91. self.point_generator = MlvlPointGenerator(strides)
  92. def init_weights(self) -> None:
  93. """Initialize weights."""
  94. for i in range(0, self.num_encoder_levels):
  95. xavier_init(
  96. self.input_convs[i].conv,
  97. gain=1,
  98. bias=0,
  99. distribution='uniform')
  100. for i in range(0, self.num_input_levels - self.num_encoder_levels):
  101. caffe2_xavier_init(self.lateral_convs[i].conv, bias=0)
  102. caffe2_xavier_init(self.output_convs[i].conv, bias=0)
  103. caffe2_xavier_init(self.mask_feature, bias=0)
  104. normal_init(self.level_encoding, mean=0, std=1)
  105. for p in self.encoder.parameters():
  106. if p.dim() > 1:
  107. nn.init.xavier_normal_(p)
  108. # init_weights defined in MultiScaleDeformableAttention
  109. for m in self.encoder.layers.modules():
  110. if isinstance(m, MultiScaleDeformableAttention):
  111. m.init_weights()
  112. def forward(self, feats: List[Tensor]) -> Tuple[Tensor, Tensor]:
  113. """
  114. Args:
  115. feats (list[Tensor]): Feature maps of each level. Each has
  116. shape of (batch_size, c, h, w).
  117. Returns:
  118. tuple: A tuple containing the following:
  119. - mask_feature (Tensor): shape (batch_size, c, h, w).
  120. - multi_scale_features (list[Tensor]): Multi scale \
  121. features, each in shape (batch_size, c, h, w).
  122. """
  123. # generate padding mask for each level, for each image
  124. batch_size = feats[0].shape[0]
  125. encoder_input_list = []
  126. padding_mask_list = []
  127. level_positional_encoding_list = []
  128. spatial_shapes = []
  129. reference_points_list = []
  130. for i in range(self.num_encoder_levels):
  131. level_idx = self.num_input_levels - i - 1
  132. feat = feats[level_idx]
  133. feat_projected = self.input_convs[i](feat)
  134. feat_hw = torch._shape_as_tensor(feat)[2:].to(feat.device)
  135. # no padding padding部分mask掉
  136. padding_mask_resized = feat.new_zeros(
  137. (batch_size, ) + feat.shape[-2:], dtype=torch.bool)
  138. pos_embed = self.postional_encoding(padding_mask_resized) # 正弦位置编码,与特征图大小对应
  139. level_embed = self.level_encoding.weight[i] # 层级位置编码,就是256维向量
  140. level_pos_embed = level_embed.view(1, -1, 1, 1) + pos_embed
  141. # (h_i * w_i, 2) 采样点
  142. reference_points = self.point_generator.single_level_grid_priors(
  143. feat.shape[-2:], level_idx, device=feat.device)
  144. # normalize
  145. feat_wh = feat_hw.unsqueeze(0).flip(dims=[0, 1])
  146. factor = feat_wh * self.strides[level_idx]
  147. reference_points = reference_points / factor
  148. # shape (batch_size, c, h_i, w_i) -> (h_i * w_i, batch_size, c) 维度转换
  149. feat_projected = feat_projected.flatten(2).permute(0, 2, 1)
  150. level_pos_embed = level_pos_embed.flatten(2).permute(0, 2, 1)
  151. padding_mask_resized = padding_mask_resized.flatten(1)
  152. # 各个层级加入列表
  153. encoder_input_list.append(feat_projected)
  154. padding_mask_list.append(padding_mask_resized)
  155. level_positional_encoding_list.append(level_pos_embed)
  156. spatial_shapes.append(feat_hw)
  157. reference_points_list.append(reference_points)
  158. # shape (batch_size, total_num_queries),
  159. # total_num_queries=sum([., h_i * w_i,.])
  160. padding_masks = torch.cat(padding_mask_list, dim=1)
  161. # shape (total_num_queries, batch_size, c) 拼接各个层级
  162. encoder_inputs = torch.cat(encoder_input_list, dim=1)
  163. level_positional_encodings = torch.cat(
  164. level_positional_encoding_list, dim=1)
  165. # shape (num_encoder_levels, 2), from low
  166. # resolution to high resolution 各个层级的分界
  167. num_queries_per_level = [e[0] * e[1] for e in spatial_shapes]
  168. spatial_shapes = torch.cat(spatial_shapes).view(-1, 2) # 各个层级特征图大小
  169. # shape (0, h_0*w_0, h_0*w_0+h_1*w_1, ...)
  170. level_start_index = torch.cat((spatial_shapes.new_zeros(
  171. (1, )), spatial_shapes.prod(1).cumsum(0)[:-1]))
  172. reference_points = torch.cat(reference_points_list, dim=0) # 采样参考点
  173. reference_points = reference_points[None, :, None].repeat(
  174. batch_size, 1, self.num_encoder_levels, 1)
  175. valid_radios = reference_points.new_ones( # 哪一个层级不用
  176. (batch_size, self.num_encoder_levels, 2))
  177. # shape (num_total_queries, batch_size, c) deformable transformer进行特征提取
  178. memory = self.encoder(
  179. query=encoder_inputs,
  180. query_pos=level_positional_encodings,
  181. key_padding_mask=padding_masks,
  182. spatial_shapes=spatial_shapes,
  183. reference_points=reference_points,
  184. level_start_index=level_start_index,
  185. valid_ratios=valid_radios)
  186. # (batch_size, c, num_total_queries)
  187. memory = memory.permute(0, 2, 1)
  188. # from low resolution to high resolution
  189. outs = torch.split(memory, num_queries_per_level, dim=-1) # 将各个层级分开
  190. outs = [
  191. x.reshape(batch_size, -1, spatial_shapes[i][0],
  192. spatial_shapes[i][1]) for i, x in enumerate(outs)
  193. ]
  194. # 上采样与特征融合
  195. for i in range(self.num_input_levels - self.num_encoder_levels - 1, -1,
  196. -1):
  197. x = feats[i]
  198. cur_feat = self.lateral_convs[i](x)
  199. y = cur_feat + F.interpolate(
  200. outs[-1],
  201. size=cur_feat.shape[-2:],
  202. mode='bilinear',
  203. align_corners=False)
  204. y = self.output_convs[i](y)
  205. outs.append(y)
  206. multi_scale_features = outs[:self.num_outs]
  207. mask_feature = self.mask_feature(outs[-1]) # 根据最后一层特征图学习一个mask
  208. return mask_feature, multi_scale_features

 deformable transformer

        在deformerable transformer中,需要根据query预测一个偏移量和注意力权重,然后根据采样点和偏移量完成对V的采样,并完成attention_score*v。

  1. class MultiScaleDeformableAttention(BaseModule):
  2. """An attention module used in Deformable-Detr.
  3. `Deformable DETR: Deformable Transformers for End-to-End Object Detection.
  4. <https://arxiv.org/pdf/2010.04159.pdf>`_.
  5. Args:
  6. embed_dims (int): The embedding dimension of Attention.
  7. Default: 256.
  8. num_heads (int): Parallel attention heads. Default: 8.
  9. num_levels (int): The number of feature map used in
  10. Attention. Default: 4.
  11. num_points (int): The number of sampling points for
  12. each query in each head. Default: 4.
  13. im2col_step (int): The step used in image_to_column.
  14. Default: 64.
  15. dropout (float): A Dropout layer on `inp_identity`.
  16. Default: 0.1.
  17. batch_first (bool): Key, Query and Value are shape of
  18. (batch, n, embed_dim)
  19. or (n, batch, embed_dim). Default to False.
  20. norm_cfg (dict): Config dict for normalization layer.
  21. Default: None.
  22. init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
  23. Default: None.
  24. value_proj_ratio (float): The expansion ratio of value_proj.
  25. Default: 1.0.
  26. """
  27. def __init__(self,
  28. embed_dims: int = 256,
  29. num_heads: int = 8,
  30. num_levels: int = 4,
  31. num_points: int = 4,
  32. im2col_step: int = 64,
  33. dropout: float = 0.1,
  34. batch_first: bool = False,
  35. norm_cfg: Optional[dict] = None,
  36. init_cfg: Optional[mmengine.ConfigDict] = None,
  37. value_proj_ratio: float = 1.0):
  38. super().__init__(init_cfg)
  39. if embed_dims % num_heads != 0:
  40. raise ValueError(f'embed_dims must be divisible by num_heads, '
  41. f'but got {embed_dims} and {num_heads}')
  42. dim_per_head = embed_dims // num_heads
  43. self.norm_cfg = norm_cfg
  44. self.dropout = nn.Dropout(dropout)
  45. self.batch_first = batch_first
  46. # you'd better set dim_per_head to a power of 2
  47. # which is more efficient in the CUDA implementation
  48. def _is_power_of_2(n):
  49. if (not isinstance(n, int)) or (n < 0):
  50. raise ValueError(
  51. 'invalid input for _is_power_of_2: {} (type: {})'.format(
  52. n, type(n)))
  53. return (n & (n - 1) == 0) and n != 0
  54. if not _is_power_of_2(dim_per_head):
  55. warnings.warn(
  56. "You'd better set embed_dims in "
  57. 'MultiScaleDeformAttention to make '
  58. 'the dimension of each attention head a power of 2 '
  59. 'which is more efficient in our CUDA implementation.')
  60. self.im2col_step = im2col_step
  61. self.embed_dims = embed_dims
  62. self.num_levels = num_levels
  63. self.num_heads = num_heads
  64. self.num_points = num_points
  65. self.sampling_offsets = nn.Linear(
  66. embed_dims, num_heads * num_levels * num_points * 2)
  67. self.attention_weights = nn.Linear(embed_dims,
  68. num_heads * num_levels * num_points)
  69. value_proj_size = int(embed_dims * value_proj_ratio)
  70. self.value_proj = nn.Linear(embed_dims, value_proj_size)
  71. self.output_proj = nn.Linear(value_proj_size, embed_dims)
  72. self.init_weights()
  73. def init_weights(self) -> None:
  74. """Default initialization for Parameters of Module."""
  75. constant_init(self.sampling_offsets, 0.)
  76. device = next(self.parameters()).device
  77. thetas = torch.arange(
  78. self.num_heads, dtype=torch.float32,
  79. device=device) * (2.0 * math.pi / self.num_heads)
  80. grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
  81. grid_init = (grid_init /
  82. grid_init.abs().max(-1, keepdim=True)[0]).view(
  83. self.num_heads, 1, 1,
  84. 2).repeat(1, self.num_levels, self.num_points, 1)
  85. for i in range(self.num_points):
  86. grid_init[:, :, i, :] *= i + 1
  87. self.sampling_offsets.bias.data = grid_init.view(-1)
  88. constant_init(self.attention_weights, val=0., bias=0.)
  89. xavier_init(self.value_proj, distribution='uniform', bias=0.)
  90. xavier_init(self.output_proj, distribution='uniform', bias=0.)
  91. self._is_init = True
  92. @no_type_check
  93. @deprecated_api_warning({'residual': 'identity'},
  94. cls_name='MultiScaleDeformableAttention')
  95. def forward(self,
  96. query: torch.Tensor,
  97. key: Optional[torch.Tensor] = None,
  98. value: Optional[torch.Tensor] = None,
  99. identity: Optional[torch.Tensor] = None,
  100. query_pos: Optional[torch.Tensor] = None,
  101. key_padding_mask: Optional[torch.Tensor] = None,
  102. reference_points: Optional[torch.Tensor] = None,
  103. spatial_shapes: Optional[torch.Tensor] = None,
  104. level_start_index: Optional[torch.Tensor] = None,
  105. **kwargs) -> torch.Tensor:
  106. """Forward Function of MultiScaleDeformAttention.
  107. Args:
  108. query (torch.Tensor): Query of Transformer with shape
  109. (num_query, bs, embed_dims).
  110. key (torch.Tensor): The key tensor with shape
  111. `(num_key, bs, embed_dims)`.
  112. value (torch.Tensor): The value tensor with shape
  113. `(num_key, bs, embed_dims)`.
  114. identity (torch.Tensor): The tensor used for addition, with the
  115. same shape as `query`. Default None. If None,
  116. `query` will be used.
  117. query_pos (torch.Tensor): The positional encoding for `query`.
  118. Default: None.
  119. key_padding_mask (torch.Tensor): ByteTensor for `query`, with
  120. shape [bs, num_key].
  121. reference_points (torch.Tensor): The normalized reference
  122. points with shape (bs, num_query, num_levels, 2),
  123. all elements is range in [0, 1], top-left (0,0),
  124. bottom-right (1, 1), including padding area.
  125. or (N, Length_{query}, num_levels, 4), add
  126. additional two dimensions is (w, h) to
  127. form reference boxes.
  128. spatial_shapes (torch.Tensor): Spatial shape of features in
  129. different levels. With shape (num_levels, 2),
  130. last dimension represents (h, w).
  131. level_start_index (torch.Tensor): The start index of each level.
  132. A tensor has shape ``(num_levels, )`` and can be represented
  133. as [0, h_0*w_0, h_0*w_0+h_1*w_1, ...].
  134. Returns:
  135. torch.Tensor: forwarded results with shape
  136. [num_query, bs, embed_dims].
  137. """
  138. if value is None:
  139. value = query
  140. if identity is None:
  141. identity = query
  142. if query_pos is not None:
  143. query = query + query_pos
  144. if not self.batch_first:
  145. # change to (bs, num_query ,embed_dims)
  146. query = query.permute(1, 0, 2)
  147. value = value.permute(1, 0, 2)
  148. bs, num_query, _ = query.shape
  149. bs, num_value, _ = value.shape
  150. assert (spatial_shapes[:, 0] * spatial_shapes[:, 1]).sum() == num_value
  151. value = self.value_proj(value) # 全连接层,得到v
  152. if key_padding_mask is not None: # mask,可能有维度不对应的情况
  153. value = value.masked_fill(key_padding_mask[..., None], 0.0)
  154. value = value.view(bs, num_value, self.num_heads, -1)
  155. sampling_offsets = self.sampling_offsets(query).view( # 通过query预测一个偏移量,MLP层输出通道数满足:nem_heads*num_levels*num_points*2
  156. bs, num_query, self.num_heads, self.num_levels, self.num_points, 2)
  157. attention_weights = self.attention_weights(query).view( # 通过query预测注意力权重,num_heads*num_levels*num_points
  158. bs, num_query, self.num_heads, self.num_levels * self.num_points)
  159. attention_weights = attention_weights.softmax(-1)
  160. attention_weights = attention_weights.view(bs, num_query,
  161. self.num_heads,
  162. self.num_levels,
  163. self.num_points)
  164. if reference_points.shape[-1] == 2: # 进一步得到偏移后点的坐标[-1,+1]
  165. offset_normalizer = torch.stack(
  166. [spatial_shapes[..., 1], spatial_shapes[..., 0]], -1)
  167. sampling_locations = reference_points[:, :, None, :, None, :] \
  168. + sampling_offsets \
  169. / offset_normalizer[None, None, None, :, None, :]
  170. elif reference_points.shape[-1] == 4:
  171. sampling_locations = reference_points[:, :, None, :, None, :2] \
  172. + sampling_offsets / self.num_points \
  173. * reference_points[:, :, None, :, None, 2:] \
  174. * 0.5
  175. else:
  176. raise ValueError(
  177. f'Last dim of reference_points must be'
  178. f' 2 or 4, but get {reference_points.shape[-1]} instead.')
  179. if ((IS_CUDA_AVAILABLE and value.is_cuda)
  180. or (IS_MLU_AVAILABLE and value.is_mlu)):
  181. output = MultiScaleDeformableAttnFunction.apply( # 完成采样和attention*v
  182. value, spatial_shapes, level_start_index, sampling_locations,
  183. attention_weights, self.im2col_step)
  184. else:
  185. output = multi_scale_deformable_attn_pytorch(
  186. value, spatial_shapes, sampling_locations, attention_weights)
  187. output = self.output_proj(output) # 输出的全连接层
  188. if not self.batch_first:
  189. # (num_query, bs ,embed_dims)
  190. output = output.permute(1, 0, 2)
  191. return self.dropout(output) + identity # dropout和残差
  1. def multi_scale_deformable_attn_pytorch(
  2. value: torch.Tensor, value_spatial_shapes: torch.Tensor,
  3. sampling_locations: torch.Tensor,
  4. attention_weights: torch.Tensor) -> torch.Tensor:
  5. """CPU version of multi-scale deformable attention.
  6. Args:
  7. value (torch.Tensor): The value has shape
  8. (bs, num_keys, num_heads, embed_dims//num_heads)
  9. value_spatial_shapes (torch.Tensor): Spatial shape of
  10. each feature map, has shape (num_levels, 2),
  11. last dimension 2 represent (h, w)
  12. sampling_locations (torch.Tensor): The location of sampling points,
  13. has shape
  14. (bs ,num_queries, num_heads, num_levels, num_points, 2),
  15. the last dimension 2 represent (x, y).
  16. attention_weights (torch.Tensor): The weight of sampling points used
  17. when calculate the attention, has shape
  18. (bs ,num_queries, num_heads, num_levels, num_points),
  19. Returns:
  20. torch.Tensor: has shape (bs, num_queries, embed_dims)
  21. """
  22. bs, _, num_heads, embed_dims = value.shape
  23. _, num_queries, num_heads, num_levels, num_points, _ =\
  24. sampling_locations.shape
  25. value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes],
  26. dim=1) # 分离各个层级
  27. sampling_grids = 2 * sampling_locations - 1
  28. sampling_value_list = [] # 对各个层级进行采样
  29. for level, (H_, W_) in enumerate(value_spatial_shapes):
  30. # bs, H_*W_, num_heads, embed_dims ->
  31. # bs, H_*W_, num_heads*embed_dims ->
  32. # bs, num_heads*embed_dims, H_*W_ ->
  33. # bs*num_heads, embed_dims, H_, W_
  34. value_l_ = value_list[level].flatten(2).transpose(1, 2).reshape(
  35. bs * num_heads, embed_dims, H_, W_)
  36. # bs, num_queries, num_heads, num_points, 2 ->
  37. # bs, num_heads, num_queries, num_points, 2 ->
  38. # bs*num_heads, num_queries, num_points, 2
  39. sampling_grid_l_ = sampling_grids[:, :, :,
  40. level].transpose(1, 2).flatten(0, 1)
  41. # bs*num_heads, embed_dims, num_queries, num_points
  42. sampling_value_l_ = F.grid_sample(
  43. value_l_,
  44. sampling_grid_l_,
  45. mode='bilinear',
  46. padding_mode='zeros',
  47. align_corners=False)
  48. sampling_value_list.append(sampling_value_l_)
  49. # (bs, num_queries, num_heads, num_levels, num_points) ->
  50. # (bs, num_heads, num_queries, num_levels, num_points) ->
  51. # (bs, num_heads, 1, num_queries, num_levels*num_points)
  52. attention_weights = attention_weights.transpose(1, 2).reshape(
  53. bs * num_heads, 1, num_queries, num_levels * num_points) # attention*V
  54. output = (torch.stack(sampling_value_list, dim=-2).flatten(-2) *
  55. attention_weights).sum(-1).view(bs, num_heads * embed_dims,
  56. num_queries)
  57. return output.transpose(1, 2).contiguous()

4.transformer decoder

        transformer decoder预测一组mask,每个mask包含了预测的实例对象相关的信息。具体流程为:

  • 首先,初始化一组query
  •  得到query的类别预测,mask预测,同时得到cross attention的attention mask
  • 经过交叉注意力和自注意力进行特征提取与特征融合

交叉注意力与自注意力:

  1. class Mask2FormerTransformerDecoderLayer(DetrTransformerDecoderLayer):
  2. """Implements decoder layer in Mask2Former transformer."""
  3. def forward(self,
  4. query: Tensor,
  5. key: Tensor = None,
  6. value: Tensor = None,
  7. query_pos: Tensor = None,
  8. key_pos: Tensor = None,
  9. self_attn_mask: Tensor = None,
  10. cross_attn_mask: Tensor = None,
  11. key_padding_mask: Tensor = None,
  12. **kwargs) -> Tensor:
  13. """
  14. Args:
  15. query (Tensor): The input query, has shape (bs, num_queries, dim).
  16. key (Tensor, optional): The input key, has shape (bs, num_keys,
  17. dim). If `None`, the `query` will be used. Defaults to `None`.
  18. value (Tensor, optional): The input value, has the same shape as
  19. `key`, as in `nn.MultiheadAttention.forward`. If `None`, the
  20. `key` will be used. Defaults to `None`.
  21. query_pos (Tensor, optional): The positional encoding for `query`,
  22. has the same shape as `query`. If not `None`, it will be added
  23. to `query` before forward function. Defaults to `None`.
  24. key_pos (Tensor, optional): The positional encoding for `key`, has
  25. the same shape as `key`. If not `None`, it will be added to
  26. `key` before forward function. If None, and `query_pos` has the
  27. same shape as `key`, then `query_pos` will be used for
  28. `key_pos`. Defaults to None.
  29. self_attn_mask (Tensor, optional): ByteTensor mask, has shape
  30. (num_queries, num_keys), as in `nn.MultiheadAttention.forward`.
  31. Defaults to None.
  32. cross_attn_mask (Tensor, optional): ByteTensor mask, has shape
  33. (num_queries, num_keys), as in `nn.MultiheadAttention.forward`.
  34. Defaults to None.
  35. key_padding_mask (Tensor, optional): The `key_padding_mask` of
  36. `self_attn` input. ByteTensor, has shape (bs, num_value).
  37. Defaults to None.
  38. Returns:
  39. Tensor: forwarded results, has shape (bs, num_queries, dim).
  40. """
  41. query = self.cross_attn(
  42. query=query,
  43. key=key,
  44. value=value,
  45. query_pos=query_pos,
  46. key_pos=key_pos,
  47. attn_mask=cross_attn_mask,
  48. key_padding_mask=key_padding_mask,
  49. **kwargs)
  50. query = self.norms[0](query)
  51. query = self.self_attn(
  52. query=query,
  53. key=query,
  54. value=query,
  55. query_pos=query_pos,
  56. key_pos=query_pos,
  57. attn_mask=self_attn_mask,
  58. **kwargs)
  59. query = self.norms[1](query)
  60. query = self.ffn(query)
  61. query = self.norms[2](query)
  62. return query

pixel decoder和transformer decoder网络流程:

  1. class Mask2FormerHead(MaskFormerHead):
  2. """Implements the Mask2Former head.
  3. See `Masked-attention Mask Transformer for Universal Image
  4. Segmentation <https://arxiv.org/pdf/2112.01527>`_ for details.
  5. Args:
  6. in_channels (list[int]): Number of channels in the input feature map.
  7. feat_channels (int): Number of channels for features.
  8. out_channels (int): Number of channels for output.
  9. num_things_classes (int): Number of things.
  10. num_stuff_classes (int): Number of stuff.
  11. num_queries (int): Number of query in Transformer decoder.
  12. pixel_decoder (:obj:`ConfigDict` or dict): Config for pixel
  13. decoder. Defaults to None.
  14. enforce_decoder_input_project (bool, optional): Whether to add
  15. a layer to change the embed_dim of tranformer encoder in
  16. pixel decoder to the embed_dim of transformer decoder.
  17. Defaults to False.
  18. transformer_decoder (:obj:`ConfigDict` or dict): Config for
  19. transformer decoder. Defaults to None.
  20. positional_encoding (:obj:`ConfigDict` or dict): Config for
  21. transformer decoder position encoding. Defaults to
  22. dict(num_feats=128, normalize=True).
  23. loss_cls (:obj:`ConfigDict` or dict): Config of the classification
  24. loss. Defaults to None.
  25. loss_mask (:obj:`ConfigDict` or dict): Config of the mask loss.
  26. Defaults to None.
  27. loss_dice (:obj:`ConfigDict` or dict): Config of the dice loss.
  28. Defaults to None.
  29. train_cfg (:obj:`ConfigDict` or dict, optional): Training config of
  30. Mask2Former head.
  31. test_cfg (:obj:`ConfigDict` or dict, optional): Testing config of
  32. Mask2Former head.
  33. init_cfg (:obj:`ConfigDict` or dict or list[:obj:`ConfigDict` or \
  34. dict], optional): Initialization config dict. Defaults to None.
  35. """
  36. def __init__(self,
  37. in_channels: List[int],
  38. feat_channels: int,
  39. out_channels: int,
  40. num_things_classes: int = 80,
  41. num_stuff_classes: int = 53,
  42. num_queries: int = 100,
  43. num_transformer_feat_level: int = 3,
  44. pixel_decoder: ConfigType = ...,
  45. enforce_decoder_input_project: bool = False,
  46. transformer_decoder: ConfigType = ...,
  47. positional_encoding: ConfigType = dict(
  48. num_feats=128, normalize=True),
  49. loss_cls: ConfigType = dict(
  50. type='CrossEntropyLoss',
  51. use_sigmoid=False,
  52. loss_weight=2.0,
  53. reduction='mean',
  54. class_weight=[1.0] * 133 + [0.1]),
  55. loss_mask: ConfigType = dict(
  56. type='CrossEntropyLoss',
  57. use_sigmoid=True,
  58. reduction='mean',
  59. loss_weight=5.0),
  60. loss_dice: ConfigType = dict(
  61. type='DiceLoss',
  62. use_sigmoid=True,
  63. activate=True,
  64. reduction='mean',
  65. naive_dice=True,
  66. eps=1.0,
  67. loss_weight=5.0),
  68. train_cfg: OptConfigType = None,
  69. test_cfg: OptConfigType = None,
  70. init_cfg: OptMultiConfig = None,
  71. **kwargs) -> None:
  72. super(AnchorFreeHead, self).__init__(init_cfg=init_cfg)
  73. self.num_things_classes = num_things_classes
  74. self.num_stuff_classes = num_stuff_classes
  75. self.num_classes = self.num_things_classes + self.num_stuff_classes
  76. self.num_queries = num_queries
  77. self.num_transformer_feat_level = num_transformer_feat_level
  78. self.num_heads = transformer_decoder.layer_cfg.cross_attn_cfg.num_heads
  79. self.num_transformer_decoder_layers = transformer_decoder.num_layers
  80. assert pixel_decoder.encoder.layer_cfg. \
  81. self_attn_cfg.num_levels == num_transformer_feat_level
  82. pixel_decoder_ = copy.deepcopy(pixel_decoder)
  83. pixel_decoder_.update(
  84. in_channels=in_channels,
  85. feat_channels=feat_channels,
  86. out_channels=out_channels)
  87. self.pixel_decoder = MODELS.build(pixel_decoder_)
  88. self.transformer_decoder = Mask2FormerTransformerDecoder(
  89. **transformer_decoder)
  90. self.decoder_embed_dims = self.transformer_decoder.embed_dims
  91. self.decoder_input_projs = ModuleList()
  92. # from low resolution to high resolution
  93. for _ in range(num_transformer_feat_level):
  94. if (self.decoder_embed_dims != feat_channels
  95. or enforce_decoder_input_project):
  96. self.decoder_input_projs.append(
  97. Conv2d(
  98. feat_channels, self.decoder_embed_dims, kernel_size=1))
  99. else:
  100. self.decoder_input_projs.append(nn.Identity())
  101. self.decoder_positional_encoding = SinePositionalEncoding(
  102. **positional_encoding)
  103. self.query_embed = nn.Embedding(self.num_queries, feat_channels)
  104. self.query_feat = nn.Embedding(self.num_queries, feat_channels)
  105. # from low resolution to high resolution
  106. self.level_embed = nn.Embedding(self.num_transformer_feat_level,
  107. feat_channels)
  108. self.cls_embed = nn.Linear(feat_channels, self.num_classes + 1)
  109. self.mask_embed = nn.Sequential(
  110. nn.Linear(feat_channels, feat_channels), nn.ReLU(inplace=True),
  111. nn.Linear(feat_channels, feat_channels), nn.ReLU(inplace=True),
  112. nn.Linear(feat_channels, out_channels))
  113. self.test_cfg = test_cfg
  114. self.train_cfg = train_cfg
  115. if train_cfg:
  116. self.assigner = TASK_UTILS.build(self.train_cfg['assigner'])
  117. self.sampler = TASK_UTILS.build(
  118. self.train_cfg['sampler'], default_args=dict(context=self))
  119. self.num_points = self.train_cfg.get('num_points', 12544)
  120. self.oversample_ratio = self.train_cfg.get('oversample_ratio', 3.0)
  121. self.importance_sample_ratio = self.train_cfg.get(
  122. 'importance_sample_ratio', 0.75)
  123. self.class_weight = loss_cls.class_weight
  124. self.loss_cls = MODELS.build(loss_cls)
  125. self.loss_mask = MODELS.build(loss_mask)
  126. self.loss_dice = MODELS.build(loss_dice)
  127. def init_weights(self) -> None:
  128. for m in self.decoder_input_projs:
  129. if isinstance(m, Conv2d):
  130. caffe2_xavier_init(m, bias=0)
  131. self.pixel_decoder.init_weights()
  132. for p in self.transformer_decoder.parameters():
  133. if p.dim() > 1:
  134. nn.init.xavier_normal_(p)
  135. def _get_targets_single(self, cls_score: Tensor, mask_pred: Tensor,
  136. gt_instances: InstanceData,
  137. img_meta: dict) -> Tuple[Tensor]:
  138. """Compute classification and mask targets for one image.
  139. Args:
  140. cls_score (Tensor): Mask score logits from a single decoder layer
  141. for one image. Shape (num_queries, cls_out_channels).
  142. mask_pred (Tensor): Mask logits for a single decoder layer for one
  143. image. Shape (num_queries, h, w).
  144. gt_instances (:obj:`InstanceData`): It contains ``labels`` and
  145. ``masks``.
  146. img_meta (dict): Image informtation.
  147. Returns:
  148. tuple[Tensor]: A tuple containing the following for one image.
  149. - labels (Tensor): Labels of each image. \
  150. shape (num_queries, ).
  151. - label_weights (Tensor): Label weights of each image. \
  152. shape (num_queries, ).
  153. - mask_targets (Tensor): Mask targets of each image. \
  154. shape (num_queries, h, w).
  155. - mask_weights (Tensor): Mask weights of each image. \
  156. shape (num_queries, ).
  157. - pos_inds (Tensor): Sampled positive indices for each \
  158. image.
  159. - neg_inds (Tensor): Sampled negative indices for each \
  160. image.
  161. - sampling_result (:obj:`SamplingResult`): Sampling results.
  162. """
  163. gt_labels = gt_instances.labels
  164. gt_masks = gt_instances.masks
  165. # sample points
  166. num_queries = cls_score.shape[0]
  167. num_gts = gt_labels.shape[0]
  168. point_coords = torch.rand((1, self.num_points, 2),
  169. device=cls_score.device)
  170. # shape (num_queries, num_points)
  171. mask_points_pred = point_sample(
  172. mask_pred.unsqueeze(1), point_coords.repeat(num_queries, 1,
  173. 1)).squeeze(1)
  174. # shape (num_gts, num_points)
  175. gt_points_masks = point_sample(
  176. gt_masks.unsqueeze(1).float(), point_coords.repeat(num_gts, 1,
  177. 1)).squeeze(1)
  178. sampled_gt_instances = InstanceData(
  179. labels=gt_labels, masks=gt_points_masks)
  180. sampled_pred_instances = InstanceData(
  181. scores=cls_score, masks=mask_points_pred)
  182. # assign and sample
  183. assign_result = self.assigner.assign(
  184. pred_instances=sampled_pred_instances,
  185. gt_instances=sampled_gt_instances,
  186. img_meta=img_meta)
  187. pred_instances = InstanceData(scores=cls_score, masks=mask_pred)
  188. sampling_result = self.sampler.sample(
  189. assign_result=assign_result,
  190. pred_instances=pred_instances,
  191. gt_instances=gt_instances)
  192. pos_inds = sampling_result.pos_inds
  193. neg_inds = sampling_result.neg_inds
  194. # label target
  195. labels = gt_labels.new_full((self.num_queries, ),
  196. self.num_classes,
  197. dtype=torch.long)
  198. labels[pos_inds] = gt_labels[sampling_result.pos_assigned_gt_inds]
  199. label_weights = gt_labels.new_ones((self.num_queries, ))
  200. # mask target
  201. mask_targets = gt_masks[sampling_result.pos_assigned_gt_inds]
  202. mask_weights = mask_pred.new_zeros((self.num_queries, ))
  203. mask_weights[pos_inds] = 1.0
  204. return (labels, label_weights, mask_targets, mask_weights, pos_inds,
  205. neg_inds, sampling_result)
  206. def _loss_by_feat_single(self, cls_scores: Tensor, mask_preds: Tensor,
  207. batch_gt_instances: List[InstanceData],
  208. batch_img_metas: List[dict]) -> Tuple[Tensor]:
  209. """Loss function for outputs from a single decoder layer.
  210. Args:
  211. cls_scores (Tensor): Mask score logits from a single decoder layer
  212. for all images. Shape (batch_size, num_queries,
  213. cls_out_channels). Note `cls_out_channels` should includes
  214. background.
  215. mask_preds (Tensor): Mask logits for a pixel decoder for all
  216. images. Shape (batch_size, num_queries, h, w).
  217. batch_gt_instances (list[obj:`InstanceData`]): each contains
  218. ``labels`` and ``masks``.
  219. batch_img_metas (list[dict]): List of image meta information.
  220. Returns:
  221. tuple[Tensor]: Loss components for outputs from a single \
  222. decoder layer.
  223. """
  224. num_imgs = cls_scores.size(0)
  225. cls_scores_list = [cls_scores[i] for i in range(num_imgs)]
  226. mask_preds_list = [mask_preds[i] for i in range(num_imgs)]
  227. (labels_list, label_weights_list, mask_targets_list, mask_weights_list,
  228. avg_factor) = self.get_targets(cls_scores_list, mask_preds_list,
  229. batch_gt_instances, batch_img_metas)
  230. # shape (batch_size, num_queries)
  231. labels = torch.stack(labels_list, dim=0)
  232. # shape (batch_size, num_queries)
  233. label_weights = torch.stack(label_weights_list, dim=0)
  234. # shape (num_total_gts, h, w)
  235. mask_targets = torch.cat(mask_targets_list, dim=0)
  236. # shape (batch_size, num_queries)
  237. mask_weights = torch.stack(mask_weights_list, dim=0)
  238. # classfication loss
  239. # shape (batch_size * num_queries, )
  240. cls_scores = cls_scores.flatten(0, 1)
  241. labels = labels.flatten(0, 1)
  242. label_weights = label_weights.flatten(0, 1)
  243. class_weight = cls_scores.new_tensor(self.class_weight)
  244. loss_cls = self.loss_cls(
  245. cls_scores,
  246. labels,
  247. label_weights,
  248. avg_factor=class_weight[labels].sum())
  249. num_total_masks = reduce_mean(cls_scores.new_tensor([avg_factor]))
  250. num_total_masks = max(num_total_masks, 1)
  251. # extract positive ones
  252. # shape (batch_size, num_queries, h, w) -> (num_total_gts, h, w)
  253. mask_preds = mask_preds[mask_weights > 0]
  254. if mask_targets.shape[0] == 0:
  255. # zero match
  256. loss_dice = mask_preds.sum()
  257. loss_mask = mask_preds.sum()
  258. return loss_cls, loss_mask, loss_dice
  259. with torch.no_grad():
  260. points_coords = get_uncertain_point_coords_with_randomness(
  261. mask_preds.unsqueeze(1), None, self.num_points,
  262. self.oversample_ratio, self.importance_sample_ratio)
  263. # shape (num_total_gts, h, w) -> (num_total_gts, num_points)
  264. mask_point_targets = point_sample(
  265. mask_targets.unsqueeze(1).float(), points_coords).squeeze(1)
  266. # shape (num_queries, h, w) -> (num_queries, num_points)
  267. mask_point_preds = point_sample(
  268. mask_preds.unsqueeze(1), points_coords).squeeze(1)
  269. # dice loss
  270. loss_dice = self.loss_dice(
  271. mask_point_preds, mask_point_targets, avg_factor=num_total_masks)
  272. # mask loss
  273. # shape (num_queries, num_points) -> (num_queries * num_points, )
  274. mask_point_preds = mask_point_preds.reshape(-1)
  275. # shape (num_total_gts, num_points) -> (num_total_gts * num_points, )
  276. mask_point_targets = mask_point_targets.reshape(-1)
  277. loss_mask = self.loss_mask(
  278. mask_point_preds,
  279. mask_point_targets,
  280. avg_factor=num_total_masks * self.num_points)
  281. return loss_cls, loss_mask, loss_dice
  282. def _forward_head(self, decoder_out: Tensor, mask_feature: Tensor,
  283. attn_mask_target_size: Tuple[int, int]) -> Tuple[Tensor]:
  284. """Forward for head part which is called after every decoder layer.
  285. Args:
  286. decoder_out (Tensor): in shape (batch_size, num_queries, c).
  287. mask_feature (Tensor): in shape (batch_size, c, h, w).
  288. attn_mask_target_size (tuple[int, int]): target attention
  289. mask size.
  290. Returns:
  291. tuple: A tuple contain three elements.
  292. - cls_pred (Tensor): Classification scores in shape \
  293. (batch_size, num_queries, cls_out_channels). \
  294. Note `cls_out_channels` should includes background.
  295. - mask_pred (Tensor): Mask scores in shape \
  296. (batch_size, num_queries,h, w).
  297. - attn_mask (Tensor): Attention mask in shape \
  298. (batch_size * num_heads, num_queries, h, w).
  299. """
  300. decoder_out = self.transformer_decoder.post_norm(decoder_out) # layernorm
  301. # shape (num_queries, batch_size, c)
  302. cls_pred = self.cls_embed(decoder_out) # 类别预测
  303. # shape (num_queries, batch_size, c)
  304. mask_embed = self.mask_embed(decoder_out)
  305. # shape (num_queries, batch_size, h, w) 相当于将query映射到区域
  306. mask_pred = torch.einsum('bqc,bchw->bqhw', mask_embed, mask_feature)
  307. attn_mask = F.interpolate(
  308. mask_pred,
  309. attn_mask_target_size,
  310. mode='bilinear',
  311. align_corners=False) # 下采样到16*16大小
  312. # shape (num_queries, batch_size, h, w) ->
  313. # (batch_size * num_head, num_queries, h, w) repeat为多头
  314. attn_mask = attn_mask.flatten(2).unsqueeze(1).repeat(
  315. (1, self.num_heads, 1, 1)).flatten(0, 1)
  316. attn_mask = attn_mask.sigmoid() < 0.5 # 注意力mask的定义
  317. attn_mask = attn_mask.detach()
  318. return cls_pred, mask_pred, attn_mask
  319. def forward(self, x: List[Tensor],
  320. batch_data_samples: SampleList) -> Tuple[List[Tensor]]:
  321. """Forward function.
  322. Args:
  323. x (list[Tensor]): Multi scale Features from the
  324. upstream network, each is a 4D-tensor.
  325. batch_data_samples (List[:obj:`DetDataSample`]): The Data
  326. Samples. It usually includes information such as
  327. `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`.
  328. Returns:
  329. tuple[list[Tensor]]: A tuple contains two elements.
  330. - cls_pred_list (list[Tensor)]: Classification logits \
  331. for each decoder layer. Each is a 3D-tensor with shape \
  332. (batch_size, num_queries, cls_out_channels). \
  333. Note `cls_out_channels` should includes background.
  334. - mask_pred_list (list[Tensor]): Mask logits for each \
  335. decoder layer. Each with shape (batch_size, num_queries, \
  336. h, w).
  337. """
  338. batch_size = x[0].shape[0]
  339. mask_features, multi_scale_memorys = self.pixel_decoder(x)
  340. # multi_scale_memorys (from low resolution to high resolution)
  341. decoder_inputs = []
  342. decoder_positional_encodings = []
  343. for i in range(self.num_transformer_feat_level):
  344. decoder_input = self.decoder_input_projs[i](multi_scale_memorys[i]) # decoder的输入
  345. # shape (batch_size, c, h, w) -> (batch_size, h*w, c)
  346. decoder_input = decoder_input.flatten(2).permute(0, 2, 1)
  347. level_embed = self.level_embed.weight[i].view(1, 1, -1) # 层级编码
  348. decoder_input = decoder_input + level_embed
  349. # shape (batch_size, c, h, w) -> (batch_size, h*w, c)
  350. mask = decoder_input.new_zeros( # 初始化mask
  351. (batch_size, ) + multi_scale_memorys[i].shape[-2:],
  352. dtype=torch.bool)
  353. decoder_positional_encoding = self.decoder_positional_encoding(
  354. mask) # 位置编码维度与mask一致
  355. decoder_positional_encoding = decoder_positional_encoding.flatten(
  356. 2).permute(0, 2, 1)
  357. decoder_inputs.append(decoder_input)
  358. decoder_positional_encodings.append(decoder_positional_encoding)
  359. # shape (num_queries, c) -> (batch_size, num_queries, c)
  360. query_feat = self.query_feat.weight.unsqueeze(0).repeat( # query的特征
  361. (batch_size, 1, 1))
  362. query_embed = self.query_embed.weight.unsqueeze(0).repeat( # query的位置编码
  363. (batch_size, 1, 1))
  364. cls_pred_list = []
  365. mask_pred_list = []
  366. # 获得类别预测,mask预测,注意力mask
  367. cls_pred, mask_pred, attn_mask = self._forward_head(
  368. query_feat, mask_features, multi_scale_memorys[0].shape[-2:])
  369. cls_pred_list.append(cls_pred)
  370. mask_pred_list.append(mask_pred)
  371. for i in range(self.num_transformer_decoder_layers):
  372. level_idx = i % self.num_transformer_feat_level
  373. # if a mask is all True(all background), then set it all False.全为True,cross attn就失效了
  374. mask_sum = (attn_mask.sum(-1) != attn_mask.shape[-1]).unsqueeze(-1)
  375. attn_mask = attn_mask & mask_sum
  376. # cross_attn + self_attn
  377. layer = self.transformer_decoder.layers[i]
  378. query_feat = layer( # cross attn
  379. query=query_feat,
  380. key=decoder_inputs[level_idx],
  381. value=decoder_inputs[level_idx],
  382. query_pos=query_embed,
  383. key_pos=decoder_positional_encodings[level_idx],
  384. cross_attn_mask=attn_mask,
  385. query_key_padding_mask=None,
  386. # here we do not apply masking on padded region
  387. key_padding_mask=None)
  388. cls_pred, mask_pred, attn_mask = self._forward_head( # 输出层,更新cls_pred,mask_pred,attn_mask
  389. query_feat, mask_features, multi_scale_memorys[
  390. (i + 1) % self.num_transformer_feat_level].shape[-2:])
  391. cls_pred_list.append(cls_pred)
  392. mask_pred_list.append(mask_pred)
  393. return cls_pred_list, mask_pred_list

5.标签分配策略

        标签分配采用的是匈牙利二分图匹配,对于匈牙利匹配,首先需要构建一个维度为num_query*num_labels的成本矩阵,成本矩阵主要由3种损失构成,即分类损失、mask损失,diceloss损失,分类损失是query预测每个label概率的负值,mask损失是一个二元交叉熵损失,dice loss是重叠度损失。然后使用匈牙利匹配方法进行匹配。

  1. class HungarianAssigner(BaseAssigner):
  2. """Computes one-to-one matching between predictions and ground truth.
  3. This class computes an assignment between the targets and the predictions
  4. based on the costs. The costs are weighted sum of some components.
  5. For DETR the costs are weighted sum of classification cost, regression L1
  6. cost and regression iou cost. The targets don't include the no_object, so
  7. generally there are more predictions than targets. After the one-to-one
  8. matching, the un-matched are treated as backgrounds. Thus each query
  9. prediction will be assigned with `0` or a positive integer indicating the
  10. ground truth index:
  11. - 0: negative sample, no assigned gt
  12. - positive integer: positive sample, index (1-based) of assigned gt
  13. Args:
  14. match_costs (:obj:`ConfigDict` or dict or \
  15. List[Union[:obj:`ConfigDict`, dict]]): Match cost configs.
  16. """
  17. def __init__(
  18. self, match_costs: Union[List[Union[dict, ConfigDict]], dict,
  19. ConfigDict]
  20. ) -> None:
  21. if isinstance(match_costs, dict):
  22. match_costs = [match_costs]
  23. elif isinstance(match_costs, list):
  24. assert len(match_costs) > 0, \
  25. 'match_costs must not be a empty list.'
  26. self.match_costs = [
  27. TASK_UTILS.build(match_cost) for match_cost in match_costs
  28. ]
  29. def assign(self,
  30. pred_instances: InstanceData,
  31. gt_instances: InstanceData,
  32. img_meta: Optional[dict] = None,
  33. **kwargs) -> AssignResult:
  34. """Computes one-to-one matching based on the weighted costs.
  35. This method assign each query prediction to a ground truth or
  36. background. The `assigned_gt_inds` with -1 means don't care,
  37. 0 means negative sample, and positive number is the index (1-based)
  38. of assigned gt.
  39. The assignment is done in the following steps, the order matters.
  40. 1. assign every prediction to -1
  41. 2. compute the weighted costs
  42. 3. do Hungarian matching on CPU based on the costs
  43. 4. assign all to 0 (background) first, then for each matched pair
  44. between predictions and gts, treat this prediction as foreground
  45. and assign the corresponding gt index (plus 1) to it.
  46. Args:
  47. pred_instances (:obj:`InstanceData`): Instances of model
  48. predictions. It includes ``priors``, and the priors can
  49. be anchors or points, or the bboxes predicted by the
  50. previous stage, has shape (n, 4). The bboxes predicted by
  51. the current model or stage will be named ``bboxes``,
  52. ``labels``, and ``scores``, the same as the ``InstanceData``
  53. in other places. It may includes ``masks``, with shape
  54. (n, h, w) or (n, l).
  55. gt_instances (:obj:`InstanceData`): Ground truth of instance
  56. annotations. It usually includes ``bboxes``, with shape (k, 4),
  57. ``labels``, with shape (k, ) and ``masks``, with shape
  58. (k, h, w) or (k, l).
  59. img_meta (dict): Image information.
  60. Returns:
  61. :obj:`AssignResult`: The assigned result.
  62. """
  63. assert isinstance(gt_instances.labels, Tensor)
  64. num_gts, num_preds = len(gt_instances), len(pred_instances)
  65. gt_labels = gt_instances.labels
  66. device = gt_labels.device
  67. # 1. assign -1 by default 初始化为-1
  68. assigned_gt_inds = torch.full((num_preds, ),
  69. -1,
  70. dtype=torch.long,
  71. device=device)
  72. assigned_labels = torch.full((num_preds, ),
  73. -1,
  74. dtype=torch.long,
  75. device=device)
  76. if num_gts == 0 or num_preds == 0:
  77. # No ground truth or boxes, return empty assignment
  78. if num_gts == 0:
  79. # No ground truth, assign all to background
  80. assigned_gt_inds[:] = 0
  81. return AssignResult(
  82. num_gts=num_gts,
  83. gt_inds=assigned_gt_inds,
  84. max_overlaps=None,
  85. labels=assigned_labels)
  86. # 2. compute weighted cost
  87. cost_list = [] # 分类损失是query预测每个label概率的负值
  88. for match_cost in self.match_costs: # 分类损失,mask损失,diceloss(重合比例)
  89. cost = match_cost(
  90. pred_instances=pred_instances,
  91. gt_instances=gt_instances,
  92. img_meta=img_meta)
  93. cost_list.append(cost)
  94. cost = torch.stack(cost_list).sum(dim=0)
  95. # 3. do Hungarian matching on CPU using linear_sum_assignment
  96. cost = cost.detach().cpu()
  97. if linear_sum_assignment is None:
  98. raise ImportError('Please run "pip install scipy" '
  99. 'to install scipy first.')
  100. matched_row_inds, matched_col_inds = linear_sum_assignment(cost) # num_query*num_lables的cost矩阵做二分图最大匹配
  101. matched_row_inds = torch.from_numpy(matched_row_inds).to(device)
  102. matched_col_inds = torch.from_numpy(matched_col_inds).to(device)
  103. # 4. assign backgrounds and foregrounds
  104. # assign all indices to backgrounds first
  105. assigned_gt_inds[:] = 0
  106. # assign foregrounds based on matching results 匹配的标签
  107. assigned_gt_inds[matched_row_inds] = matched_col_inds + 1
  108. assigned_labels[matched_row_inds] = gt_labels[matched_col_inds]
  109. return AssignResult( # 字典
  110. num_gts=num_gts,
  111. gt_inds=assigned_gt_inds,
  112. max_overlaps=None,
  113. labels=assigned_labels)

 整体代码:

  1. class MaskFormerHead(AnchorFreeHead):
  2. """Implements the MaskFormer head.
  3. See `Per-Pixel Classification is Not All You Need for Semantic
  4. Segmentation <https://arxiv.org/pdf/2107.06278>`_ for details.
  5. Args:
  6. in_channels (list[int]): Number of channels in the input feature map.
  7. feat_channels (int): Number of channels for feature.
  8. out_channels (int): Number of channels for output.
  9. num_things_classes (int): Number of things.
  10. num_stuff_classes (int): Number of stuff.
  11. num_queries (int): Number of query in Transformer.
  12. pixel_decoder (:obj:`ConfigDict` or dict): Config for pixel
  13. decoder.
  14. enforce_decoder_input_project (bool): Whether to add a layer
  15. to change the embed_dim of transformer encoder in pixel decoder to
  16. the embed_dim of transformer decoder. Defaults to False.
  17. transformer_decoder (:obj:`ConfigDict` or dict): Config for
  18. transformer decoder.
  19. positional_encoding (:obj:`ConfigDict` or dict): Config for
  20. transformer decoder position encoding.
  21. loss_cls (:obj:`ConfigDict` or dict): Config of the classification
  22. loss. Defaults to `CrossEntropyLoss`.
  23. loss_mask (:obj:`ConfigDict` or dict): Config of the mask loss.
  24. Defaults to `FocalLoss`.
  25. loss_dice (:obj:`ConfigDict` or dict): Config of the dice loss.
  26. Defaults to `DiceLoss`.
  27. train_cfg (:obj:`ConfigDict` or dict, optional): Training config of
  28. MaskFormer head.
  29. test_cfg (:obj:`ConfigDict` or dict, optional): Testing config of
  30. MaskFormer head.
  31. init_cfg (:obj:`ConfigDict` or dict or list[:obj:`ConfigDict` or \
  32. dict], optional): Initialization config dict. Defaults to None.
  33. """
  34. def __init__(self,
  35. in_channels: List[int],
  36. feat_channels: int,
  37. out_channels: int,
  38. num_things_classes: int = 80,
  39. num_stuff_classes: int = 53,
  40. num_queries: int = 100,
  41. pixel_decoder: ConfigType = ...,
  42. enforce_decoder_input_project: bool = False,
  43. transformer_decoder: ConfigType = ...,
  44. positional_encoding: ConfigType = dict(
  45. num_feats=128, normalize=True),
  46. loss_cls: ConfigType = dict(
  47. type='CrossEntropyLoss',
  48. use_sigmoid=False,
  49. loss_weight=1.0,
  50. class_weight=[1.0] * 133 + [0.1]),
  51. loss_mask: ConfigType = dict(
  52. type='FocalLoss',
  53. use_sigmoid=True,
  54. gamma=2.0,
  55. alpha=0.25,
  56. loss_weight=20.0),
  57. loss_dice: ConfigType = dict(
  58. type='DiceLoss',
  59. use_sigmoid=True,
  60. activate=True,
  61. naive_dice=True,
  62. loss_weight=1.0),
  63. train_cfg: OptConfigType = None,
  64. test_cfg: OptConfigType = None,
  65. init_cfg: OptMultiConfig = None,
  66. **kwargs) -> None:
  67. super(AnchorFreeHead, self).__init__(init_cfg=init_cfg)
  68. self.num_things_classes = num_things_classes
  69. self.num_stuff_classes = num_stuff_classes
  70. self.num_classes = self.num_things_classes + self.num_stuff_classes
  71. self.num_queries = num_queries
  72. pixel_decoder.update(
  73. in_channels=in_channels,
  74. feat_channels=feat_channels,
  75. out_channels=out_channels)
  76. self.pixel_decoder = MODELS.build(pixel_decoder)
  77. self.transformer_decoder = DetrTransformerDecoder(
  78. **transformer_decoder)
  79. self.decoder_embed_dims = self.transformer_decoder.embed_dims
  80. if type(self.pixel_decoder) == PixelDecoder and (
  81. self.decoder_embed_dims != in_channels[-1]
  82. or enforce_decoder_input_project):
  83. self.decoder_input_proj = Conv2d(
  84. in_channels[-1], self.decoder_embed_dims, kernel_size=1)
  85. else:
  86. self.decoder_input_proj = nn.Identity()
  87. self.decoder_pe = SinePositionalEncoding(**positional_encoding)
  88. self.query_embed = nn.Embedding(self.num_queries, out_channels)
  89. self.cls_embed = nn.Linear(feat_channels, self.num_classes + 1)
  90. self.mask_embed = nn.Sequential(
  91. nn.Linear(feat_channels, feat_channels), nn.ReLU(inplace=True),
  92. nn.Linear(feat_channels, feat_channels), nn.ReLU(inplace=True),
  93. nn.Linear(feat_channels, out_channels))
  94. self.test_cfg = test_cfg
  95. self.train_cfg = train_cfg
  96. if train_cfg:
  97. self.assigner = TASK_UTILS.build(train_cfg['assigner'])
  98. self.sampler = TASK_UTILS.build(
  99. train_cfg['sampler'], default_args=dict(context=self))
  100. self.class_weight = loss_cls.class_weight
  101. self.loss_cls = MODELS.build(loss_cls)
  102. self.loss_mask = MODELS.build(loss_mask)
  103. self.loss_dice = MODELS.build(loss_dice)
  104. def init_weights(self) -> None:
  105. if isinstance(self.decoder_input_proj, Conv2d):
  106. caffe2_xavier_init(self.decoder_input_proj, bias=0)
  107. self.pixel_decoder.init_weights()
  108. for p in self.transformer_decoder.parameters():
  109. if p.dim() > 1:
  110. nn.init.xavier_uniform_(p)
  111. def preprocess_gt(
  112. self, batch_gt_instances: InstanceList,
  113. batch_gt_semantic_segs: List[Optional[PixelData]]) -> InstanceList:
  114. """Preprocess the ground truth for all images.
  115. Args:
  116. batch_gt_instances (list[:obj:`InstanceData`]): Batch of
  117. gt_instance. It usually includes ``labels``, each is
  118. ground truth labels of each bbox, with shape (num_gts, )
  119. and ``masks``, each is ground truth masks of each instances
  120. of a image, shape (num_gts, h, w).
  121. gt_semantic_seg (list[Optional[PixelData]]): Ground truth of
  122. semantic segmentation, each with the shape (1, h, w).
  123. [0, num_thing_class - 1] means things,
  124. [num_thing_class, num_class-1] means stuff,
  125. 255 means VOID. It's None when training instance segmentation.
  126. Returns:
  127. list[obj:`InstanceData`]: each contains the following keys
  128. - labels (Tensor): Ground truth class indices\
  129. for a image, with shape (n, ), n is the sum of\
  130. number of stuff type and number of instance in a image.
  131. - masks (Tensor): Ground truth mask for a\
  132. image, with shape (n, h, w).
  133. """
  134. num_things_list = [self.num_things_classes] * len(batch_gt_instances)
  135. num_stuff_list = [self.num_stuff_classes] * len(batch_gt_instances)
  136. gt_labels_list = [
  137. gt_instances['labels'] for gt_instances in batch_gt_instances
  138. ]
  139. gt_masks_list = [
  140. gt_instances['masks'] for gt_instances in batch_gt_instances
  141. ]
  142. gt_semantic_segs = [
  143. None if gt_semantic_seg is None else gt_semantic_seg.sem_seg
  144. for gt_semantic_seg in batch_gt_semantic_segs
  145. ]
  146. targets = multi_apply(preprocess_panoptic_gt, gt_labels_list,
  147. gt_masks_list, gt_semantic_segs, num_things_list,
  148. num_stuff_list)
  149. labels, masks = targets
  150. batch_gt_instances = [
  151. InstanceData(labels=label, masks=mask)
  152. for label, mask in zip(labels, masks)
  153. ]
  154. return batch_gt_instances
  155. def get_targets(
  156. self,
  157. cls_scores_list: List[Tensor],
  158. mask_preds_list: List[Tensor],
  159. batch_gt_instances: InstanceList,
  160. batch_img_metas: List[dict],
  161. return_sampling_results: bool = False
  162. ) -> Tuple[List[Union[Tensor, int]]]:
  163. """Compute classification and mask targets for all images for a decoder
  164. layer.
  165. Args:
  166. cls_scores_list (list[Tensor]): Mask score logits from a single
  167. decoder layer for all images. Each with shape (num_queries,
  168. cls_out_channels).
  169. mask_preds_list (list[Tensor]): Mask logits from a single decoder
  170. layer for all images. Each with shape (num_queries, h, w).
  171. batch_gt_instances (list[obj:`InstanceData`]): each contains
  172. ``labels`` and ``masks``.
  173. batch_img_metas (list[dict]): List of image meta information.
  174. return_sampling_results (bool): Whether to return the sampling
  175. results. Defaults to False.
  176. Returns:
  177. tuple: a tuple containing the following targets.
  178. - labels_list (list[Tensor]): Labels of all images.\
  179. Each with shape (num_queries, ).
  180. - label_weights_list (list[Tensor]): Label weights\
  181. of all images. Each with shape (num_queries, ).
  182. - mask_targets_list (list[Tensor]): Mask targets of\
  183. all images. Each with shape (num_queries, h, w).
  184. - mask_weights_list (list[Tensor]): Mask weights of\
  185. all images. Each with shape (num_queries, ).
  186. - avg_factor (int): Average factor that is used to average\
  187. the loss. When using sampling method, avg_factor is
  188. usually the sum of positive and negative priors. When
  189. using `MaskPseudoSampler`, `avg_factor` is usually equal
  190. to the number of positive priors.
  191. additional_returns: This function enables user-defined returns from
  192. `self._get_targets_single`. These returns are currently refined
  193. to properties at each feature map (i.e. having HxW dimension).
  194. The results will be concatenated after the end.
  195. """
  196. results = multi_apply(self._get_targets_single, cls_scores_list,
  197. mask_preds_list, batch_gt_instances,
  198. batch_img_metas)
  199. (labels_list, label_weights_list, mask_targets_list, mask_weights_list,
  200. pos_inds_list, neg_inds_list, sampling_results_list) = results[:7]
  201. rest_results = list(results[7:])
  202. avg_factor = sum(
  203. [results.avg_factor for results in sampling_results_list])
  204. res = (labels_list, label_weights_list, mask_targets_list,
  205. mask_weights_list, avg_factor)
  206. if return_sampling_results:
  207. res = res + (sampling_results_list)
  208. return res + tuple(rest_results)
  209. def _get_targets_single(self, cls_score: Tensor, mask_pred: Tensor,
  210. gt_instances: InstanceData,
  211. img_meta: dict) -> Tuple[Tensor]:
  212. """Compute classification and mask targets for one image.
  213. Args:
  214. cls_score (Tensor): Mask score logits from a single decoder layer
  215. for one image. Shape (num_queries, cls_out_channels).
  216. mask_pred (Tensor): Mask logits for a single decoder layer for one
  217. image. Shape (num_queries, h, w).
  218. gt_instances (:obj:`InstanceData`): It contains ``labels`` and
  219. ``masks``.
  220. img_meta (dict): Image informtation.
  221. Returns:
  222. tuple: a tuple containing the following for one image.
  223. - labels (Tensor): Labels of each image.
  224. shape (num_queries, ).
  225. - label_weights (Tensor): Label weights of each image.
  226. shape (num_queries, ).
  227. - mask_targets (Tensor): Mask targets of each image.
  228. shape (num_queries, h, w).
  229. - mask_weights (Tensor): Mask weights of each image.
  230. shape (num_queries, ).
  231. - pos_inds (Tensor): Sampled positive indices for each image.
  232. - neg_inds (Tensor): Sampled negative indices for each image.
  233. - sampling_result (:obj:`SamplingResult`): Sampling results.
  234. """
  235. gt_masks = gt_instances.masks
  236. gt_labels = gt_instances.labels
  237. target_shape = mask_pred.shape[-2:]
  238. if gt_masks.shape[0] > 0:
  239. gt_masks_downsampled = F.interpolate(
  240. gt_masks.unsqueeze(1).float(), target_shape,
  241. mode='nearest').squeeze(1).long()
  242. else:
  243. gt_masks_downsampled = gt_masks
  244. pred_instances = InstanceData(scores=cls_score, masks=mask_pred)
  245. downsampled_gt_instances = InstanceData(
  246. labels=gt_labels, masks=gt_masks_downsampled)
  247. # assign and sample
  248. assign_result = self.assigner.assign( # 标签分配
  249. pred_instances=pred_instances,
  250. gt_instances=downsampled_gt_instances,
  251. img_meta=img_meta)
  252. sampling_result = self.sampler.sample(
  253. assign_result=assign_result,
  254. pred_instances=pred_instances,
  255. gt_instances=gt_instances)
  256. pos_inds = sampling_result.pos_inds
  257. neg_inds = sampling_result.neg_inds
  258. # label target
  259. labels = gt_labels.new_full((self.num_queries, ),
  260. self.num_classes,
  261. dtype=torch.long)
  262. labels[pos_inds] = gt_labels[sampling_result.pos_assigned_gt_inds]
  263. label_weights = gt_labels.new_ones(self.num_queries)
  264. # mask target
  265. mask_targets = gt_masks[sampling_result.pos_assigned_gt_inds]
  266. mask_weights = mask_pred.new_zeros((self.num_queries, ))
  267. mask_weights[pos_inds] = 1.0
  268. return (labels, label_weights, mask_targets, mask_weights, pos_inds,
  269. neg_inds, sampling_result)
  270. def loss_by_feat(self, all_cls_scores: Tensor, all_mask_preds: Tensor,
  271. batch_gt_instances: List[InstanceData],
  272. batch_img_metas: List[dict]) -> Dict[str, Tensor]:
  273. """Loss function.
  274. Args:
  275. all_cls_scores (Tensor): Classification scores for all decoder
  276. layers with shape (num_decoder, batch_size, num_queries,
  277. cls_out_channels). Note `cls_out_channels` should includes
  278. background.
  279. all_mask_preds (Tensor): Mask scores for all decoder layers with
  280. shape (num_decoder, batch_size, num_queries, h, w).
  281. batch_gt_instances (list[obj:`InstanceData`]): each contains
  282. ``labels`` and ``masks``.
  283. batch_img_metas (list[dict]): List of image meta information.
  284. Returns:
  285. dict[str, Tensor]: A dictionary of loss components.
  286. """
  287. num_dec_layers = len(all_cls_scores)
  288. batch_gt_instances_list = [
  289. batch_gt_instances for _ in range(num_dec_layers)
  290. ]
  291. img_metas_list = [batch_img_metas for _ in range(num_dec_layers)] # 每一层做处理
  292. losses_cls, losses_mask, losses_dice = multi_apply( # 计算损失
  293. self._loss_by_feat_single, all_cls_scores, all_mask_preds,
  294. batch_gt_instances_list, img_metas_list)
  295. loss_dict = dict()
  296. # loss from the last decoder layer
  297. loss_dict['loss_cls'] = losses_cls[-1]
  298. loss_dict['loss_mask'] = losses_mask[-1]
  299. loss_dict['loss_dice'] = losses_dice[-1]
  300. # loss from other decoder layers
  301. num_dec_layer = 0
  302. for loss_cls_i, loss_mask_i, loss_dice_i in zip(
  303. losses_cls[:-1], losses_mask[:-1], losses_dice[:-1]):
  304. loss_dict[f'd{num_dec_layer}.loss_cls'] = loss_cls_i
  305. loss_dict[f'd{num_dec_layer}.loss_mask'] = loss_mask_i
  306. loss_dict[f'd{num_dec_layer}.loss_dice'] = loss_dice_i
  307. num_dec_layer += 1
  308. return loss_dict
  309. def _loss_by_feat_single(self, cls_scores: Tensor, mask_preds: Tensor,
  310. batch_gt_instances: List[InstanceData],
  311. batch_img_metas: List[dict]) -> Tuple[Tensor]:
  312. """Loss function for outputs from a single decoder layer.
  313. Args:
  314. cls_scores (Tensor): Mask score logits from a single decoder layer
  315. for all images. Shape (batch_size, num_queries,
  316. cls_out_channels). Note `cls_out_channels` should includes
  317. background.
  318. mask_preds (Tensor): Mask logits for a pixel decoder for all
  319. images. Shape (batch_size, num_queries, h, w).
  320. batch_gt_instances (list[obj:`InstanceData`]): each contains
  321. ``labels`` and ``masks``.
  322. batch_img_metas (list[dict]): List of image meta information.
  323. Returns:
  324. tuple[Tensor]: Loss components for outputs from a single decoder\
  325. layer.
  326. """
  327. num_imgs = cls_scores.size(0)
  328. cls_scores_list = [cls_scores[i] for i in range(num_imgs)] # 取出每一个cls score和mask preds
  329. mask_preds_list = [mask_preds[i] for i in range(num_imgs)]
  330. # 分配标签
  331. (labels_list, label_weights_list, mask_targets_list, mask_weights_list,
  332. avg_factor) = self.get_targets(cls_scores_list, mask_preds_list,
  333. batch_gt_instances, batch_img_metas)
  334. # shape (batch_size, num_queries)
  335. labels = torch.stack(labels_list, dim=0)
  336. # shape (batch_size, num_queries)
  337. label_weights = torch.stack(label_weights_list, dim=0)
  338. # shape (num_total_gts, h, w)
  339. mask_targets = torch.cat(mask_targets_list, dim=0)
  340. # shape (batch_size, num_queries)
  341. mask_weights = torch.stack(mask_weights_list, dim=0)
  342. # classfication loss 分配标签后实际计算损失
  343. # shape (batch_size * num_queries, )
  344. cls_scores = cls_scores.flatten(0, 1)
  345. labels = labels.flatten(0, 1)
  346. label_weights = label_weights.flatten(0, 1)
  347. class_weight = cls_scores.new_tensor(self.class_weight)
  348. loss_cls = self.loss_cls(
  349. cls_scores,
  350. labels,
  351. label_weights,
  352. avg_factor=class_weight[labels].sum())
  353. num_total_masks = reduce_mean(cls_scores.new_tensor([avg_factor]))
  354. num_total_masks = max(num_total_masks, 1)
  355. # extract positive ones
  356. # shape (batch_size, num_queries, h, w) -> (num_total_gts, h, w)
  357. mask_preds = mask_preds[mask_weights > 0] # 取出有实例的位置
  358. target_shape = mask_targets.shape[-2:]
  359. if mask_targets.shape[0] == 0:
  360. # zero match
  361. loss_dice = mask_preds.sum()
  362. loss_mask = mask_preds.sum()
  363. return loss_cls, loss_mask, loss_dice
  364. # upsample to shape of target
  365. # shape (num_total_gts, h, w)
  366. mask_preds = F.interpolate(
  367. mask_preds.unsqueeze(1),
  368. target_shape,
  369. mode='bilinear',
  370. align_corners=False).squeeze(1)
  371. # dice loss
  372. loss_dice = self.loss_dice(
  373. mask_preds, mask_targets, avg_factor=num_total_masks)
  374. # mask loss
  375. # FocalLoss support input of shape (n, num_class)
  376. h, w = mask_preds.shape[-2:]
  377. # shape (num_total_gts, h, w) -> (num_total_gts * h * w, 1)
  378. mask_preds = mask_preds.reshape(-1, 1)
  379. # shape (num_total_gts, h, w) -> (num_total_gts * h * w)
  380. mask_targets = mask_targets.reshape(-1)
  381. # target is (1 - mask_targets) !!!
  382. loss_mask = self.loss_mask(
  383. mask_preds, 1 - mask_targets, avg_factor=num_total_masks * h * w)
  384. return loss_cls, loss_mask, loss_dice
  385. def forward(self, x: Tuple[Tensor],
  386. batch_data_samples: SampleList) -> Tuple[Tensor]:
  387. """Forward function.
  388. Args:
  389. x (tuple[Tensor]): Features from the upstream network, each
  390. is a 4D-tensor.
  391. batch_data_samples (List[:obj:`DetDataSample`]): The Data
  392. Samples. It usually includes information such as
  393. `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`.
  394. Returns:
  395. tuple[Tensor]: a tuple contains two elements.
  396. - all_cls_scores (Tensor): Classification scores for each\
  397. scale level. Each is a 4D-tensor with shape\
  398. (num_decoder, batch_size, num_queries, cls_out_channels).\
  399. Note `cls_out_channels` should includes background.
  400. - all_mask_preds (Tensor): Mask scores for each decoder\
  401. layer. Each with shape (num_decoder, batch_size,\
  402. num_queries, h, w).
  403. """
  404. batch_img_metas = [
  405. data_sample.metainfo for data_sample in batch_data_samples
  406. ]
  407. batch_size = x[0].shape[0]
  408. input_img_h, input_img_w = batch_img_metas[0]['batch_input_shape']
  409. padding_mask = x[-1].new_ones((batch_size, input_img_h, input_img_w),
  410. dtype=torch.float32)
  411. for i in range(batch_size):
  412. img_h, img_w = batch_img_metas[i]['img_shape']
  413. padding_mask[i, :img_h, :img_w] = 0
  414. padding_mask = F.interpolate(
  415. padding_mask.unsqueeze(1), size=x[-1].shape[-2:],
  416. mode='nearest').to(torch.bool).squeeze(1)
  417. # when backbone is swin, memory is output of last stage of swin.
  418. # when backbone is r50, memory is output of tranformer encoder.
  419. mask_features, memory = self.pixel_decoder(x, batch_img_metas)
  420. pos_embed = self.decoder_pe(padding_mask)
  421. memory = self.decoder_input_proj(memory)
  422. # shape (batch_size, c, h, w) -> (batch_size, h*w, c)
  423. memory = memory.flatten(2).permute(0, 2, 1)
  424. pos_embed = pos_embed.flatten(2).permute(0, 2, 1)
  425. # shape (batch_size, h * w)
  426. padding_mask = padding_mask.flatten(1)
  427. # shape = (num_queries, embed_dims)
  428. query_embed = self.query_embed.weight
  429. # shape = (batch_size, num_queries, embed_dims)
  430. query_embed = query_embed.unsqueeze(0).repeat(batch_size, 1, 1)
  431. target = torch.zeros_like(query_embed)
  432. # shape (num_decoder, num_queries, batch_size, embed_dims)
  433. out_dec = self.transformer_decoder(
  434. query=target,
  435. key=memory,
  436. value=memory,
  437. query_pos=query_embed,
  438. key_pos=pos_embed,
  439. key_padding_mask=padding_mask)
  440. # cls_scores
  441. all_cls_scores = self.cls_embed(out_dec)
  442. # mask_preds
  443. mask_embed = self.mask_embed(out_dec)
  444. all_mask_preds = torch.einsum('lbqc,bchw->lbqhw', mask_embed,
  445. mask_features)
  446. return all_cls_scores, all_mask_preds
  447. def loss(
  448. self,
  449. x: Tuple[Tensor],
  450. batch_data_samples: SampleList,
  451. ) -> Dict[str, Tensor]:
  452. """Perform forward propagation and loss calculation of the panoptic
  453. head on the features of the upstream network.
  454. Args:
  455. x (tuple[Tensor]): Multi-level features from the upstream
  456. network, each is a 4D-tensor.
  457. batch_data_samples (List[:obj:`DetDataSample`]): The Data
  458. Samples. It usually includes information such as
  459. `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`.
  460. Returns:
  461. dict[str, Tensor]: a dictionary of loss components
  462. """
  463. batch_img_metas = []
  464. batch_gt_instances = []
  465. batch_gt_semantic_segs = []
  466. for data_sample in batch_data_samples:
  467. batch_img_metas.append(data_sample.metainfo)
  468. batch_gt_instances.append(data_sample.gt_instances)
  469. if 'gt_sem_seg' in data_sample:
  470. batch_gt_semantic_segs.append(data_sample.gt_sem_seg)
  471. else:
  472. batch_gt_semantic_segs.append(None)
  473. # forward
  474. all_cls_scores, all_mask_preds = self(x, batch_data_samples)
  475. # preprocess ground truth
  476. batch_gt_instances = self.preprocess_gt(batch_gt_instances,
  477. batch_gt_semantic_segs)
  478. # loss
  479. losses = self.loss_by_feat(all_cls_scores, all_mask_preds,
  480. batch_gt_instances, batch_img_metas)
  481. return losses
  482. def predict(self, x: Tuple[Tensor],
  483. batch_data_samples: SampleList) -> Tuple[Tensor]:
  484. """Test without augmentaton.
  485. Args:
  486. x (tuple[Tensor]): Multi-level features from the
  487. upstream network, each is a 4D-tensor.
  488. batch_data_samples (List[:obj:`DetDataSample`]): The Data
  489. Samples. It usually includes information such as
  490. `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`.
  491. Returns:
  492. tuple[Tensor]: A tuple contains two tensors.
  493. - mask_cls_results (Tensor): Mask classification logits,\
  494. shape (batch_size, num_queries, cls_out_channels).
  495. Note `cls_out_channels` should includes background.
  496. - mask_pred_results (Tensor): Mask logits, shape \
  497. (batch_size, num_queries, h, w).
  498. """
  499. batch_img_metas = [
  500. data_sample.metainfo for data_sample in batch_data_samples
  501. ]
  502. all_cls_scores, all_mask_preds = self(x, batch_data_samples)
  503. mask_cls_results = all_cls_scores[-1]
  504. mask_pred_results = all_mask_preds[-1]
  505. # upsample masks
  506. img_shape = batch_img_metas[0]['batch_input_shape']
  507. mask_pred_results = F.interpolate(
  508. mask_pred_results,
  509. size=(img_shape[0], img_shape[1]),
  510. mode='bilinear',
  511. align_corners=False)
  512. return mask_cls_results, mask_pred_results

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/我家自动化/article/detail/737695
推荐阅读
相关标签
  

闽ICP备14008679号