赞
踩
SAM模型主要由三部分组成,本文旨在结合代码的方式对其中mask decoder进行详细的描述与总结,不足之处还望多多指教
mask decoder模块可以高效的将image embeddings与prompt embeddings映射到一个output mask中。为了结合image embeddings与prompt embeddings 这两个输入, 受Transformer模型的启发,作者修改了transformer中标准的Transformer Decoder作为本文所要介绍的Mask Decoder。
可以结合论文中该图看下列代码:
- def __init__(
- self,
- *,
- transformer_dim: int,
- transformer: nn.Module,
- num_multimask_outputs: int = 3,
- activation: Type[nn.Module] = nn.GELU,
- iou_head_depth: int = 3,
- iou_head_hidden_dim: int = 256,
- ) -> None:
- super().__init__()
- self.transformer_dim = transformer_dim
- self.transformer = transformer
-
- self.num_multimask_outputs = num_multimask_outputs
-
- self.iou_token = nn.Embedding(1, transformer_dim)
- self.num_mask_tokens = num_multimask_outputs + 1
- self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim)
-
- # ---upscaled---
- # ---四倍上采样---
- self.output_upscaling = nn.Sequential(
- nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2),
- LayerNorm2d(transformer_dim // 4),
- activation(),
- nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2),
- activation(),
- )
- # ---upscaled end---
-
- # ---MLP---
- # ---对应mask数量的mlp---
- self.output_hypernetworks_mlps = nn.ModuleList(
- [
- MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3)
- for i in range(self.num_mask_tokens)
- ]
- )
- # ---对应mask数量的mlp end---
-
- # ---对应iou的mlp--
- self.iou_prediction_head = MLP(
- transformer_dim, iou_head_hidden_dim, self.num_mask_tokens, iou_head_depth
- )
- # ---对应iou的mlp end--
- # ---MLP end---
先看传入的参数,如下:
transformer dim:表示transformer channel的维度。
transformer:表示传入的transformer模型。
num_multimask_outputs:表示mask decoder输出的mask个数(用于消除歧义), SAM原文中默认值为3。
activation:表示mask decoder 上采样过程中使用的激活函数。
iou_head_depth:表示用于预测mask的IoU质量指标时,所使用MLP层的深度。
iou_head_hidden_dim: 表示用于预测mask的IoU质量指标时,所使用MLP层中隐藏层的维度。
在下面代码中,由于在该阶段引入了一个额外的iou_token用于计算所预测mask的质量,因此要在该处+1。
self.num_mask_tokens = num_multimask_outputs + 1
- def forward(
- self,
- image_embeddings: torch.Tensor, # image encoder 输出的image embedding
- image_pe: torch.Tensor, # image的position embedding
- sparse_prompt_embeddings: torch.Tensor, # prompt encoder输出的sparse prompt
- dense_prompt_embeddings: torch.Tensor, # prompt encoder 输出的dense prompt
- multimask_output: bool, # 多类别输出,具有模糊识别的能力
- ) -> Tuple[torch.Tensor, torch.Tensor]:
-
- masks, iou_pred = self.predict_masks(
- image_embeddings=image_embeddings,
- image_pe=image_pe,
- sparse_prompt_embeddings=sparse_prompt_embeddings,
- dense_prompt_embeddings=dense_prompt_embeddings,
- )
-
- # 根据multimask_output的bool值来对masks以及iou_pred进行选择性的切片
- if multimask_output:
- mask_slice = slice(1, None) # 为真时, 切片选择从第一个元素到最后一个
- else:
- mask_slice = slice(0, 1) # 为假时,只选择第一个切片
- masks = masks[:, mask_slice, :, :]
- iou_pred = iou_pred[:, mask_slice]
-
- # 返回mask 以及 iou的预测分数
- return masks, iou_pred
- def predict_masks(
- self,
- image_embeddings: torch.Tensor,
- image_pe: torch.Tensor,
- sparse_prompt_embeddings: torch.Tensor,
- dense_prompt_embeddings: torch.Tensor,
- ) -> Tuple[torch.Tensor, torch.Tensor]:
- """Predicts masks. See 'forward' for more details."""
- # Concatenate output tokens
- output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0)
- output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1)
- tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1) # 对应 output_tokens + prompt tokens
-
- # Expand per-image data in batch direction to be per-mask
- # 扩展image_embeddings的B维度,因为boxes标记分割时,n个box时batchsize=batchsize*n
- if image_embeddings.shape[0] != tokens.shape[0]:
- # torch.repeat_interleave() 沿着指定的维度重复张量的元素
- # image_embeddings 相当于待重复的张量元素
- # tokens.shape[0] 相当于重复次数
- src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)
- else:
- src = image_embeddings
- src = src + dense_prompt_embeddings # 对应 image embedding + dense_prompt_embeddings(mask)
- pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)
- b, c, h, w = src.shape
-
- # Run the transformer
- # hs代表transformer的输出隐藏状态 , 而src代表transformer的输入
- hs, src = self.transformer(src, pos_src, tokens)
- # [hs]是Transformer的输出,其中第一个维度表示batch中的不同样本,第二个维度表示token的序列,第三个维度表示token的特征维度。
- # 通过`hs[:, 0, :]`可以获取第一个token对应的输出,即`iou_token_out`;
- # 通过`hs[:, 1 : (1 + self.num_mask_tokens), :]`可以获取接下来的`num_mask_tokens`个token对应的输出,即`mask_tokens_out`。
- # 这样的切片操作可以有效地提取出不同类型的token对应的输出。
- iou_token_out = hs[:, 0, :]
- mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :]
-
- # Upscale mask embeddings and predict masks using the mask tokens
- src = src.transpose(1, 2).view(b, c, h, w)
- upscaled_embedding = self.output_upscaling(src)
- # 用于存储每个mask token 对应的经过 MLP 处理后的输出。这些处理后的输出将被用于生成最终的预测 masks。
- hyper_in_list: List[torch.Tensor] = []
- # ---MLP---
- for i in range(self.num_mask_tokens):
- hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :]))
- # 将列表 hyper_in_list 中的张量沿着指定维度进行堆叠,生成一个新的张量 hyper_in
- hyper_in = torch.stack(hyper_in_list, dim=1)
- # ---MLP End---
-
- b, c, h, w = upscaled_embedding.shape
- masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w)
-
- # ---MLP---
- # Generate mask quality predictions
- iou_pred = self.iou_prediction_head(iou_token_out)
- # ---MLP End---
-
- return masks, iou_pred
该代码第一部分 Concatenate output tokens的实现原理即为, output_tokens + prompt_tokens的结合。这一部分的理解可以参考该博主的图片。【图像分割】【深度学习】SAM官方Pytorch代码-Mask decoder模块MaskDeco网络解析_sam decoder-CSDN博客
Transformer数据流如图所示:从底部向上看我们发现, output tokens与prompt tokens先流入一个self-attention, 然后在从 token到 image 以及 image 到 token都采用corss-attention机制。在第一个Cross attention机制中(token to image), toke当作q,image embedding当作k与v。在第二个Cross attention中(image to token), image当作q, token充当k与v。【流程见TwoWayAttentionBlock代码】
- class TwoWayTransformer(nn.Module):
- def __init__(
- self,
- depth: int,
- embedding_dim: int,
- num_heads: int,
- mlp_dim: int,
- activation: Type[nn.Module] = nn.ReLU,
- attention_downsample_rate: int = 2, # 下采样
- ) -> None:
- """
- A transformer decoder 尝试对一个输入图片使用带有位置embedding的查询
- 由多个transformer block组成, 每个block包含两个attention模块.
- 输入是图像的embedding、图像的position embedding和 点的embedding,
- 输出是处理后的点的embedding和处理后的图像的embedding。
- Args:
- depth (int): number of layers in the transformer
- embedding_dim (int): the channel dimension for the input embeddings
- num_heads (int): the number of heads for multihead attention. Must
- divide embedding_dim
- mlp_dim (int): the channel dimension internal to the MLP block
- activation (nn.Module): the activation to use in the MLP block
- """
- super().__init__()
- self.depth = depth
- self.embedding_dim = embedding_dim
- self.num_heads = num_heads
- self.mlp_dim = mlp_dim
- self.layers = nn.ModuleList()
-
- for i in range(depth):
- self.layers.append(
- TwoWayAttentionBlock(
- embedding_dim=embedding_dim,
- num_heads=num_heads,
- mlp_dim=mlp_dim,
- activation=activation,
- attention_downsample_rate=attention_downsample_rate,
- skip_first_layer_pe=(i == 0), # 在第一个循环中 i=0, 说明在TwoWayAttentionBlock前向传播过程中第一次进self attn
- )
- )
-
-
- self.final_attn_token_to_image = Attention(
- embedding_dim, num_heads, downsample_rate=attention_downsample_rate
- )
- self.norm_final_attn = nn.LayerNorm(embedding_dim)
-
- def forward(
- self,
- image_embedding: Tensor,
- image_pe: Tensor,
- point_embedding: Tensor, # 传入的是token = output_tokens + prompt_tokens
- ) -> Tuple[Tensor, Tensor]:
- """
- 前向传播过程:
- (1) 将图像的embedding和position embedding 分别经过一个线性层,
- 得到image_embedding 和 image_pe。
- (2) 将点嵌入的embedding经过一个线性层,得到point_embedding。
- (3) 对 image_embedding 和 point_embedding 进行 transformer block处理,
- 得到经过处理的 image_embedding 和 point_embedding。
- (4) 对经过处理的 image_embedding 和 point_embedding 进行交叉注意力,
- 得到经过处理的 point_embedding 和 image_embedding。
-
- Args:
- image_embedding (torch.Tensor): 图像嵌入张量,形状为 B x embedding_dim x h x w。
- image_pe (torch.Tensor): 图像的位置编码张量,与 image_embedding 具有相同的形状。
- point_embedding (torch.Tensor): 查询点的嵌入张量,形状为 B x N_points x embedding_dim。
-
- Returns:
- Tuple[torch.Tensor, torch.Tensor]: 经过处理的 point_embedding 和 image_embedding。
-
- """
- # Flatten image embedding to B x N_image_tokens x C
- # BxCxHxW -> BxHWxC == B x N_image_tokens x C
- bs, c, h, w = image_embedding.shape
- image_embedding = image_embedding.flatten(2).permute(0, 2, 1)
- image_pe = image_pe.flatten(2).permute(0, 2, 1) # image embedding 对应的 position embedding
-
- # Prepare queries
- queries = point_embedding
- keys = image_embedding
-
- # Apply transformer blocks and final layernorm
- for layer in self.layers:
- queries, keys = layer(
- queries=queries,
- keys=keys,
- query_pe=point_embedding, # 第一次添加时, queries与query_pe相同
- key_pe=image_pe,
- )
-
- # Apply the final attention layer from the points to the image
- q = queries + point_embedding
- k = keys + image_pe
- # # 最后一个cross attn Final attention layer from the points to the image
- attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys)
- queries = queries + attn_out
- queries = self.norm_final_attn(queries)
-
- return queries, keys
- class TwoWayAttentionBlock(nn.Module):
- # TwoWayAttentionBlock = LayerNorm + Multi-Head Attention + MLP
- def __init__(
- self,
- embedding_dim: int,
- num_heads: int,
- mlp_dim: int = 2048,
- activation: Type[nn.Module] = nn.ReLU,
- attention_downsample_rate: int = 2,
- skip_first_layer_pe: bool = False,
- ) -> None:
- """
- A transformer block with four layers:
- (1) self-attention of sparse inputs,
- (2) cross attention of sparse inputs to dense inputs,
- (3) mlp block on sparse inputs,
- (4) cross attention of dense inputs to sparse
- inputs.
- Arguments:
- embedding_dim (int): the channel dimension of the embeddings
- num_heads (int): the number of heads in the attention layers
- mlp_dim (int): the hidden dimension of the mlp block
- activation (nn.Module): the activation of the mlp block
- skip_first_layer_pe (bool): skip the PE on the first layer
- """
- super().__init__()
- self.self_attn = Attention(embedding_dim, num_heads)
- self.norm1 = nn.LayerNorm(embedding_dim)
-
- self.cross_attn_token_to_image = Attention(
- embedding_dim, num_heads, downsample_rate=attention_downsample_rate
- )
- self.norm2 = nn.LayerNorm(embedding_dim)
-
- self.mlp = MLPBlock(embedding_dim, mlp_dim, activation)
- self.norm3 = nn.LayerNorm(embedding_dim)
-
- self.norm4 = nn.LayerNorm(embedding_dim)
- self.cross_attn_image_to_token = Attention(
- embedding_dim, num_heads, downsample_rate=attention_downsample_rate
- )
-
- self.skip_first_layer_pe = skip_first_layer_pe
-
- def forward(
- self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor
- ) -> Tuple[Tensor, Tensor]:
-
- # 第一个Self attention 模块。
- # 第一轮本身queries==query_pe
- if self.skip_first_layer_pe:
- queries = self.self_attn(q=queries, k=queries, v=queries)
- else:
- q = queries + query_pe
- attn_out = self.self_attn(q=q, k=q, v=queries)
- queries = queries + attn_out
- queries = self.norm1(queries)
-
- # 第一个 Cross attention block。 tokens attending to image embedding
- # q, k, v不再是来源于同一个序列,而是多个序列. queries + query_pe充当q, k与v都由 keys提供
- # tokens to image embedding意味着,将token作为q, image_embedding 作为 k与v
- q = queries + query_pe
- k = keys + key_pe
- attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys)
- queries = queries + attn_out
- queries = self.norm2(queries)
-
- # MLP block
- mlp_out = self.mlp(queries)
- queries = queries + mlp_out
- queries = self.norm3(queries)
-
- # 第二个 Cross attention block。 image embedding attending to tokens
- q = queries + query_pe
- k = keys + key_pe
- attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries)
- keys = keys + attn_out
- keys = self.norm4(keys)
-
- return queries, keys
- class Attention(nn.Module):
- """
- 一个允许下采样embedding size的attention层
- """
-
- def __init__(
- self,
- embedding_dim: int,
- num_heads: int,
- downsample_rate: int = 1,
- ) -> None:
- super().__init__()
- self.embedding_dim = embedding_dim
- self.internal_dim = embedding_dim // downsample_rate
- self.num_heads = num_heads
- assert self.internal_dim % num_heads == 0, "num_heads must divide embedding_dim."
-
- self.q_proj = nn.Linear(embedding_dim, self.internal_dim)
- self.k_proj = nn.Linear(embedding_dim, self.internal_dim)
- self.v_proj = nn.Linear(embedding_dim, self.internal_dim)
- self.out_proj = nn.Linear(self.internal_dim, embedding_dim)
-
- def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor:
- b, n, c = x.shape
- x = x.reshape(b, n, num_heads, c // num_heads)
- return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head C_per_head表示一个head中有多少个channel
-
- def _recombine_heads(self, x: Tensor) -> Tensor:
- b, n_heads, n_tokens, c_per_head = x.shape
- x = x.transpose(1, 2)
- return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C
-
- def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
- # Input projections
- q = self.q_proj(q)
- k = self.k_proj(k)
- v = self.v_proj(v)
-
- # Separate into heads
- q = self._separate_heads(q, self.num_heads)
- k = self._separate_heads(k, self.num_heads)
- v = self._separate_heads(v, self.num_heads)
-
- # Attention
- _, _, _, c_per_head = q.shape
- attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens
- attn = attn / math.sqrt(c_per_head)
- attn = torch.softmax(attn, dim=-1)
-
- # Get output
- out = attn @ v
- out = self._recombine_heads(out)
- out = self.out_proj(out)
-
- return out
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。