当前位置:   article > 正文

【自用】SAM模型论文笔记与复现代码(segment-anything-model)_分割一切代码复现

分割一切代码复现

总模型结构

一个prompt encoder,对提示进行编码,image encoder图像编码,生成embedding, 最后融合2个encoder,再接一个轻量的mask decoder,输出最后的mask。

模型结构示意图:

流程图

模型的结构如上图所示. prompt会经过prompt encoder, 图像会经过image encoder。然后将两部分embedding经过一个轻量化mask decoder得到融合后的特征。encoder部分使用的都是已有模型,decoder使用transformer。

image encoder

利用MAE(Masked AutoEncoder)预训练的ViT模型,对每张图片只处理一次,且在prompt encoder之前进行。输入(c,h,w)的图像,对图像进行缩放,按照长边缩放成1024,短边不够就填充,得到(c,1024,1024)的图像,经过image encoder,得到对图像16倍下采样的feature,大小为(256,64,64)。

prompt encoder

prompt encoder结构图:

分为两类:稀疏与密集

稀疏:
  • point:使用position encodings
  • box:使用position encodings
  • text:使用CLIP作为encoder
密集:
  • mask:使用卷积作为encoder

mask decoder

  • prompt self-attention
  • cross-attention(从prompt到image和从image到prompt)

valid mask(模型输出)

  • 解决混淆的输入: 对于一个prompt,模型会输出3个mask,实际上也可以输出更多的分割结果,3个可以看作一个物体的整体、部分、子部分,基本能满足大多数情况。使用IOU的方式,排序mask。在反向传播时,参与计算的只有loss最小的mask相关的参数.
  • 高效: 这里主要指的是prompt encodermask decoder。在web浏览器上,CPU计算只用约50ms

SAM模型复现

环境:

python 3.8.10
pytorch 1.11.0
cuda 11.3

环境安装

git clone https://github.com/facebookresearch/segment-anything
pip install opencv-python matplotlib
pip install -e .
wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth #下载SAM_VIT-H模型
  • 1
  • 2
  • 3
  • 4

定义用于可视化的工具函数

import numpy as np
import torch
import matplotlib.pyplot as plt
import cv2

def show_mask(mask, ax, random_color=False):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        color = np.array([30/255, 144/255, 255/255, 0.6])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)
    
def show_points(coords, labels, ax, marker_size=375):
    pos_points = coords[labels==1]
    neg_points = coords[labels==0]
    ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
    ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)   
    
def show_box(box, ax):
    x0, y0 = box[0], box[1]
    w, h = box[2] - box[0], box[3] - box[1]
    ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))    
    
def show_anns(anns):
    if len(anns) == 0:
        return
    sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
    ax = plt.gca()
    ax.set_autoscale_on(False)

    img = np.ones((sorted_anns[0]['segmentation'].shape[0], sorted_anns[0]['segmentation'].shape[1], 4))
    img[:,:,3] = 0
    for ann in sorted_anns:
        m = ann['segmentation']
        color_mask = np.concatenate([np.random.random(3), [0.35]])
        img[m] = color_mask
    ax.imshow(img)

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40

可视化原图片

image = cv2.imread('R.jpg')
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

plt.figure(figsize=(14,14))
plt.imshow(image)
plt.axis('on')
plt.show()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7

原图片:

加载SAM模型

import sys
sys.path.append("..")
from segment_anything import sam_model_registry, SamPredictor

sam_checkpoint = "sam_vit_h_4b8939.pth"
model_type = "vit_h"

device = "cuda"

sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)

predictor = SamPredictor(sam)

predictor.set_image(image)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15

点作为prompt

单点
input_point = np.array([[430, 605]])
input_label = np.array([1])
plt.figure(figsize=(14,14))
plt.imshow(image)
show_points(input_point, input_label, plt.gca())
plt.axis('on')
plt.show()  
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7

使用SAM模型进行分割,并输出模型分割出的3个mask

masks, scores, logits = predictor.predict(
    point_coords=input_point,
    point_labels=input_label,
    multimask_output=True, #`multimask_output=True`表示是否输出三个mask结果
)

for i, (mask, score) in enumerate(zip(masks, scores)):
    plt.figure(figsize=(14,14))
    plt.imshow(image)
    show_mask(mask, plt.gca())
    show_points(input_point, input_label, plt.gca())
    plt.title(f"Mask {i+1}, Score: {score:.3f}", fontsize=18)
    plt.axis('off')
    plt.show()  
  
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15

多点(使用先前单点输出的mask作为mask prompt)
仅前景点
input_point = np.array([[430, 605],[520, 650]])
input_label = np.array([1, 1]) #1代表前景点(绿色),0代表后景点(红色)

mask_input = logits[np.argmax(scores), :, :]  #选择先前分数最高的mask

masks, _, _ = predictor.predict(
    point_coords=input_point,
    point_labels=input_label,
    mask_input=mask_input[None, :, :],
    multimask_output=False,
)
plt.figure(figsize=(14,14))
plt.imshow(image)
show_mask(masks, plt.gca())
show_points(input_point, input_label, plt.gca())
plt.axis('on')
plt.show() 
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17

前景点+后景点
input_point = np.array([[430, 605],[520, 650], [520,500]])
input_label = np.array([1, 1, 0])  #1代表前景点(绿色),0代表后景点(红色)

mask_input = logits[np.argmax(scores), :, :]  
masks, _, _ = predictor.predict(
    point_coords=input_point,
    point_labels=input_label,
    mask_input=mask_input[None, :, :],
    multimask_output=False,
)
plt.figure(figsize=(14,14))
plt.imshow(image)
show_mask(masks, plt.gca())
show_points(input_point, input_label, plt.gca())
plt.axis('on')
plt.show() 
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16

矩形框作为prompt

单个矩形框
input_box = np.array([730, 105, 1030, 315])

masks, _, _ = predictor.predict(
    point_coords=None,
    point_labels=None,
    box=input_box[None, :],
    multimask_output=False,
)

plt.figure(figsize=(17, 17))
plt.imshow(image)
show_mask(masks[0], plt.gca())
show_box(input_box, plt.gca())
plt.axis('on')
plt.show()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15

多个矩形框(需要使用transform.apply_boxes_torch方法进行转换)
input_boxes = torch.tensor([
    [730, 105, 1030, 315],
    [970, 155, 1025, 250]
], device=predictor.device)

transformed_boxes = predictor.transform.apply_boxes_torch(input_boxes, image.shape[:2])
masks, _, _ = predictor.predict_torch(
    point_coords=None,
    point_labels=None,
    boxes=transformed_boxes,
    multimask_output=False,
)

plt.figure(figsize=(17, 17))
plt.imshow(image)
for mask in masks:
    show_mask(mask.cpu().numpy(), plt.gca(), random_color=True)
for box in input_boxes:
    show_box(box.cpu().numpy(), plt.gca())
plt.axis('on')
plt.show()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21

自动分割

from segment_anything import SamAutomaticMaskGenerator
mask_generator = SamAutomaticMaskGenerator(sam)
masks = mask_generator.generate(image)
print(len(masks))
plt.figure(figsize=(20,20))
plt.imshow(image)
show_anns(masks)
plt.axis('on')
plt.show() 
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9

输出mask数量
178

开始调参

mask_generator_2 = SamAutomaticMaskGenerator(
    model=sam,
    points_per_side=32,
    pred_iou_thresh=0.86,
    stability_score_thresh=0.92,
    crop_n_layers=1,
    crop_n_points_downscale_factor=2,
    min_mask_region_area=100, 
)
masks_2 = mask_generator_2.generate(image)
print(len(masks_2))
plt.figure(figsize=(20,20))
plt.imshow(image)
show_anns(masks_2)
plt.axis('on')
plt.show() 
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16

输出mask数量
335

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/菜鸟追梦旅行/article/detail/677795
推荐阅读
相关标签
  

闽ICP备14008679号