赞
踩
首先先来讲一讲SAM。有讲的不对的地方请指出,谢谢!
SAM工作就是用新的prompt engineer+预训练大模型的范式来对图像进行分割,以实现zero-shot(旧的范式是pretrain+finetune)
整个SAM架构可以分成三个大部分,image encoder部分、prompt encoder部分、mask decoder部分。下面一一介绍。
SAM的image encoder部分用的是MAE预训练的ViT,ViT这里就不做介绍了。原始图像被等比和padding的缩放到1024大小,采用kernel size为16,stride为16的卷积将图像离散化为64×64×768的向量,铺平后进入transformer encoder,输出的向量再通过两层的卷积压缩到embedding dimension为256。
image encoder这一部分的计算以及存储消耗是非常大的,在META官方的demo中,image embedding的计算也是在云端服务器中进行的。所以要实现模型轻量化,对这一部分需要做改进。
在META官方的demo中,可以通过给定一个点位(point)来进行语义分割,如下图所示。
也可以框选一个区域,来进行语义分割,如下图所示。
此外,论文中还提到text prompt。这个功能在demo中没有展现,个人理解就是给一个我想要分割的区域的描述,SAM根据描述进行相应区域的分割。
上面说到的三种prompts在论文中归类为稀疏类prompt(sparse prompt)。point和box(左上角的点&右下角的点)采用position embedding(transformer里的东西,是一种用sines和cosines组成的编码,能够表示一个东西的相对位置和顺序关系)+learnable cls embeddings作为embedding;(这个部分可以看一下代码)
text prompt同样也是稀疏类prompt,但显然不能用pe来表示它。SAM中对应于text的encoder是CLIP架构中的text encoder,具体可以看CLIP的相关内容。
还有一个prompt是mask,采用卷积神经网络进行下采样后和image embedding进行element-wise相加(使得,就是1+1=2的加,反正都挺玄学的)
下图是论文中给出的mask decoder的结构
相信大部分人和我一样,乍一眼看,一脸懵逼,这么多箭头,而且论文中对它的描述也很少。那我们从左往右来分析。
image embedding和prompt embedding就是上面提到的prompt部分的内容。而output tokens前面并没有提到,其实看过ViT的同学应该对这个玩意儿不陌生,VIT做的是分类任务,在image embedding的最前面加了一个cls token,在好几层的self attention之后,输出的这个cls token就是对应的目标类别。这里也是同理,SAM做的是语义分割任务,但是输出不止一个mask,如下图所示。
这个应该是针对于point prompt来说的,拿论文中这个剪刀举例。我point点在剪刀柄上,我想要分割的区域可能会是上面三种的其中一种,也就是“全部”、“部分”、“子部分”。那么根据什么来展示出最后的输出呢,就涉及到这个output tokens,一个output mask对应一个output tokens,还有一个IoU prediction head来选择三个mask中它认为最好的输出(这个IoU prediction head是模型中的一个learnable分支,在训练模型时根据GT来训练)。
def forward( self, image_embeddings: torch.Tensor, image_pe: torch.Tensor, sparse_prompt_embeddings: torch.Tensor, dense_prompt_embeddings: torch.Tensor, multimask_output: bool, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Predict masks given image and prompt embeddings. Arguments: image_embeddings (torch.Tensor): the embeddings from the image encoder image_pe (torch.Tensor): positional encoding with the shape of image_embeddings sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs multimask_output (bool): Whether to return multiple masks or a single mask. Returns: torch.Tensor: batched predicted masks torch.Tensor: batched predictions of mask quality """ masks, iou_pred = self.predict_masks( image_embeddings=image_embeddings, image_pe=image_pe, sparse_prompt_embeddings=sparse_prompt_embeddings, dense_prompt_embeddings=dense_prompt_embeddings, ) # Select the correct mask or masks for output if multimask_output: mask_slice = slice(1, None) else: mask_slice = slice(0, 1) masks = masks[:, mask_slice, :, :] iou_pred = iou_pred[:, mask_slice] # Prepare output return masks, iou_pred def predict_masks( self, image_embeddings: torch.Tensor, image_pe: torch.Tensor, sparse_prompt_embeddings: torch.Tensor, dense_prompt_embeddings: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: """Predicts masks. See 'forward' for more details.""" # Concatenate output tokens # output_tokens由IoU_token和mask_token组成,将他们进行拼接 output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0) #将原来2维的output_tokens空间化后按照sparse_prompt_embedding的维度进行扩展(sparse_prompt_embedding的维度是什么呢?) output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1) #将扩展后的output_tokens和sparse_prompt_embeddings进行拼接 tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1) # Expand per-image data in batch direction to be per-mask # 我的理解就是有几个tokens就把image_embedding扩充几次,这个tokens应该和prompt的类型有关 src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0) # 把image_embeddings和mask_embeddings加上 src = src + dense_prompt_embeddings # 和ViT一样,加上learnable_position_encoding pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0) b, c, h, w = src.shape # Run the transformer # token to image attention(先按下不表) #这里的transformer整合了右下角的token to image attn.后面再说 hs, src = self.transformer(src, pos_src, tokens) # 从hs中分离出两个token iou_token_out = hs[:, 0, :] mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :] # Upscale mask embeddings and predict masks using the mask tokens # 梯形的那个上采样 src = src.transpose(1, 2).view(b, c, h, w) upscaled_embedding = self.output_upscaling(src) hyper_in_list: List[torch.Tensor] = [] for i in range(self.num_mask_tokens): hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :])) hyper_in = torch.stack(hyper_in_list, dim=1) b, c, h, w = upscaled_embedding.shape masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w) # Generate mask quality predictions iou_pred = self.iou_prediction_head(iou_token_out) return masks, iou_pred
那么上面先按下不表的内容我们下面就可以拿出来说了,在上面的代码中的transformer并不只是正常框架的transformer,里面包含了mask decoder中左边深色框住部分以及右下角的token to image attn部分,所以上面把image embedding、output token、prompt token处理完了之后,一个transformer出来就直接给2xconv和分output_token了。
在论文中提到,mask decoder用了TwoWayTransformer。在代码中,我们可以在segment-anything/segment_anything/build_sam.py路径下找到MaskDecoder类的相关代码,如下图所示。
可以看到,transformer调用的是TwoWayTransfromer类,depth为2就是结构示意图中的×2。
根据segment-anything/segment_anything/modeling/transformer.py路径可以找到TwoWayTransformer的定义。相同目录下也能找到TwoWayAttentionBlock,上面也提到了,transformer包括深色框选部分以及右下角的token to image attn.,那么这个TwoWayAttentionBlock就是pure的深色框选部分的代码。
先来看整体的TwoWayTransformer的代码
class TwoWayTransformer(nn.Module): def __init__( self, depth: int, embedding_dim: int, num_heads: int, mlp_dim: int, activation: Type[nn.Module] = nn.ReLU, attention_downsample_rate: int = 2, ) -> None: """ A transformer decoder that attends to an input image using queries whose positional embedding is supplied. Args: depth (int): number of layers in the transformer embedding_dim (int): the channel dimension for the input embeddings num_heads (int): the number of heads for multihead attention. Must divide embedding_dim mlp_dim (int): the channel dimension internal to the MLP block activation (nn.Module): the activation to use in the MLP block """ super().__init__() self.depth = depth self.embedding_dim = embedding_dim self.num_heads = num_heads self.mlp_dim = mlp_dim self.layers = nn.ModuleList() # 这两个注释之间的内容就是实现框架的结构,depth为2,start for i in range(depth): self.layers.append( # 深色框选部分 TwoWayAttentionBlock( embedding_dim=embedding_dim, num_heads=num_heads, mlp_dim=mlp_dim, activation=activation, attention_downsample_rate=attention_downsample_rate, skip_first_layer_pe=(i == 0), ) ) # 显然这个就是右下角的token to image attn. self.final_attn_token_to_image = Attention( embedding_dim, num_heads, downsample_rate=attention_downsample_rate ) self.norm_final_attn = nn.LayerNorm(embedding_dim) # 到此能说明结构走向 end def forward( # 这部分就是前面两个encoder得到的结果,对于point prompt和box prompt来说其实都是点,无非前者是一个点,后者是两个点 self, image_embedding: Tensor, image_pe: Tensor, point_embedding: Tensor, ) -> Tuple[Tensor, Tensor]: """ Args: image_embedding (torch.Tensor): image to attend to. Should be shape B x embedding_dim x h x w for any h and w. image_pe (torch.Tensor): the positional encoding to add to the image. Must have the same shape as image_embedding. point_embedding (torch.Tensor): the embedding to add to the query points. Must have shape B x N_points x embedding_dim for any N_points. Returns: torch.Tensor: the processed point_embedding torch.Tensor: the processed image_embedding """ # BxCxHxW -> BxHWxC == B x N_image_tokens x C bs, c, h, w = image_embedding.shape # 显然,image embedding是四维的,这里是将第三维开始展开,将后面的维度转化为一维 # 那么具体到image embedding,第三维和第四维也就是high和width # 然后把三个维从BxCxHxW调整为BxHWxC image_embedding = image_embedding.flatten(2).permute(0, 2, 1) image_pe = image_pe.flatten(2).permute(0, 2, 1) # Prepare queries queries = point_embedding keys = image_embedding # Apply transformer blocks and final layernorm for layer in self.layers: queries, keys = layer( queries=queries, keys=keys, query_pe=point_embedding, key_pe=image_pe, ) # Apply the final attention layer from the points to the image q = queries + point_embedding # 为啥要加两回? k = keys + image_pe # 合理的,image_embedding直接加上position_encoding attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys) queries = queries + attn_out queries = self.norm_final_attn(queries) # 这里的norm_final_attn就是一个Layernorm # 这里还要再看一下 return queries, keys
好,看完整个的结构,再来单独看深色框选部分的代码,也就是TwoWayAttentionBlock
class TwoWayAttentionBlock(nn.Module): def __init__( self, embedding_dim: int, num_heads: int, mlp_dim: int = 2048, activation: Type[nn.Module] = nn.ReLU, attention_downsample_rate: int = 2, skip_first_layer_pe: bool = False, ) -> None: """ A transformer block with four layers: (1) self-attention of sparse inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp block on sparse inputs, and (4) cross attention of dense inputs to sparse inputs. Arguments: embedding_dim (int): the channel dimension of the embeddings num_heads (int): the number of heads in the attention layers mlp_dim (int): the hidden dimension of the mlp block activation (nn.Module): the activation of the mlp block skip_first_layer_pe (bool): skip the PE on the first layer """ super().__init__() self.self_attn = Attention(embedding_dim, num_heads) self.norm1 = nn.LayerNorm(embedding_dim) self.cross_attn_token_to_image = Attention( embedding_dim, num_heads, downsample_rate=attention_downsample_rate ) self.norm2 = nn.LayerNorm(embedding_dim) self.mlp = MLPBlock(embedding_dim, mlp_dim, activation) self.norm3 = nn.LayerNorm(embedding_dim) self.norm4 = nn.LayerNorm(embedding_dim) self.cross_attn_image_to_token = Attention( embedding_dim, num_heads, downsample_rate=attention_downsample_rate ) self.skip_first_layer_pe = skip_first_layer_pe def forward( self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor ) -> Tuple[Tensor, Tensor]: # Self attention block if self.skip_first_layer_pe: queries = self.self_attn(q=queries, k=queries, v=queries) else: q = queries + query_pe attn_out = self.self_attn(q=q, k=q, v=queries) queries = queries + attn_out queries = self.norm1(queries) # Cross attention block, tokens attending to image embedding q = queries + query_pe k = keys + key_pe attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys) queries = queries + attn_out queries = self.norm2(queries) # MLP block mlp_out = self.mlp(queries) queries = queries + mlp_out queries = self.norm3(queries) # Cross attention block, image embedding attending to tokens q = queries + query_pe k = keys + key_pe attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries) keys = keys + attn_out keys = self.norm4(keys) return queries, keys
到这里,SAM的模型结构差不多就结束了,但是离理解SAM还差一个核心,也就是数据的设计,这在论文中也花了非常大的篇幅进行介绍,也是理解SAM非常重要的一环,应为模型中的很多设计都和数据的设计有关。这里先挖坑(20240327),希望有时间能来填上。
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。