赞
踩
参考资料:
SAM模型大致上分成3个模块,一个标准的vit构成的image encoder、一个prompt encoder和一个mask decoder。其中:
从结构上看,sam的encoder部分就是堆叠transformer的block结构,最后再跟一个neck,调整输出embedding的维度。Meta开源了三个模型,分别是vit_h, vit_l和vit_b,这三个模型的区别仅仅在于内部patch embedding维度、transformer的block的个数以及每个block中head的数量和全局attention的index:
模型 | patch embedding维度 | transformer head数量 | transformer block层数 | global attention 的block的index |
---|---|---|---|---|
vit_h | 1280 | 16 | 32 | [7, 15, 23, 31] |
vit_l | 1024 | 16 | 24 | [5, 11, 17, 23] |
vit_b | 768 | 12 | 12 | [2, 5, 8, 11] |
网络输入尺寸:1024x1024,
图片分path的尺寸:16,
image embedding的长度:256,
windows size:14。
原图进入网络之后,按照最大边长补充成方形,再resize到1024x1024。
1024x1024x3的图片输入进入网络后,首先使用一个16x16,stride=16,输出channel数为patch embedding维度的二维卷积。以vit_b为例,patch embedding的维度是768,因此经过卷积之后,图片变成了768x64x64的feature map,再调整维度就变成64x64x768。
在该feature map基础上,会再加一个绝对位置编码(absolute positional embedding),所谓绝对位置编码是指生成一组与feature map同样大小(64x64x768)的可学习参数,初始化时一般为0。
针对非global attention的block,会将上一小节输出的feature map进行补边,再拆分成14x14的网格。流程如下:
输入的特征图大小为:1x64x64x768
窗口的大小为:14x14
得到最小可整除特征图大小为1x70x70x768,因此采用0来padding,padding方式为右下角填充,再将特征图拆分为25x14x14x768。
针对非global attention的block,将attention层输出的特征图1x70x70x768转化为1x64x64x768的特征图,实际上是通过切片操作得到的,即取右上角特征图。
相对位置编码出现在attention模块中,用于在q*k计算完成之后,对于attention矩阵进行操作:
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, H, W, _ = x.shape
# qkv with shape (3, B, nHead, H * W, C)
qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
# q, k, v with shape (B * nHead, H * W, C)
q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0)
attn = (q * self.scale) @ k.transpose(-2, -1)
# 添加相对位置编码
if self.use_rel_pos:
attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W))
attn = attn.softmax(dim=-1)
x = (attn @ v).view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1)
x = self.proj(x)
return x
相对位置编码针对于非global attention的block,即那些需要将特征图拆分成14x14大小的block。在h,w两个方向上生成的一组可学习的参数,维度为(2 * 14 - 1, 64)。
针对于需要global attention的block,即不拆分特征图,在h,w两个方向上生成的一组可学习的参数,维度为(2 * 64- 1, 64)
# 其中input_size[0] = input_size[1] = 14
# multi-head attention中的head维度:head_dim = block输出的维度768 / head的数量12 = 64
self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim))
self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim))
为啥是2 * input_size[0] - 1呢?因为矩阵中最远的距离就是对角线元素之间的曼哈顿距离,所以只需生成2*H-1个向量即可。
假设q,k的size为5,则生成的位置编码为:(torch.arange(5)[:, None] - torch.arange(5)[None, :]) + (5 -1)
效果如下:
tensor([[4, 3, 2, 1, 0],
[5, 4, 3, 2, 1],
[6, 5, 4, 3, 2],
[7, 6, 5, 4, 3],
[8, 7, 6, 5, 4]])
从图中可以看出,相对位置编码index是从特征图的某个角开始设置为0,距离该角越远,index越大。再使用该index,从上一步生成的位置编码向量矩阵中取出不同index下的编码向量。
针对非global attention的block,输入的特征图大小为25x14x14x768,所以生成的index矩阵大小为14x14,再用index矩阵取其对应的位置编码向量,得到的就是一个14x14x64的位置编码矩阵,针对h,w两个方向的做同样的操作,得到2个14x14x64的相对位置编码矩阵Rh与Rw。其中Rh基于rel_pos_h生成,Rw基于rel_pos_w生成,两个矩阵不一样。
此时计算出来的query矩阵大小为300x196x64,将其还原到300x14x14x64,再分别与Rh和Rw做矩阵乘法,最终得到的就是位置编码,大小为300x14x14x14,对应代码:
def add_decomposed_rel_pos(
attn: torch.Tensor,
q: torch.Tensor,
rel_pos_h: torch.Tensor,
rel_pos_w: torch.Tensor,
q_size: Tuple[int, int],
k_size: Tuple[int, int], ) -> torch.Tensor:
""" Calculate decomposed Relative Positional Embeddings
Args:
attn (Tensor): attention map.
q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C).
rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis.
rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis.
q_size (Tuple): spatial sequence size of query q with (q_h, q_w).
k_size (Tuple): spatial sequence size of key k with (k_h, k_w).
Returns:
attn (Tensor): attention map with added relative positional embeddings.
"""
# q: 300x196x64
# atten:300x196x196
q_h, q_w = q_size
k_h, k_w = k_size
# Rh: 14x14x64
Rh = get_rel_pos(q_h, k_h, rel_pos_h)
Rw = get_rel_pos(q_w, k_w, rel_pos_w)
B, _, dim = q.shape
# r_q: 300x14x14x64
r_q = q.reshape(B, q_h, q_w, dim)
# rel_h: 300x14x14x14
# 等价于:
# rel_h = torch.matmul(r_q, Rh.transpose(1, 2))
# rel_w = torch.matmul(r_q.transpose(1, 2), Rw.transpose(1, 2)).transpose(1, 2)
rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh)
rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw)
# 将相对位置编码加在atten里面,再resize回300x192x196
attn = (attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]).view(B, q_h * q_w, k_h * k_w)
return attn
attention矩阵为300x196x196,reshape成300x14x14x14x14,再使用矩阵加法,将相对位置编码分别加到倒数2个维度上,再reshape回原来的大小。
对应的代码操作为:
attn = (attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]).view(B, q_h * q_w, k_h * k_w)
neck部分由两个卷积层组成,分别是256x768x1x1和256x256x3x3,最后输出的image imbedding的尺寸是1x256x64x64。
根据输入的point和boxs返回sparse embedding, 根据mask返回dense embeddings。
point_embeddings = [nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings)]
self.point_embeddings = nn.ModuleList(point_embeddings)
4代表了表示pos/neg + 2 box corners,即demo里面的添加点和消除点、以及box框的左上角和右下角;
0:neg,对应demo中的消除点
1:pos,对应demo中的添加点
2:代表box左上角点
3:代表box右下角点
self.not_a_point_embed = nn.Embedding(1, embed_dim)
if pad:
padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device)
padding_label = -torch.ones((labels.shape[0], 1), device=labels.device)
points = torch.cat([points, padding_point], dim=1)
labels = torch.cat([labels, padding_label], dim=1)
此时point大小为Nx2x2,label为Nx2
def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor:
"""Positionally encode points that are normalized to [0,1]."""
# assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
coords = 2 * coords - 1
coords = coords @ self.positional_encoding_gaussian_matrix
coords = 2 * np.pi * coords
# outputs d_1 x ... x d_n x C shape
return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1)
point_embedding = self.pe_layer.forward_with_coords(points, self.input_image_size)
point_embedding[labels == -1] = 0.0
point_embedding[labels == -1] += self.not_a_point_embed.weight # 对应label为-1的,加上not_a_point_embed
point_embedding[labels == 0] += self.point_embeddings[0].weight # neg点加上point_embeddings[0]
point_embedding[labels == 1] += self.point_embeddings[1].weight # pos点加上point_embeddings[1]
完整point embedding的流程如下:
def _embed_points(
self,
points: torch.Tensor,
labels: torch.Tensor,
pad: bool, ) -> torch.Tensor:
"""Embeds point prompts."""
points = points + 0.5
# Shift to center of pixel
# 如果没有输入的box的话,会将points的长度用0补充形成Nx2x2,label用【-1】补充成Nx2
if pad:
padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device)
padding_label = -torch.ones((labels.shape[0], 1), device=labels.device)
points = torch.cat([points, padding_point], dim=1)
labels = torch.cat([labels, padding_label], dim=1)
# 将points与一个2x128的随机高斯矩阵相乘再通过进行sin、cos运算,两者的运算结果拼接得到
# point_embedding: Nx1x256 或者 Nx2x256
point_embedding = self.pe_layer.forward_with_coords(points, self.input_image_size)
point_embedding[labels == -1] = 0.0
point_embedding[labels == -1] += self.not_a_point_embed.weight
point_embedding[labels == 0] += self.point_embeddings[0].weight
point_embedding[labels == 1] += self.point_embeddings[1].weight
return point_embedding
返回之后需要与sparse embedding进行拼接:
# 如果只有point,那么sparse_embeddings的size是Nx2x256,如果还有box则是Nx1x256
sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1)
如果只有point,当前sparse_embeddings 的大小为Nx2x256
如果还有box,当前sparse_embeddings 的大小为Nx1x256
bbox一般有2个点,其编码步骤如下:
step1: 所以回先resize为Nx2x2;
step2: 再使用point embedding编码的方式,得到corner_embedding,
step3: 再加上之前生成的可学习的embeding向量;
最后输出的corner_embedding大小为Nx2x256。
def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor:
"""Embeds box prompts."""
boxes = boxes + 0.5 # Shift to center of pixel
# 操作与points类似,讲4个点resize成Nx2x2
coords = boxes.reshape(-1, 2, 2)
# 返回Nx2x256的embedding
corner_embedding = self.pe_layer.forward_with_coords(coords, self.input_image_size)
# 再加上点的embedding
corner_embedding[:, 0, :] += self.point_embeddings[2].weight
corner_embedding[:, 1, :] += self.point_embeddings[3].weight
return corner_embedding
最后输出的box的embedding的尺寸是Nx2x256。
合并(concat)point embedding和corner embedding,可以得到sparse embedding:
self.no_mask_embed = nn.Embedding(1, embed_dim) # embed_dim=256
dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand(
bs, -1, self.image_embedding_size[0], self.image_embedding_size[1])
mask会先进入一个1x2x2x4的卷积,stride=2;LN;
然后再进入一个4x2x2x16的卷积,stride=2;LN;
最后再进入一个16x1x1x256的卷积;
得到最后的mask_embedding的size为Nx256x64x64
最终mask embeding作为dense embedding输出,大小为Nx256x64x64。
初始化几个可学习的参数:
可学习的mask tokens:4x256
# num_mask_tokens = 3 + 1 = 4, transformer_dim = 256
# 输出一个4x256的矩阵
self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim)
可学习的iou tokens:1x256
self.iou_token = nn.Embedding(1, transformer_dim)
image_pe: 跟image embedding一样大的位置编码256x64x64 ,见prompt_encoder.py:PositionalEmbeddingRandom.get_dense_pe()
就是将64x64个坐标点归一化之后,与随机高斯矩阵相乘(2x128),再将结果分别进行sin和cos,最后再拼到一起,输出的大小为256x64x64,与image_embedding大小基本一致了。
class PositionEmbeddingRandom(nn.Module):
"""
Positional encoding using random spatial frequencies.
"""
def init(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None:
super().init()
if scale is None or scale <= 0.0:
scale = 1.0
# 构建一个2x128的随机矩阵作为位置编码高斯矩阵
self.register_buffer(
"positional_encoding_gaussian_matrix",
scale * torch.randn((2, num_pos_feats)),
)
def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor:
"""Positionally encode points that are normalized to [0,1]."""
# assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
coords = 2 * coords - 1
# 矩阵乘法:64x64xx2 @ 2x128 ---> 64x64x128
coords = coords @ self.positional_encoding_gaussian_matrix
coords = 2 * np.pi * coords
# outputs d_1 x ... x d_n x C shape
# cat, 最后一个维度上拼接:64x64x256
return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1)
def forward(self, size: Tuple[int, int]) -> torch.Tensor:
"""Generate positional encoding for a grid of the specified size."""
h, w = size
device: Any = self.positional_encoding_gaussian_matrix.device
# 构造一个64x64的全1矩阵
grid = torch.ones((h, w), device=device, dtype=torch.float32)
# 行、列累加
y_embed = grid.cumsum(dim=0) - 0.5
x_embed = grid.cumsum(dim=1) - 0.5
# 行列累加结果归一化
y_embed = y_embed / h
x_embed = x_embed / w
# 行列拼接:64x64x2,编码后的结果是64x64x256
pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1))
# 最后输出256x64x64
return pe.permute(2, 0, 1) # C x H x W
需要注意的是:
sparse embedding: point、bbox prompt合并后的产物,一般为NxXx256
iou token: 可学习参数,大小为1x256
mask token: 可学习参数,大小为4x256
首先将iou token和mask token 拼接得到一个5x256的矩阵,再将其拓展到与sparse embedding一个维度Nx5x256;
再将拓展后的矩阵与sparse embedding拼接得到tokens,其大小Nx(5+X)x256;
# 代码见:mask_decoder.py -> predict_masks
# 拼接iou_token和mask token得到i: 5x256 的tokens
output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0)
# 拓展成稀疏prompt编码的个数:Nx5x256
output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1)
# 再与稀疏矩阵拼接,假设稀疏矩阵只有point为Nx2x256,拼接之后则为Nx(5+2)x256
tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1)
需要注意的是:
image embedding: 是image encoder的输出,大小为为1x256x64x64
dense prompt: 是mask embedding的产物,大小为Nx256x64x64
image embedding拓展维度之后直接与dense prompt相加,得到image_embedding,大小为Nx256x64x64:
# 将image embedding(1x256x64x64)拓展成稠密prompt的维度:Nx256x64x64
src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)
# 将拓展后的image embedding直接与稠密prompt相加:Nx256x64x64
src = src + dense_prompt_embeddings
需要注意的是:image_pe相当于特征图中每个位置进行了与point类似的编码操作
# 将256x64x64的位置编码,拓展成Nx256x64x64
pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)
def predict_masks(
self,
image_embeddings: torch.Tensor,
image_pe: torch.Tensor,
sparse_prompt_embeddings: torch.Tensor,
dense_prompt_embeddings: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Predicts masks. See 'forward' for more details."""
# Concatenate output tokens
# 拼接iou_token和mask token得到i: 5x256 的tokens
output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0)
# 拓展成稀疏prompt编码的个数:Nx5x256
output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1)
# 再与稀疏矩阵拼接,假设稀疏矩阵只有point为Nx2x256,拼接之后则为Nx(5+2)x256
tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1)
# Expand per-image data in batch direction to be per-mask
# 将image embedding(1x256x64x64)拓展成稠密prompt的维度:Nx256x64x64
src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)
# 将拓展后的image embedding直接与稠密prompt相加:Nx256x64x64
src = src + dense_prompt_embeddings
# 将256x64x64的位置编码,拓展成Nx256x64x64
pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)
b, c, h, w = src.shape
# Run the transformer:这里使用的TwoWayTransformer,有必要对输入再说明一下
# src:image_bedding + dense_prompt(mask),Nx256x64x64
# pos_src: 位置编码,Nx256x64x64
# tokens: iou_tokens + mask_tokens + sparse_prompt(point/bbox),Nx(5+x)x256
hs, src = self.transformer(src, pos_src, tokens)
# 后处理
iou_token_out = hs[:, 0, :]
mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :]
# Upscale mask embeddings and predict masks using the mask tokens
src = src.transpose(1, 2).view(b, c, h, w)
upscaled_embedding = self.output_upscaling(src)
hyper_in_list: List[torch.Tensor] = []
for i in range(self.num_mask_tokens):
hyper_in_list.append(self.output_hypernetworks_mlpsi)
hyper_in = torch.stack(hyper_in_list, dim=1)
b, c, h, w = upscaled_embedding.shape
masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w)
# Generate mask quality predictions
iou_pred = self.iou_prediction_head(iou_token_out)
return masks, iou_pred
【关于transformer的一个心得】:q、k、v中,k、v一定具有相同的size,最后输出的attention的size是由q来决定的。
参数:
所谓的TwoWay:两轮次循环,第一次point_embedding自注意,第二次则加上上一轮输出的query再进行attention。
两层TwoWayAttentionBlock:
整个流程如下:
def forward(
self,
image_embedding: Tensor,
image_pe: Tensor,
point_embedding: Tensor,
) -> Tuple[Tensor, Tensor]:qur
"""
Args:
image_embedding (torch.Tensor): image to attend to. Should be shape
B x embedding_dim x h x w for any h and w.
image_pe (torch.Tensor): the positional encoding to add to the image. Must
have the same shape as image_embedding.
point_embedding (torch.Tensor): the embedding to add to the query points.
Must have shape B x N_points x embedding_dim for any N_points.
Returns:
torch.Tensor: the processed point_embedding
torch.Tensor: the processed image_embedding
"""
# BxCxHxW -> BxHWxC == B x N_image_tokens x C
# image_embedding: Nx256x64x64
bs, c, h, w = image_embedding.shape
# Nx256x4096 ---> Nx4096x256
image_embedding = image_embedding.flatten(2).permute(0, 2, 1)
# Nx256x4096 ---> Nx4096x256
image_pe = image_pe.flatten(2).permute(0, 2, 1)
# Prepare queries
queries = point_embedding # Nx(5+x)x256
keys = image_embedding # Nx4096x256
# Apply transformer blocks and final layernorm
# 将稀疏prompt和iou、mask的tokens组合tokens作为querise,image embedding作为key
# 进入2层的TwoWayAttentionBlock
for layer in self.layers:
queries, keys = layer(
queries=queries,
keys=keys,
query_pe=point_embedding,
key_pe=image_pe,
)
# Apply the final attention layer from the points to the image
# 出TwoWayAttentionBlock的querise继续加上组合tokens作为querise
q = queries + point_embedding
# 出TwoWayAttentionBlock的key再加上位置编码作为key
k = keys + image_pe
# 进入最后的attention层,但是v还是没有加上位置编码的key
attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys)
# querise + attention的输出并norm后输出
queries = queries + attn_out
queries = self.norm_final_attn(queries)
return queries, keys
TwoWayTransformer返回的结果为:
hs: Nx(5+x)x256
src: Nx4096x256
取tokens
reshape src: Nx4096x256 —> Nx256x64x64
通过2层转置卷积,将src变成Nx32x256x256
将4个mask token分别送入4个独立的3层全连接网络中(每个channel进入不同的FC),最终得到hyper_in:Nx4x32
将hyper_in 矩阵乘 reshap src得到Nx4x256x256的矩阵,这就是输出的mask!
将tokens送入3层全连接网络,最后得到iou_pred,大小为Nx1x4,这就输出的IoU_pred!
流程示意图如下,最后输出的N与传入的prompt的数量有关
hs, src = self.transformer(src, pos_src, tokens)
# 后处理
iou_token_out = hs[:, 0, :]
mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :]
# Upscale mask embeddings and predict masks using the mask tokens
src = src.transpose(1, 2).view(b, c, h, w)
upscaled_embedding = self.output_upscaling(src)
hyper_in_list: List[torch.Tensor] = []
for i in range(self.num_mask_tokens):
hyper_in_list.append(self.output_hypernetworks_mlpsi)
hyper_in = torch.stack(hyper_in_list, dim=1)
b, c, h, w = upscaled_embedding.shape
masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w)
# Generate mask quality predictions
iou_pred = self.iou_prediction_head(iou_token_out)
return masks, iou_pred
输出的mask有4个通道:
通道0:不用
通道1:whole
通道2:part
通道3:subpart
所以一般有一个multimask_output字段来控制是只输出whole,还是全部都输出。
代码详见:automatic_mask_generator.py
class SamAutomaticMaskGenerator:
def init(
self,
model: Sam, # sam模型
points_per_side: Optional[int] = 32, # 每个边需要采样的点数默认为32,最后的总点数为32x32
points_per_batch: int = 64, # 模型可以同时处理的点数,默认64,数字越大越快GPU越高
pred_iou_thresh: float = 0.88, # iou阈值,默认0.88
stability_score_thresh: float = 0.95,
stability_score_offset: float = 1.0,
box_nms_thresh: float = 0.7,
crop_n_layers: int = 0, # 裁剪的层数
crop_nms_thresh: float = 0.7,
crop_overlap_ratio: float = 512 / 1500, # 裁剪图片间的重叠情况
crop_n_points_downscale_factor: int = 1,
point_grids: Optional[List[np.ndarray]] = None, # 点的网格列表,与points_per_side直接相关
min_mask_region_area: int = 0,
output_mode: str = "binary_mask",
) -> None:
每张图片默认撒点方式: 首先根据每个边需要采样的点数,默认时32,生成32x32的网格point_grids,所以最后输出的就是1024个坐标,且都归一化到0-1之间。
crop_n_layers: 裁剪的层数,每层裁剪的crop_img个数为2^(n+1)个,即第一层裁剪4个,第二层16个,依次类推.
Step1: 原图裁剪,一般crop_n_layers设置为0,即送全图区域:
Step2: 将图片补边,再resize到1024x1024,送入Image Encoder中生成image embedding;
Step3: 图片宽高方向各均匀生成32个位置,组成1024个坐标点;
Step4: 每次送入64个坐标点,迭代1次,生成mask及iou_pred;
Step5: 结果后处理
根据iou阈值(默认0.88)过滤mask
对过滤后计算calculate_stability_score稳定性分值=(mask > 1的数量) / (mask > -1的数量)
根据calculate_stability_score过滤mask,阈值默认为0.95
对过滤后的mask取阈值0,得到掩膜,根据掩模计算外界矩形框
过滤外界矩形框达到crop边界的对应的mask
将截取图片crop_img的mask,映射到原图尺寸上
再将mask转化为rle编码,用于节省内存,mask拉平,(3,3)表示第3个元素开始,后面3个都是1
def generate(self, image: np.ndarray) -> List[Dict[str, Any]]:
# Generate masks
# 生成mask
mask_data = self._generate_masks(image)
# Filter small disconnected regions and holes in masks
# 过滤小的区域或者空洞
if self.min_mask_region_area > 0:
mask_data = self.postprocess_small_regions(
mask_data,
self.min_mask_region_area,
max(self.box_nms_thresh, self.crop_nms_thresh),
)
# Encode masks
# 输出的mask格式,默认输出binary_mask二值掩码
if self.output_mode == "coco_rle":
mask_data["segmentations"] = [coco_encode_rle(rle) for rle in mask_data["rles"]]
elif self.output_mode == "binary_mask":
mask_data["segmentations"] = [rle_to_mask(rle) for rle in mask_data["rles"]]
else:
mask_data["segmentations"] = mask_data["rles"]
# Write mask records
# 将结果整理输出
curr_anns = []
for idx in range(len(mask_data["segmentations"])):
ann = {
"segmentation": mask_data"segmentations",
"area": area_from_rle(mask_data"rles"),
"bbox": box_xyxy_to_xywh(mask_data"boxes").tolist(),
"predicted_iou": mask_data"iou_preds".item(),
"point_coords": [mask_data"points".tolist()],
"stability_score": mask_data"stability_score".item(),
"crop_box": box_xyxy_to_xywh(mask_data"crop_boxes").tolist(),
}
curr_anns.append(ann)
return curr_anns
def _generate_masks(self, image: np.ndarray) -> MaskData:
orig_size = image.shape[:2]
#由于默认的crop_n_layers为0,所以返回的crop_box为全图,layer_idxs只有0
crop_boxes, layer_idxs = generate_crop_boxes(
orig_size, self.crop_n_layers, self.crop_overlap_ratio
)
# Iterate over image crops
# 将每一个抠图区域送入网络中
data = MaskData()
for crop_box, layer_idx in zip(crop_boxes, layer_idxs):
crop_data = self._process_crop(image, crop_box, layer_idx, orig_size)
data.cat(crop_data)
# Remove duplicate masks between crops
# 将所有crop图片的结果汇总到一起进行NMS过滤
if len(crop_boxes) > 1:
# Prefer masks from smaller crops
scores = 1 / box_area(data["crop_boxes"])
scores = scores.to(data["boxes"].device)
keep_by_nms = batched_nms(
data["boxes"].float(),
scores,
torch.zeros_like(data"boxes"), # categories
iou_threshold=self.crop_nms_thresh,
)
data.filter(keep_by_nms)
data.to_numpy()
return data
def _process_crop(
self,
image: np.ndarray,
crop_box: List[int],
crop_layer_idx: int,
orig_size: Tuple[int, ...],
) -> MaskData:
# Crop the image and calculate embeddings
# 裁剪图片
x0, y0, x1, y1 = crop_box
cropped_im = image[y0:y1, x0:x1, :]
cropped_im_size = cropped_im.shape[:2]
# 将裁剪后的图片送入网络,变成了encoder所需的1024*1024
# 1、先按照比例将图片最大的边resize到1024
# 2、调整位置HxWxC ---> 1xCxHxW
# 3、再将图片归一化
# 4、用0padding,使其成为1024x1024的size
# 5、送入encoder计算image_embedding
self.predictor.set_image(cropped_im)
# Get points for this crop
# 将1024x2的矩阵乘以裁剪图片的大小,就得到了在裁剪图片上的grid
points_scale = np.array(cropped_im_size)[None, ::-1]
points_for_image = self.point_grids[crop_layer_idx] * points_scale
# Generate masks for this crop in batches
data = MaskData()
# 每张图送入网络1024个点,每次同时计算points_per_batch(64个),因此需要迭代1024 / 64 = 16次
for (points,) in batch_iterator(self.points_per_batch, points_for_image):
# 1、将坐标点映射到1024x1024的图片上
# 2、每个点的label设置为1,label的size就是64x1
# 3、送入decoder计算mask
# 3.1 先送入prompt encoder,由于只有point,得到sparse embedding和dense embedding(no_mask_embed)
# 3.2 将其送入mask decoder得到最后输出的
# 3.3 切片输出,如果需要输出多个mask,取index 1 ~ 3, 如果只输出一个mask index取0
# 3.4 对mask做后处理:先resize回1024x1024,取出非padding部分再resize回原图
# 4、一系列的后处理
# 4.1 根据iou阈值(默认0.88)过滤mask
# 4.2 对过滤后计算calculate_stability_score稳定性分值=(mask > 1的数量) / (mask > -1的数量)
# 4.3 根据calculate_stability_score过滤mask,阈值默认为0.95
# 4.4 对过滤后的mask取阈值0,得到掩膜,根据掩模计算外界矩形框
# 4.5 过滤外界矩形框达到crop边界的对应的mask
# 4.6 将截取图片crop_img的mask,映射到原图尺寸上
# 4.7 再将mask转化为rle编码,用于节省内存,mask拉平,(3,3)表示第3个元素开始,后面3个都是1
batch_data = self._process_batch(points, cropped_im_size, crop_box, orig_size)
data.cat(batch_data)
del batch_data
self.predictor.reset_image()
# Remove duplicates within this crop.
# 根据bbox,使用NMS过滤重复的结果
keep_by_nms = batched_nms(
data["boxes"].float(),
data["iou_preds"],
torch.zeros_like(data"boxes"), # categories
iou_threshold=self.box_nms_thresh,
)
data.filter(keep_by_nms)
# Return to the original image frame
# bbox和point映射回原图坐标
data["boxes"] = uncrop_boxes_xyxy(data["boxes"], crop_box)
data["points"] = uncrop_points(data["points"], crop_box)
data["crop_boxes"] = torch.tensor([crop_box for _ in range(len(data["rles"]))])
return data
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。