当前位置:   article > 正文

SAM之MaskDecoder总结(个人研究)_sam输出的mask

sam输出的mask

SAM模型总览

前言

        SAM模型主要由三部分组成,本文旨在结合代码的方式对其中mask decoder进行详细的描述与总结,不足之处还望多多指教

        mask decoder模块可以高效的将image embeddings与prompt embeddings映射到一个output mask中。为了结合image embeddings与prompt embeddings 这两个输入, 受Transformer模型的启发,作者修改了transformer中标准的Transformer Decoder作为本文所要介绍的Mask Decoder。

Mask Decoder定义

可以结合论文中该图看下列代码:

  1. def __init__(
  2. self,
  3. *,
  4. transformer_dim: int,
  5. transformer: nn.Module,
  6. num_multimask_outputs: int = 3,
  7. activation: Type[nn.Module] = nn.GELU,
  8. iou_head_depth: int = 3,
  9. iou_head_hidden_dim: int = 256,
  10. ) -> None:
  11. super().__init__()
  12. self.transformer_dim = transformer_dim
  13. self.transformer = transformer
  14. self.num_multimask_outputs = num_multimask_outputs
  15. self.iou_token = nn.Embedding(1, transformer_dim)
  16. self.num_mask_tokens = num_multimask_outputs + 1
  17. self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim)
  18. # ---upscaled---
  19. # ---四倍上采样---
  20. self.output_upscaling = nn.Sequential(
  21. nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2),
  22. LayerNorm2d(transformer_dim // 4),
  23. activation(),
  24. nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2),
  25. activation(),
  26. )
  27. # ---upscaled end---
  28. # ---MLP---
  29. # ---对应mask数量的mlp---
  30. self.output_hypernetworks_mlps = nn.ModuleList(
  31. [
  32. MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3)
  33. for i in range(self.num_mask_tokens)
  34. ]
  35. )
  36. # ---对应mask数量的mlp end---
  37. # ---对应iou的mlp--
  38. self.iou_prediction_head = MLP(
  39. transformer_dim, iou_head_hidden_dim, self.num_mask_tokens, iou_head_depth
  40. )
  41. # ---对应iou的mlp end--
  42. # ---MLP end---

先看传入的参数,如下:

transformer dim:表示transformer channel的维度。

transformer:表示传入的transformer模型

num_multimask_outputs表示mask decoder输出的mask个数(用于消除歧义), SAM原文中默认值为3。

activation:表示mask decoder 上采样过程中使用的激活函数。

iou_head_depth表示用于预测mask的IoU质量指标时,所使用MLP层的深度。

iou_head_hidden_dim: 表示用于预测mask的IoU质量指标时,所使用MLP层中隐藏层的维度。

在下面代码中,由于在该阶段引入了一个额外的iou_token用于计算所预测mask的质量,因此要在该处+1。

self.num_mask_tokens = num_multimask_outputs + 1

Mask Decoder前向传播

  1. def forward(
  2. self,
  3. image_embeddings: torch.Tensor, # image encoder 输出的image embedding
  4. image_pe: torch.Tensor, # image的position embedding
  5. sparse_prompt_embeddings: torch.Tensor, # prompt encoder输出的sparse prompt
  6. dense_prompt_embeddings: torch.Tensor, # prompt encoder 输出的dense prompt
  7. multimask_output: bool, # 多类别输出,具有模糊识别的能力
  8. ) -> Tuple[torch.Tensor, torch.Tensor]:
  9. masks, iou_pred = self.predict_masks(
  10. image_embeddings=image_embeddings,
  11. image_pe=image_pe,
  12. sparse_prompt_embeddings=sparse_prompt_embeddings,
  13. dense_prompt_embeddings=dense_prompt_embeddings,
  14. )
  15. # 根据multimask_output的bool值来对masks以及iou_pred进行选择性的切片
  16. if multimask_output:
  17. mask_slice = slice(1, None) # 为真时, 切片选择从第一个元素到最后一个
  18. else:
  19. mask_slice = slice(0, 1) # 为假时,只选择第一个切片
  20. masks = masks[:, mask_slice, :, :]
  21. iou_pred = iou_pred[:, mask_slice]
  22. # 返回mask 以及 iou的预测分数
  23. return masks, iou_pred

Mask Decoder中predict_masks

  1. def predict_masks(
  2. self,
  3. image_embeddings: torch.Tensor,
  4. image_pe: torch.Tensor,
  5. sparse_prompt_embeddings: torch.Tensor,
  6. dense_prompt_embeddings: torch.Tensor,
  7. ) -> Tuple[torch.Tensor, torch.Tensor]:
  8. """Predicts masks. See 'forward' for more details."""
  9. # Concatenate output tokens
  10. output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0)
  11. output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1)
  12. tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1) # 对应 output_tokens + prompt tokens
  13. # Expand per-image data in batch direction to be per-mask
  14. # 扩展image_embeddings的B维度,因为boxes标记分割时,n个box时batchsize=batchsize*n
  15. if image_embeddings.shape[0] != tokens.shape[0]:
  16. # torch.repeat_interleave() 沿着指定的维度重复张量的元素
  17. # image_embeddings 相当于待重复的张量元素
  18. # tokens.shape[0] 相当于重复次数
  19. src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)
  20. else:
  21. src = image_embeddings
  22. src = src + dense_prompt_embeddings # 对应 image embedding + dense_prompt_embeddings(mask)
  23. pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)
  24. b, c, h, w = src.shape
  25. # Run the transformer
  26. # hs代表transformer的输出隐藏状态 , 而src代表transformer的输入
  27. hs, src = self.transformer(src, pos_src, tokens)
  28. # [hs]是Transformer的输出,其中第一个维度表示batch中的不同样本,第二个维度表示token的序列,第三个维度表示token的特征维度。
  29. # 通过`hs[:, 0, :]`可以获取第一个token对应的输出,即`iou_token_out`;
  30. # 通过`hs[:, 1 : (1 + self.num_mask_tokens), :]`可以获取接下来的`num_mask_tokens`个token对应的输出,即`mask_tokens_out`。
  31. # 这样的切片操作可以有效地提取出不同类型的token对应的输出。
  32. iou_token_out = hs[:, 0, :]
  33. mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :]
  34. # Upscale mask embeddings and predict masks using the mask tokens
  35. src = src.transpose(1, 2).view(b, c, h, w)
  36. upscaled_embedding = self.output_upscaling(src)
  37. # 用于存储每个mask token 对应的经过 MLP 处理后的输出。这些处理后的输出将被用于生成最终的预测 masks。
  38. hyper_in_list: List[torch.Tensor] = []
  39. # ---MLP---
  40. for i in range(self.num_mask_tokens):
  41. hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :]))
  42. # 将列表 hyper_in_list 中的张量沿着指定维度进行堆叠,生成一个新的张量 hyper_in
  43. hyper_in = torch.stack(hyper_in_list, dim=1)
  44. # ---MLP End---
  45. b, c, h, w = upscaled_embedding.shape
  46. masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w)
  47. # ---MLP---
  48. # Generate mask quality predictions
  49. iou_pred = self.iou_prediction_head(iou_token_out)
  50. # ---MLP End---
  51. return masks, iou_pred

该代码第一部分 Concatenate output tokens的实现原理即为, output_tokens + prompt_tokens的结合。这一部分的理解可以参考该博主的图片。【图像分割】【深度学习】SAM官方Pytorch代码-Mask decoder模块MaskDeco网络解析_sam decoder-CSDN博客

Transformer

  Transformer数据流如图所示:从底部向上看我们发现, output tokens与prompt tokens先流入一个self-attention, 然后在从 token到 image 以及 image 到 token都采用corss-attention机制。在第一个Cross attention机制中(token to image), toke当作q,image embedding当作k与v。在第二个Cross attention中(image to token), image当作q, token充当k与v。【流程见TwoWayAttentionBlock代码】

  1. class TwoWayTransformer(nn.Module):
  2. def __init__(
  3. self,
  4. depth: int,
  5. embedding_dim: int,
  6. num_heads: int,
  7. mlp_dim: int,
  8. activation: Type[nn.Module] = nn.ReLU,
  9. attention_downsample_rate: int = 2, # 下采样
  10. ) -> None:
  11. """
  12. A transformer decoder 尝试对一个输入图片使用带有位置embedding的查询
  13. 由多个transformer block组成, 每个block包含两个attention模块.
  14. 输入是图像的embedding、图像的position embedding和 点的embedding,
  15. 输出是处理后的点的embedding和处理后的图像的embedding。
  16. Args:
  17. depth (int): number of layers in the transformer
  18. embedding_dim (int): the channel dimension for the input embeddings
  19. num_heads (int): the number of heads for multihead attention. Must
  20. divide embedding_dim
  21. mlp_dim (int): the channel dimension internal to the MLP block
  22. activation (nn.Module): the activation to use in the MLP block
  23. """
  24. super().__init__()
  25. self.depth = depth
  26. self.embedding_dim = embedding_dim
  27. self.num_heads = num_heads
  28. self.mlp_dim = mlp_dim
  29. self.layers = nn.ModuleList()
  30. for i in range(depth):
  31. self.layers.append(
  32. TwoWayAttentionBlock(
  33. embedding_dim=embedding_dim,
  34. num_heads=num_heads,
  35. mlp_dim=mlp_dim,
  36. activation=activation,
  37. attention_downsample_rate=attention_downsample_rate,
  38. skip_first_layer_pe=(i == 0), # 在第一个循环中 i=0, 说明在TwoWayAttentionBlock前向传播过程中第一次进self attn
  39. )
  40. )
  41. self.final_attn_token_to_image = Attention(
  42. embedding_dim, num_heads, downsample_rate=attention_downsample_rate
  43. )
  44. self.norm_final_attn = nn.LayerNorm(embedding_dim)
  45. def forward(
  46. self,
  47. image_embedding: Tensor,
  48. image_pe: Tensor,
  49. point_embedding: Tensor, # 传入的是token = output_tokens + prompt_tokens
  50. ) -> Tuple[Tensor, Tensor]:
  51. """
  52. 前向传播过程:
  53. (1) 将图像的embedding和position embedding 分别经过一个线性层,
  54. 得到image_embedding 和 image_pe。
  55. (2) 将点嵌入的embedding经过一个线性层,得到point_embedding。
  56. (3) 对 image_embedding 和 point_embedding 进行 transformer block处理,
  57. 得到经过处理的 image_embedding 和 point_embedding。
  58. (4) 对经过处理的 image_embedding 和 point_embedding 进行交叉注意力,
  59. 得到经过处理的 point_embedding 和 image_embedding。
  60. Args:
  61. image_embedding (torch.Tensor): 图像嵌入张量,形状为 B x embedding_dim x h x w。
  62. image_pe (torch.Tensor): 图像的位置编码张量,与 image_embedding 具有相同的形状。
  63. point_embedding (torch.Tensor): 查询点的嵌入张量,形状为 B x N_points x embedding_dim。
  64. Returns:
  65. Tuple[torch.Tensor, torch.Tensor]: 经过处理的 point_embedding 和 image_embedding。
  66. """
  67. # Flatten image embedding to B x N_image_tokens x C
  68. # BxCxHxW -> BxHWxC == B x N_image_tokens x C
  69. bs, c, h, w = image_embedding.shape
  70. image_embedding = image_embedding.flatten(2).permute(0, 2, 1)
  71. image_pe = image_pe.flatten(2).permute(0, 2, 1) # image embedding 对应的 position embedding
  72. # Prepare queries
  73. queries = point_embedding
  74. keys = image_embedding
  75. # Apply transformer blocks and final layernorm
  76. for layer in self.layers:
  77. queries, keys = layer(
  78. queries=queries,
  79. keys=keys,
  80. query_pe=point_embedding, # 第一次添加时, queries与query_pe相同
  81. key_pe=image_pe,
  82. )
  83. # Apply the final attention layer from the points to the image
  84. q = queries + point_embedding
  85. k = keys + image_pe
  86. # # 最后一个cross attn Final attention layer from the points to the image
  87. attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys)
  88. queries = queries + attn_out
  89. queries = self.norm_final_attn(queries)
  90. return queries, keys

TwoWayAttentionBlock

        

  1. class TwoWayAttentionBlock(nn.Module):
  2. # TwoWayAttentionBlock = LayerNorm + Multi-Head Attention + MLP
  3. def __init__(
  4. self,
  5. embedding_dim: int,
  6. num_heads: int,
  7. mlp_dim: int = 2048,
  8. activation: Type[nn.Module] = nn.ReLU,
  9. attention_downsample_rate: int = 2,
  10. skip_first_layer_pe: bool = False,
  11. ) -> None:
  12. """
  13. A transformer block with four layers:
  14. (1) self-attention of sparse inputs,
  15. (2) cross attention of sparse inputs to dense inputs,
  16. (3) mlp block on sparse inputs,
  17. (4) cross attention of dense inputs to sparse
  18. inputs.
  19. Arguments:
  20. embedding_dim (int): the channel dimension of the embeddings
  21. num_heads (int): the number of heads in the attention layers
  22. mlp_dim (int): the hidden dimension of the mlp block
  23. activation (nn.Module): the activation of the mlp block
  24. skip_first_layer_pe (bool): skip the PE on the first layer
  25. """
  26. super().__init__()
  27. self.self_attn = Attention(embedding_dim, num_heads)
  28. self.norm1 = nn.LayerNorm(embedding_dim)
  29. self.cross_attn_token_to_image = Attention(
  30. embedding_dim, num_heads, downsample_rate=attention_downsample_rate
  31. )
  32. self.norm2 = nn.LayerNorm(embedding_dim)
  33. self.mlp = MLPBlock(embedding_dim, mlp_dim, activation)
  34. self.norm3 = nn.LayerNorm(embedding_dim)
  35. self.norm4 = nn.LayerNorm(embedding_dim)
  36. self.cross_attn_image_to_token = Attention(
  37. embedding_dim, num_heads, downsample_rate=attention_downsample_rate
  38. )
  39. self.skip_first_layer_pe = skip_first_layer_pe
  40. def forward(
  41. self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor
  42. ) -> Tuple[Tensor, Tensor]:
  43. # 第一个Self attention 模块。
  44. # 第一轮本身queries==query_pe
  45. if self.skip_first_layer_pe:
  46. queries = self.self_attn(q=queries, k=queries, v=queries)
  47. else:
  48. q = queries + query_pe
  49. attn_out = self.self_attn(q=q, k=q, v=queries)
  50. queries = queries + attn_out
  51. queries = self.norm1(queries)
  52. # 第一个 Cross attention block。 tokens attending to image embedding
  53. # q, k, v不再是来源于同一个序列,而是多个序列. queries + query_pe充当q, k与v都由 keys提供
  54. # tokens to image embedding意味着,将token作为q, image_embedding 作为 k与v
  55. q = queries + query_pe
  56. k = keys + key_pe
  57. attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys)
  58. queries = queries + attn_out
  59. queries = self.norm2(queries)
  60. # MLP block
  61. mlp_out = self.mlp(queries)
  62. queries = queries + mlp_out
  63. queries = self.norm3(queries)
  64. # 第二个 Cross attention block。 image embedding attending to tokens
  65. q = queries + query_pe
  66. k = keys + key_pe
  67. attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries)
  68. keys = keys + attn_out
  69. keys = self.norm4(keys)
  70. return queries, keys

Attention

  1. class Attention(nn.Module):
  2. """
  3. 一个允许下采样embedding size的attention层
  4. """
  5. def __init__(
  6. self,
  7. embedding_dim: int,
  8. num_heads: int,
  9. downsample_rate: int = 1,
  10. ) -> None:
  11. super().__init__()
  12. self.embedding_dim = embedding_dim
  13. self.internal_dim = embedding_dim // downsample_rate
  14. self.num_heads = num_heads
  15. assert self.internal_dim % num_heads == 0, "num_heads must divide embedding_dim."
  16. self.q_proj = nn.Linear(embedding_dim, self.internal_dim)
  17. self.k_proj = nn.Linear(embedding_dim, self.internal_dim)
  18. self.v_proj = nn.Linear(embedding_dim, self.internal_dim)
  19. self.out_proj = nn.Linear(self.internal_dim, embedding_dim)
  20. def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor:
  21. b, n, c = x.shape
  22. x = x.reshape(b, n, num_heads, c // num_heads)
  23. return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head C_per_head表示一个head中有多少个channel
  24. def _recombine_heads(self, x: Tensor) -> Tensor:
  25. b, n_heads, n_tokens, c_per_head = x.shape
  26. x = x.transpose(1, 2)
  27. return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C
  28. def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
  29. # Input projections
  30. q = self.q_proj(q)
  31. k = self.k_proj(k)
  32. v = self.v_proj(v)
  33. # Separate into heads
  34. q = self._separate_heads(q, self.num_heads)
  35. k = self._separate_heads(k, self.num_heads)
  36. v = self._separate_heads(v, self.num_heads)
  37. # Attention
  38. _, _, _, c_per_head = q.shape
  39. attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens
  40. attn = attn / math.sqrt(c_per_head)
  41. attn = torch.softmax(attn, dim=-1)
  42. # Get output
  43. out = attn @ v
  44. out = self._recombine_heads(out)
  45. out = self.out_proj(out)
  46. return out

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/羊村懒王/article/detail/348337
推荐阅读
相关标签
  

闽ICP备14008679号