当前位置:   article > 正文

【大模型系列】一文看懂SAM大模型_sam模型

sam模型

参考资料:

SAM模型大致上分成3个模块,一个标准的vit构成的image encoder、一个prompt encoder和一个mask decoder。其中:

  • Image encoder: 用于输出image embedding;
  • prompt encoder:用于接收point、box、txt的编码信息,并且与image embedding组合到一起送入mask decoder中;
  • mask decoder:将上述两个encoder的编码信息转化为mask输出。

1 Image Encoder的结构

从结构上看,sam的encoder部分就是堆叠transformer的block结构,最后再跟一个neck,调整输出embedding的维度。Meta开源了三个模型,分别是vit_h, vit_l和vit_b,这三个模型的区别仅仅在于内部patch embedding维度、transformer的block的个数以及每个block中head的数量和全局attention的index:

模型patch embedding维度transformer head数量transformer block层数global attention 的block的index
vit_h12801632[7, 15, 23, 31]
vit_l10241624[5, 11, 17, 23]
vit_b7681212[2, 5, 8, 11]

网络输入尺寸:1024x1024,
图片分path的尺寸:16,
image embedding的长度:256,
windows size:14。

1.1 图片分patch

原图进入网络之后,按照最大边长补充成方形,再resize到1024x1024。

1024x1024x3的图片输入进入网络后,首先使用一个16x16,stride=16,输出channel数为patch embedding维度的二维卷积。以vit_b为例,patch embedding的维度是768,因此经过卷积之后,图片变成了768x64x64的feature map,再调整维度就变成64x64x768。

在该feature map基础上,会再加一个绝对位置编码(absolute positional embedding),所谓绝对位置编码是指生成一组与feature map同样大小(64x64x768)的可学习参数,初始化时一般为0。

1.2 attention block

在这里插入图片描述

1.2.1 window partition

针对非global attention的block,会将上一小节输出的feature map进行补边,再拆分成14x14的网格。流程如下:
输入的特征图大小为:1x64x64x768
窗口的大小为:14x14

得到最小可整除特征图大小为1x70x70x768,因此采用0来padding,padding方式为右下角填充,再将特征图拆分为25x14x14x768。
在这里插入图片描述

1.2.2 window unpartition

针对非global attention的block,将attention层输出的特征图1x70x70x768转化为1x64x64x768的特征图,实际上是通过切片操作得到的,即取右上角特征图。
在这里插入图片描述

1.2.3 relative partition embedding

相对位置编码出现在attention模块中,用于在q*k计算完成之后,对于attention矩阵进行操作:

def forward(self, x: torch.Tensor) -> torch.Tensor:        
B, H, W, _ = x.shape        
# qkv with shape (3, B, nHead, H * W, C)        
qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)        
# q, k, v with shape (B * nHead, H * W, C)        
q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0)        
attn = (q * self.scale) @ k.transpose(-2, -1)

# 添加相对位置编码        
if self.use_rel_pos:            
    attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W))        
    attn = attn.softmax(dim=-1)        
x = (attn @ v).view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1)        
x = self.proj(x)        
return x
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 生成一组可学习的位置编码向量矩阵

相对位置编码针对于非global attention的block,即那些需要将特征图拆分成14x14大小的block。在h,w两个方向上生成的一组可学习的参数,维度为(2 * 14 - 1, 64)。

针对于需要global attention的block,即不拆分特征图,在h,w两个方向上生成的一组可学习的参数,维度为(2 * 64- 1, 64)

# 其中input_size[0] = input_size[1] = 14
# multi-head attention中的head维度:head_dim = block输出的维度768 / head的数量12 = 64
self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim))
self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim))
  • 1
  • 2
  • 3
  • 4

为啥是2 * input_size[0] - 1呢?因为矩阵中最远的距离就是对角线元素之间的曼哈顿距离,所以只需生成2*H-1个向量即可。

  • 根据特征图的大小,生成相对坐标的index

假设q,k的size为5,则生成的位置编码为:(torch.arange(5)[:, None] - torch.arange(5)[None, :]) + (5 -1)

效果如下:

tensor([[4, 3, 2, 1, 0],

​ [5, 4, 3, 2, 1],

​ [6, 5, 4, 3, 2],

​ [7, 6, 5, 4, 3],

​ [8, 7, 6, 5, 4]])

从图中可以看出,相对位置编码index是从特征图的某个角开始设置为0,距离该角越远,index越大。再使用该index,从上一步生成的位置编码向量矩阵中取出不同index下的编码向量。

  • 基于query矩阵计算最终的相对位置编码

针对非global attention的block,输入的特征图大小为25x14x14x768,所以生成的index矩阵大小为14x14,再用index矩阵取其对应的位置编码向量,得到的就是一个14x14x64的位置编码矩阵,针对h,w两个方向的做同样的操作,得到2个14x14x64的相对位置编码矩阵Rh与Rw。其中Rh基于rel_pos_h生成,Rw基于rel_pos_w生成,两个矩阵不一样。

此时计算出来的query矩阵大小为300x196x64,将其还原到300x14x14x64,再分别与Rh和Rw做矩阵乘法,最终得到的就是位置编码,大小为300x14x14x14,对应代码:

def add_decomposed_rel_pos(        
    attn: torch.Tensor,        
    q: torch.Tensor,        
    rel_pos_h: torch.Tensor,        
    rel_pos_w: torch.Tensor,        
    q_size: Tuple[int, int],        
    k_size: Tuple[int, int], ) -> torch.Tensor:        
    """ Calculate decomposed Relative Positional Embeddings    
    Args:                
        attn (Tensor): attention map.                
        q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C).                
        rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis.                
        rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis.                
        q_size (Tuple): spatial sequence size of query q with (q_h, q_w).                
        k_size (Tuple): spatial sequence size of key k with (k_h, k_w).        

    Returns:                
        attn (Tensor): attention map with added relative positional embeddings.        
    """    
    # q: 300x196x64    
    # atten:300x196x196    
    q_h, q_w = q_size        
    k_h, k_w = k_size
    
    # Rh: 14x14x64        
    Rh = get_rel_pos(q_h, k_h, rel_pos_h)        
    Rw = get_rel_pos(q_w, k_w, rel_pos_w)        
    B, _, dim = q.shape
    
    # r_q: 300x14x14x64        
    r_q = q.reshape(B, q_h, q_w, dim) 
    
    # rel_h: 300x14x14x14    
    # 等价于:   
    # rel_h = torch.matmul(r_q, Rh.transpose(1, 2))
    # rel_w = torch.matmul(r_q.transpose(1, 2), Rw.transpose(1, 2)).transpose(1, 2)    
    rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh)        
    rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw)
    
    # 将相对位置编码加在atten里面,再resize回300x192x196    
    attn = (attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]).view(B, q_h * q_w, k_h * k_w) 
    return attn
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 直接将计算好的相对位置编码加到attention矩阵上

attention矩阵为300x196x196,reshape成300x14x14x14x14,再使用矩阵加法,将相对位置编码分别加到倒数2个维度上,再reshape回原来的大小。

对应的代码操作为:

attn = (attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]).view(B, q_h * q_w, k_h * k_w)    
  • 1

1.3 neck

neck部分由两个卷积层组成,分别是256x768x1x1和256x256x3x3,最后输出的image imbedding的尺寸是1x256x64x64。

2 Prompt encoder

根据输入的point和boxs返回sparse embedding, 根据mask返回dense embeddings。

2.1 point embedding

  • step1:首先生成一组可学习的向量point embedding,大小为:4x1x256:
point_embeddings = [nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings)]
self.point_embeddings = nn.ModuleList(point_embeddings)
  • 1
  • 2

4代表了表示pos/neg + 2 box corners,即demo里面的添加点和消除点、以及box框的左上角和右下角;

0:neg,对应demo中的消除点

1:pos,对应demo中的添加点

2:代表box左上角点

3:代表box右下角点

  • step2:再生成一组可学习的向量not_a_point_embed,大小为1x256,用于表示该位置不是一个点
self.not_a_point_embed = nn.Embedding(1, embed_dim)
  • 1
  • step3:如果传入的prompt里面没有bbox,则补充一个【0,0】点到每个point后面,其对应的label为-1
if pad:            
    padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device)            
    padding_label = -torch.ones((labels.shape[0], 1), device=labels.device)            
    points = torch.cat([points, padding_point], dim=1)            
    labels = torch.cat([labels, padding_label], dim=1)
  • 1
  • 2
  • 3
  • 4
  • 5

此时point大小为Nx2x2,label为Nx2

  • step4:如果传入的还有bbox,此时的point大小为Nx1x2,label为Nx1
  • step5:再根据point计算point embedding,其流程如下:
    • 横纵坐标先归一化,即都除以输入的尺寸(1024, 1024);
    • 再将point矩阵与一个随机高斯矩阵(2x128)矩阵相乘得到Nxax128的矩阵coord,其中(a=2表示只有point,a=1表示还有box作为prompt输入);
    • 再分别对coord计算sin和cos,拼接矩阵得到最终的point embedding(Nxax256)
def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor:
        """Positionally encode points that are normalized to [0,1]."""
        # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
        coords = 2 * coords - 1
        coords = coords @ self.positional_encoding_gaussian_matrix
        coords = 2 * np.pi * coords
        # outputs d_1 x ... x d_n x C shape
        return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • step:6再根据label,给point embedding加上之前生成的可学习的embeding向量
point_embedding = self.pe_layer.forward_with_coords(points, self.input_image_size)    
point_embedding[labels == -1] = 0.0        
point_embedding[labels == -1] += self.not_a_point_embed.weight      # 对应label为-1的,加上not_a_point_embed
point_embedding[labels == 0] += self.point_embeddings[0].weight     # neg点加上point_embeddings[0]
point_embedding[labels == 1] += self.point_embeddings[1].weight     # pos点加上point_embeddings[1]
  • 1
  • 2
  • 3
  • 4
  • 5

完整point embedding的流程如下:

def _embed_points(
    self,        
    points: torch.Tensor,        
    labels: torch.Tensor,        
    pad: bool,     ) -> torch.Tensor:        
    """Embeds point prompts."""        

    points = points + 0.5  
    # Shift to center of pixel
    # 如果没有输入的box的话,会将points的长度用0补充形成Nx2x2,label用【-1】补充成Nx2
    if pad:            
        padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device)            
        padding_label = -torch.ones((labels.shape[0], 1), device=labels.device)            
        points = torch.cat([points, padding_point], dim=1)            
        labels = torch.cat([labels, padding_label], dim=1)

    # 将points与一个2x128的随机高斯矩阵相乘再通过进行sin、cos运算,两者的运算结果拼接得到
    # point_embedding: Nx1x256 或者 Nx2x256
    point_embedding = self.pe_layer.forward_with_coords(points, self.input_image_size)    
    point_embedding[labels == -1] = 0.0        
    point_embedding[labels == -1] += self.not_a_point_embed.weight
    point_embedding[labels == 0] += self.point_embeddings[0].weight
    point_embedding[labels == 1] += self.point_embeddings[1].weight
    return point_embedding
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24

返回之后需要与sparse embedding进行拼接:

# 如果只有point,那么sparse_embeddings的size是Nx2x256,如果还有box则是Nx1x256
sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1)
  • 1
  • 2

如果只有point,当前sparse_embeddings 的大小为Nx2x256

如果还有box,当前sparse_embeddings 的大小为Nx1x256

2.2 box embedding

bbox一般有2个点,其编码步骤如下:

step1: 所以回先resize为Nx2x2;

step2: 再使用point embedding编码的方式,得到corner_embedding,

step3: 再加上之前生成的可学习的embeding向量;

最后输出的corner_embedding大小为Nx2x256。

def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor:        
    """Embeds box prompts."""        
    boxes = boxes + 0.5  # Shift to center of pixel
    # 操作与points类似,讲4个点resize成Nx2x2
    coords = boxes.reshape(-1, 2, 2)

    # 返回Nx2x256的embedding
    corner_embedding = self.pe_layer.forward_with_coords(coords, self.input_image_size)

    # 再加上点的embedding        
    corner_embedding[:, 0, :] += self.point_embeddings[2].weight        
    corner_embedding[:, 1, :] += self.point_embeddings[3].weight        
    return corner_embedding
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13

最后输出的box的embedding的尺寸是Nx2x256。

合并(concat)point embedding和corner embedding,可以得到sparse embedding:

  • 全都没有:sparse embedding(1x0x256)
  • 如果只有point:sparse embedding(Nx2x256)
  • 如果只有box:sparse embedding(Nx2x256)
  • piont、box都有:sparse embedding(Nx3x256)

2.3 mask embedding

  • 如果没有配置mask,有一个长度为256的可学习向量,表示没有mask embedding,再将其拓展为1x256x64x64
self.no_mask_embed = nn.Embedding(1, embed_dim)      # embed_dim=256
dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand(
                bs, -1, self.image_embedding_size[0], self.image_embedding_size[1])
  • 1
  • 2
  • 3
  • **如果有配置mask:**已知输入的mask是Nx1x256x256,经过3层卷积,最后得到与image embedding一样的size:

mask会先进入一个1x2x2x4的卷积,stride=2;LN;

然后再进入一个4x2x2x16的卷积,stride=2;LN;

最后再进入一个16x1x1x256的卷积;

得到最后的mask_embedding的size为Nx256x64x64

最终mask embeding作为dense embedding输出,大小为Nx256x64x64。

3 Mask decoder

初始化几个可学习的参数:

可学习的mask tokens:4x256

# num_mask_tokens = 3 + 1 = 4, transformer_dim = 256
# 输出一个4x256的矩阵
self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim)
  • 1
  • 2
  • 3

可学习的iou tokens:1x256

self.iou_token = nn.Embedding(1, transformer_dim)
  • 1

image_pe: 跟image embedding一样大的位置编码256x64x64 ,见prompt_encoder.py:PositionalEmbeddingRandom.get_dense_pe()

就是将64x64个坐标点归一化之后,与随机高斯矩阵相乘(2x128),再将结果分别进行sin和cos,最后再拼到一起,输出的大小为256x64x64,与image_embedding大小基本一致了。

class PositionEmbeddingRandom(nn.Module):
    """
    Positional encoding using random spatial frequencies.
    """
    def init(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None:
        super().init()
        if scale is None or scale <= 0.0:
            scale = 1.0
        # 构建一个2x128的随机矩阵作为位置编码高斯矩阵
        self.register_buffer(
            "positional_encoding_gaussian_matrix",
            scale * torch.randn((2, num_pos_feats)),
        )

    def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor:
        """Positionally encode points that are normalized to [0,1]."""
        # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
        coords = 2 * coords - 1

        # 矩阵乘法:64x64xx2 @ 2x128 ---> 64x64x128
        coords = coords @ self.positional_encoding_gaussian_matrix
        coords = 2 * np.pi * coords

        # outputs d_1 x ... x d_n x C shape
        # cat, 最后一个维度上拼接:64x64x256
        return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1)

    def forward(self, size: Tuple[int, int]) -> torch.Tensor:
        """Generate positional encoding for a grid of the specified size."""
        h, w = size
        device: Any = self.positional_encoding_gaussian_matrix.device

        # 构造一个64x64的全1矩阵
        grid = torch.ones((h, w), device=device, dtype=torch.float32)

        # 行、列累加
        y_embed = grid.cumsum(dim=0) - 0.5
        x_embed = grid.cumsum(dim=1) - 0.5

        # 行列累加结果归一化
        y_embed = y_embed / h
        x_embed = x_embed / w

        # 行列拼接:64x64x2,编码后的结果是64x64x256
        pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1))

        # 最后输出256x64x64
        return pe.permute(2, 0, 1)  # C x H x W
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48

3.1 预测mask的流程

  • sparse embedding、iou token、mask token合并成一个tokens,作为point_embeddings

需要注意的是:

sparse embedding: point、bbox prompt合并后的产物,一般为NxXx256

iou token: 可学习参数,大小为1x256

mask token: 可学习参数,大小为4x256

首先将iou token和mask token 拼接得到一个5x256的矩阵,再将其拓展到与sparse embedding一个维度Nx5x256;

再将拓展后的矩阵与sparse embedding拼接得到tokens,其大小Nx(5+X)x256

# 代码见:mask_decoder.py -> predict_masks
# 拼接iou_token和mask token得到i: 5x256 的tokens
output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0)

# 拓展成稀疏prompt编码的个数:Nx5x256
output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1)

# 再与稀疏矩阵拼接,假设稀疏矩阵只有point为Nx2x256,拼接之后则为Nx(5+2)x256
tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • image embedding与dense prompt直接相加得到Nx256x64x64的矩阵,命名为src,作为image_embedding

需要注意的是:

image embedding: 是image encoder的输出,大小为为1x256x64x64

dense prompt: 是mask embedding的产物,大小为Nx256x64x64

image embedding拓展维度之后直接与dense prompt相加,得到image_embedding,大小为Nx256x64x64

# 将image embedding(1x256x64x64)拓展成稠密prompt的维度:Nx256x64x64
src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)

# 将拓展后的image embedding直接与稠密prompt相加:Nx256x64x64
src = src + dense_prompt_embeddings
  • 1
  • 2
  • 3
  • 4
  • 5
  • image_pe位置编码也拓展成Nx256x64x64的矩阵,命名为pos_src

需要注意的是:image_pe相当于特征图中每个位置进行了与point类似的编码操作

# 将256x64x64的位置编码,拓展成Nx256x64x64
pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)
  • 1
  • 2
  • 将这三个送入TwoWayTransformer中,返回的结果后处理后就能得到最终的mask信息。
def predict_masks(
        self,
        image_embeddings: torch.Tensor,
        image_pe: torch.Tensor,
        sparse_prompt_embeddings: torch.Tensor,
        dense_prompt_embeddings: torch.Tensor,
    ) -> Tuple[torch.Tensor, torch.Tensor]:

        """Predicts masks. See 'forward' for more details."""
        # Concatenate output tokens
        # 拼接iou_token和mask token得到i: 5x256 的tokens
        output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0)

        # 拓展成稀疏prompt编码的个数:Nx5x256
        output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1)

        # 再与稀疏矩阵拼接,假设稀疏矩阵只有point为Nx2x256,拼接之后则为Nx(5+2)x256
        tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1)

        # Expand per-image data in batch direction to be per-mask
        # 将image embedding(1x256x64x64)拓展成稠密prompt的维度:Nx256x64x64
        src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)

        # 将拓展后的image embedding直接与稠密prompt相加:Nx256x64x64
        src = src + dense_prompt_embeddings

        # 将256x64x64的位置编码,拓展成Nx256x64x64
        pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)
        b, c, h, w = src.shape

        # Run the transformer:这里使用的TwoWayTransformer,有必要对输入再说明一下
        # src:image_bedding + dense_prompt(mask),Nx256x64x64
        # pos_src: 位置编码,Nx256x64x64
        # tokens: iou_tokens + mask_tokens + sparse_prompt(point/bbox),Nx(5+x)x256
        hs, src = self.transformer(src, pos_src, tokens)


        # 后处理
        iou_token_out = hs[:, 0, :]
        mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :]

        # Upscale mask embeddings and predict masks using the mask tokens
        src = src.transpose(1, 2).view(b, c, h, w)
        upscaled_embedding = self.output_upscaling(src)
        hyper_in_list: List[torch.Tensor] = []
        for i in range(self.num_mask_tokens):
            hyper_in_list.append(self.output_hypernetworks_mlpsi)
        hyper_in = torch.stack(hyper_in_list, dim=1)
        b, c, h, w = upscaled_embedding.shape
        masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w)
        # Generate mask quality predictions
        iou_pred = self.iou_prediction_head(iou_token_out)
        return masks, iou_pred
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53

3.2 TwoWayTransformer

【关于transformer的一个心得】:q、k、v中,k、v一定具有相同的size,最后输出的attention的size是由q来决定的。

参数:

  • depth:2,表示attention block只有2个
  • embedding_dim: 256
  • mlp_dim: 2048
  • num_heads: 8

所谓的TwoWay:两轮次循环,第一次point_embedding自注意,第二次则加上上一轮输出的query再进行attention。

两层TwoWayAttentionBlock:

  • 第一层:q = q_pe = point_embedding,k = image_embedding, k_pe = image_pe
  • 第二层:q = 第一层输出q, q_pe = point_embedding,k = 第一层输出k, k_pe = image_pe

整个流程如下:
在这里插入图片描述

  • 先将image_embedding转换shape:Nx256x64x64 —> Nx4096x256 (注意此时的image embedding是encoder的输出+dense prompt)
  • 再将位置编码也调整shape:Nx256x4096 —> Nx4096x256
  • 将image embedding( Nx4096x256)作为key,point embedding(Nx(5+x)x256)作为querise,送入2层TwoWayAttentionBlock,需要注意的是point embedding是point、bbox和iou、mask embedding拼接结果
  • TwoWayAttentionBlock返回 key和querise
  • q = querise + point_embedding, k = key + image_pe,v = key将其输入到final_attn_token_to_image中
  • 最后输出的结果是queries = queries + attn_out,再经过norm
def forward(
        self,
        image_embedding: Tensor,
        image_pe: Tensor,
        point_embedding: Tensor,
    ) -> Tuple[Tensor, Tensor]:qur

        """
        Args:
          image_embedding (torch.Tensor): image to attend to. Should be shape
            B x embedding_dim x h x w for any h and w.
          image_pe (torch.Tensor): the positional encoding to add to the image. Must
            have the same shape as image_embedding.
          point_embedding (torch.Tensor): the embedding to add to the query points.
            Must have shape B x N_points x embedding_dim for any N_points.
            
        Returns:
          torch.Tensor: the processed point_embedding
          torch.Tensor: the processed image_embedding
        """

        # BxCxHxW -> BxHWxC == B x N_image_tokens x C
        # image_embedding: Nx256x64x64
        bs, c, h, w = image_embedding.shape

        # Nx256x4096 ---> Nx4096x256
        image_embedding = image_embedding.flatten(2).permute(0, 2, 1)

        # Nx256x4096 ---> Nx4096x256
        image_pe = image_pe.flatten(2).permute(0, 2, 1)

        # Prepare queries
        queries = point_embedding       # Nx(5+x)x256
        keys = image_embedding          # Nx4096x256

        # Apply transformer blocks and final layernorm
        # 将稀疏prompt和iou、mask的tokens组合tokens作为querise,image embedding作为key
        # 进入2层的TwoWayAttentionBlock
        for layer in self.layers:
            queries, keys = layer(
                queries=queries,
                keys=keys,
                query_pe=point_embedding,
                key_pe=image_pe,
            )
        # Apply the final attention layer from the points to the image
        # 出TwoWayAttentionBlock的querise继续加上组合tokens作为querise
        q = queries + point_embedding

        # 出TwoWayAttentionBlock的key再加上位置编码作为key
        k = keys + image_pe

        # 进入最后的attention层,但是v还是没有加上位置编码的key
        attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys)

        # querise + attention的输出并norm后输出
        queries = queries + attn_out
        queries = self.norm_final_attn(queries)
        return queries, keys
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59

3.3 后处理

TwoWayTransformer返回的结果为:

hs: Nx(5+x)x256

src: Nx4096x256

  • 取tokens

    • 取第一个位置为iou_token : hs[:, 0, :] —> Nx1x256
    • 取1~5为位子为mask_token: hs[:, 1 : (1 + 4), :] —> Nx4x256
  • reshape src: Nx4096x256 —> Nx256x64x64

  • 通过2层转置卷积,将src变成Nx32x256x256

  • 将4个mask token分别送入4个独立的3层全连接网络中(每个channel进入不同的FC),最终得到hyper_in:Nx4x32

  • 将hyper_in 矩阵乘 reshap src得到Nx4x256x256的矩阵,这就是输出的mask

  • 将tokens送入3层全连接网络,最后得到iou_pred,大小为Nx1x4,这就输出的IoU_pred!

流程示意图如下,最后输出的N与传入的prompt的数量有关
在这里插入图片描述

hs, src = self.transformer(src, pos_src, tokens)
# 后处理
iou_token_out = hs[:, 0, :]
mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :]

# Upscale mask embeddings and predict masks using the mask tokens
src = src.transpose(1, 2).view(b, c, h, w)
upscaled_embedding = self.output_upscaling(src)
hyper_in_list: List[torch.Tensor] = []
for i in range(self.num_mask_tokens):
    hyper_in_list.append(self.output_hypernetworks_mlpsi)

hyper_in = torch.stack(hyper_in_list, dim=1)
b, c, h, w = upscaled_embedding.shape
masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w)

# Generate mask quality predictions
iou_pred = self.iou_prediction_head(iou_token_out)
return masks, iou_pred
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19

输出的mask有4个通道:

通道0:不用

通道1:whole

通道2:part

通道3:subpart

所以一般有一个multimask_output字段来控制是只输出whole,还是全部都输出。

4 全图分割

代码详见:automatic_mask_generator.py

class SamAutomaticMaskGenerator:
    def init(
        self,
        model: Sam,                                        # sam模型
        points_per_side: Optional[int] = 32,               # 每个边需要采样的点数默认为32,最后的总点数为32x32
        points_per_batch: int = 64,                        # 模型可以同时处理的点数,默认64,数字越大越快GPU越高
        pred_iou_thresh: float = 0.88,                     # iou阈值,默认0.88
        stability_score_thresh: float = 0.95,             
        stability_score_offset: float = 1.0,
        box_nms_thresh: float = 0.7,                      
        crop_n_layers: int = 0,                            # 裁剪的层数                     
        crop_nms_thresh: float = 0.7,
        crop_overlap_ratio: float = 512 / 1500,            # 裁剪图片间的重叠情况
        crop_n_points_downscale_factor: int = 1,
        point_grids: Optional[List[np.ndarray]] = None,    # 点的网格列表,与points_per_side直接相关
        min_mask_region_area: int = 0,
        output_mode: str = "binary_mask",
    ) -> None:
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 每张图片默认撒点方式: 首先根据每个边需要采样的点数,默认时32,生成32x32的网格point_grids,所以最后输出的就是1024个坐标,且都归一化到0-1之间。

  • crop_n_layers: 裁剪的层数,每层裁剪的crop_img个数为2^(n+1)个,即第一层裁剪4个,第二层16个,依次类推.

    • point_grids也可以根据crop_n_layers每层进行递减,通过crop_n_points_downscale_factor控制,默认设置为1表示所有的裁剪图crop_img也均匀采样1024个点

Step1: 原图裁剪,一般crop_n_layers设置为0,即送全图区域:
在这里插入图片描述

Step2: 将图片补边,再resize到1024x1024,送入Image Encoder中生成image embedding;

Step3: 图片宽高方向各均匀生成32个位置,组成1024个坐标点;

Step4: 每次送入64个坐标点,迭代1次,生成mask及iou_pred;

Step5: 结果后处理

  1. 根据iou阈值(默认0.88)过滤mask

  2. 对过滤后计算calculate_stability_score稳定性分值=(mask > 1的数量) / (mask > -1的数量)

  3. 根据calculate_stability_score过滤mask,阈值默认为0.95

  4. 对过滤后的mask取阈值0,得到掩膜,根据掩模计算外界矩形框

  5. 过滤外界矩形框达到crop边界的对应的mask

  6. 将截取图片crop_img的mask,映射到原图尺寸上

  7. 再将mask转化为rle编码,用于节省内存,mask拉平,(3,3)表示第3个元素开始,后面3个都是1
    在这里插入图片描述

4.1 完整流程

def generate(self, image: np.ndarray) -> List[Dict[str, Any]]:
        # Generate masks
        # 生成mask
        mask_data = self._generate_masks(image)

        # Filter small disconnected regions and holes in masks
        # 过滤小的区域或者空洞
        if self.min_mask_region_area > 0:
            mask_data = self.postprocess_small_regions(
                mask_data,
                self.min_mask_region_area,
                max(self.box_nms_thresh, self.crop_nms_thresh),
            )

        # Encode masks
        # 输出的mask格式,默认输出binary_mask二值掩码
        if self.output_mode == "coco_rle":
            mask_data["segmentations"] = [coco_encode_rle(rle) for rle in mask_data["rles"]]
        elif self.output_mode == "binary_mask":
            mask_data["segmentations"] = [rle_to_mask(rle) for rle in mask_data["rles"]]
        else:
            mask_data["segmentations"] = mask_data["rles"]

        # Write mask records
        # 将结果整理输出
        curr_anns = []
        for idx in range(len(mask_data["segmentations"])):
            ann = {
                "segmentation": mask_data"segmentations",
                "area": area_from_rle(mask_data"rles"),
                "bbox": box_xyxy_to_xywh(mask_data"boxes").tolist(),
                "predicted_iou": mask_data"iou_preds".item(),
                "point_coords": [mask_data"points".tolist()],
                "stability_score": mask_data"stability_score".item(),
                "crop_box": box_xyxy_to_xywh(mask_data"crop_boxes").tolist(),
            }
            curr_anns.append(ann)

        return curr_anns
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39

4.2 生成masks

def _generate_masks(self, image: np.ndarray) -> MaskData:
        orig_size = image.shape[:2]
        #由于默认的crop_n_layers为0,所以返回的crop_box为全图,layer_idxs只有0
        crop_boxes, layer_idxs = generate_crop_boxes(
            orig_size, self.crop_n_layers, self.crop_overlap_ratio
        )

        # Iterate over image crops
        # 将每一个抠图区域送入网络中
        data = MaskData()
        for crop_box, layer_idx in zip(crop_boxes, layer_idxs):
            crop_data = self._process_crop(image, crop_box, layer_idx, orig_size)
            data.cat(crop_data)

        # Remove duplicate masks between crops
        # 将所有crop图片的结果汇总到一起进行NMS过滤
        if len(crop_boxes) > 1:
            # Prefer masks from smaller crops
            scores = 1 / box_area(data["crop_boxes"])
            scores = scores.to(data["boxes"].device)
            keep_by_nms = batched_nms(
                data["boxes"].float(),
                scores,
                torch.zeros_like(data"boxes"),  # categories
                iou_threshold=self.crop_nms_thresh,
            )
            data.filter(keep_by_nms)
        data.to_numpy()
        return data
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29

4.3 处理每一个crop_img

def _process_crop(
        self,
        image: np.ndarray,
        crop_box: List[int],
        crop_layer_idx: int,
        orig_size: Tuple[int, ...],
    ) -> MaskData:
        # Crop the image and calculate embeddings
        # 裁剪图片
        x0, y0, x1, y1 = crop_box
        cropped_im = image[y0:y1, x0:x1, :]
        cropped_im_size = cropped_im.shape[:2]

        # 将裁剪后的图片送入网络,变成了encoder所需的1024*1024
        # 1、先按照比例将图片最大的边resize到1024
        # 2、调整位置HxWxC ---> 1xCxHxW
        # 3、再将图片归一化
        # 4、用0padding,使其成为1024x1024的size
        # 5、送入encoder计算image_embedding
        self.predictor.set_image(cropped_im)

        # Get points for this crop
        # 将1024x2的矩阵乘以裁剪图片的大小,就得到了在裁剪图片上的grid
        points_scale = np.array(cropped_im_size)[None, ::-1]
        points_for_image = self.point_grids[crop_layer_idx] * points_scale

        # Generate masks for this crop in batches
        data = MaskData()

        # 每张图送入网络1024个点,每次同时计算points_per_batch(64个),因此需要迭代1024 / 64 = 16次
        for (points,) in batch_iterator(self.points_per_batch, points_for_image):
            # 1、将坐标点映射到1024x1024的图片上
            # 2、每个点的label设置为1,label的size就是64x1
            # 3、送入decoder计算mask
            #    3.1 先送入prompt encoder,由于只有point,得到sparse embedding和dense embedding(no_mask_embed)
            #    3.2 将其送入mask decoder得到最后输出的
            #    3.3 切片输出,如果需要输出多个mask,取index 1 ~ 3, 如果只输出一个mask index取0
            #    3.4 对mask做后处理:先resize回1024x1024,取出非padding部分再resize回原图

            # 4、一系列的后处理
            #    4.1 根据iou阈值(默认0.88)过滤mask
            #    4.2 对过滤后计算calculate_stability_score稳定性分值=(mask > 1的数量) / (mask > -1的数量)
            #    4.3 根据calculate_stability_score过滤mask,阈值默认为0.95
            #    4.4 对过滤后的mask取阈值0,得到掩膜,根据掩模计算外界矩形框
            #    4.5 过滤外界矩形框达到crop边界的对应的mask
            #    4.6 将截取图片crop_img的mask,映射到原图尺寸上
            #    4.7 再将mask转化为rle编码,用于节省内存,mask拉平,(3,3)表示第3个元素开始,后面3个都是1

            batch_data = self._process_batch(points, cropped_im_size, crop_box, orig_size)
            data.cat(batch_data)
            del batch_data

        self.predictor.reset_image()

        # Remove duplicates within this crop.
        # 根据bbox,使用NMS过滤重复的结果
        keep_by_nms = batched_nms(
            data["boxes"].float(),
            data["iou_preds"],
            torch.zeros_like(data"boxes"),  # categories
            iou_threshold=self.box_nms_thresh,
        )

        data.filter(keep_by_nms)
        # Return to the original image frame
        # bbox和point映射回原图坐标
        data["boxes"] = uncrop_boxes_xyxy(data["boxes"], crop_box)
        data["points"] = uncrop_points(data["points"], crop_box)
        data["crop_boxes"] = torch.tensor([crop_box for _ in range(len(data["rles"]))])
        return data
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/菜鸟追梦旅行/article/detail/667122
推荐阅读
相关标签
  

闽ICP备14008679号