当前位置:   article > 正文

【Python&语义分割】Segment Anything(SAM)模型详细使用教程+代码解释(一)_segment anything 教程

segment anything 教程

1 Segment Anything介绍

1.1 概况

        Meta AI 公司的 Segment Anything 模型是一项革命性的技术,该模型能够根据文本指令或图像识别,实现对任意物体的识别和分割。这一模型的推出,将极大地推动计算机视觉领域的发展,并使得图像分割技术进一步普及化。

        论文地址:https://arxiv.org/abs/2304.02643

        项目地址:Segment Anything

1.2 使用方法

        具体使用方法上,Segment Anything 提供了简单易用的接口,用户只需要通过提示,即可进行物体识别和分割操作。例如在图片处理中,用户可以通过 Hover & Click 或 Box 等方式来选取物体。值得一提的是,SAM 还支持通过上传自己的图片进行物体分割操作,提取物体用时仅需数秒。

        总的来说,Meta AI 的 Segment Anything 模型为我们提供了一种全新的物体识别和分割方式,其强大的泛化能力和广泛的应用前景将极大地推动计算机视觉领域的发展。未来,我们期待看到更多基于 Segment Anything 的创新应用,以及在科学图像分析、照片编辑等领域的广泛应用。

2 代码复现+讲解

2.1 用于生成显示掩膜的函数(初始化)

        里面包含三个封装好的函数,一个是生成掩膜(分割的轮廓)的函数,一个是显示标记点(自己选择需要分割的目标)的函数,一个是显示标记框(需要分割的目标)的函数。

  1. import cv2
  2. import sys
  3. import torch
  4. import numpy as np
  5. from datetime import datetime
  6. import matplotlib.pyplot as plt
  7. from segment_anything import sam_model_registry, SamPredictor
  8. def show_mask(mask, ax, random_color=False):
  9. if random_color:
  10. color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
  11. else:
  12. color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6])
  13. h, w = mask.shape[-2:]
  14. mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
  15. ax.imshow(mask_image)
  16. def show_points(coords, labels, ax, marker_size=375):
  17. pos_points = coords[labels == 1]
  18. neg_points = coords[labels == 0]
  19. ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white',
  20. linewidth=1.25)
  21. ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white',
  22. linewidth=1.25)
  23. def show_box(box, ax):
  24. x0, y0 = box[0], box[1]
  25. w, h = box[2] - box[0], box[3] - box[1]
  26. ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0, 0, 0, 0), lw=2))

2.2 模型预加载

        这里包含的代码是打开图片,转换图片格式以及加载模型。注意这里的模型要和你定义的模型类型保持一致(官网给出了三种模型)。模型比较大,我已经将模型以及Segment Anything的包下载至网盘中了,需要的可以在我之前发布的SAM模型安装教程的文章2.2.2小节中下载:【Python&语义分割】Segment Anything(SAM)模型介绍&安装教程

  1. image = cv2.imread(r'B:/truck.jpg') # 读取的图像以NumPy数组的形式存储在变量image中
  2. print("[%s]正在转换图片格式......" % datetime.now().strftime('%Y-%m-%d %H:%M:%S'))
  3. image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # 将图像从BGR颜色空间转换为RGB颜色空间,还原图片色彩(图像处理库所认同的格式)
  4. print("[%s]正在初始化模型参数......" % datetime.now().strftime('%Y-%m-%d %H:%M:%S'))
  5. # plt.figure(figsize=(10, 10)) # 创建一个新的图形窗口,设置其大小为10x10英寸
  6. # plt.imshow(image) # 使用imshow函数在创建的图形窗口中显示图像
  7. # plt.axis('on') # 开启图像坐标轴,使得图像下的像素坐标可以显示出来
  8. # plt.show() # 显示已经创建的图形窗口和其中的内容
  9. sys.path.append("..") # 将当前路径上一级目录添加到sys.path列表,这里模型使用绝对路径所以这行没啥用
  10. sam_checkpoint = "G:/Neat Download Manager/Misc/sam_vit_b_01ec64.pth" # 定义模型路径
  11. model_type = "vit_b" # 定义模型类型
  12. device = "cuda" # "cpu" or "cuda"
  13. sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
  14. sam.to(device=device) # 定义模型参数
  15. predictor = SamPredictor(sam) # 调用预测模型
  16. predictor.set_image(image)
  17. # 调用`SamPredictor.set_image`来处理图像以产生一个图像嵌入。`SamPredictor`会记住这个嵌入,并将其用于随后的掩码预测

2.3 单点输入mask,分割一个目标

        这里的input_point指你想分割的兴趣点(图片坐标),这里的input_label代表目标的标签,如果你想要分割它input_label的值就为1,如果想排除它则值为0。

  1. # --------------------------------------单点输入--------------------------------------
  2. print("【单点分割阶段】")
  3. print("[%s]正在分割图片......" % datetime.now().strftime('%Y-%m-%d %H:%M:%S'))
  4. input_point = np.array([[250, 187]])
  5. # 单点 prompt 输入格式为(x, y)和并表示出点所带有的标签1(前景点)或0(背景点)。
  6. input_label = np.array([1]) # 点所对应的标签
  7. plt.figure(figsize=(10, 10))
  8. plt.imshow(image)
  9. show_points(input_point, input_label, plt.gca())
  10. plt.axis('on')
  11. plt.show()
  12. masks, scores, logit = predictor.predict(
  13. point_coords=input_point,
  14. point_labels=input_label,
  15. multimask_output=True, # 为False时,它将返回一个掩码
  16. )
  17. # print(masks.shape) # (3, 2160, 3840)波段,高,宽
  18. for i, (mask, score) in enumerate(zip(masks, scores)):
  19. # 三个置信度不同的图
  20. plt.figure(figsize=(10, 10))
  21. plt.imshow(image)
  22. show_mask(mask, plt.gca())
  23. show_points(input_point, input_label, plt.gca())
  24. plt.title(f"Mask {i + 1}, Score: {score:.3f}", fontsize=18)
  25. plt.axis('off')
  26. plt.show()

        单点输入时,会出现三种不同置信度结果的图,可以自己选择。

2.4 多点输入masks,分割一/多个目标

        这里的目标点可以同时输入多个,不同的lable可以控制不同的分割效果。如果label均为1,则将两个点分割成同一目标(单个输入点不明确,需要让模型返回了与其一致的多个对象)。如果label一个为1,一个为0则分割一个,排除一个。下面第一张图是label均为1的效果,第二张图为一个1,一个0的效果。此外还可以将先前迭代的掩码(logits值)提供给模型以帮助预测。

  1. # --------------------------------------多点输入--------------------------------------
  2. print("【多点分割阶段】")
  3. print("[%s]正在分割图片......" % datetime.now().strftime('%Y-%m-%d %H:%M:%S'))
  4. input_point = np.array([[250, 184], [562, 322]])
  5. input_label = np.array([1, 0]) # input_label = np.array([1, 0])负点区域,用来排除该点
  6. mask_input = logit[np.argmax(scores), :, :] # Choose the model's best mask
  7. # 将先前迭代的掩码logit值提供给模型以帮助预测
  8. masks, _, _ = predictor.predict(
  9. point_coords=input_point,
  10. point_labels=input_label,
  11. mask_input=mask_input[None, :, :],
  12. multimask_output=False,
  13. )
  14. # print(masks.shape) # output: (1, 600, 900)
  15. plt.figure(figsize=(10,10))
  16. plt.imshow(image)
  17. show_mask(masks, plt.gca())
  18. show_points(input_point, input_label, plt.gca())
  19. plt.axis('off')
  20. plt.show()

2.5 矩形输入mask,分割一个目标

        SAM支持将xyxy格式(左上和右下角坐标)的矩形作为输入,将框内的主体目标识别出来。

  1. # --------------------------------------矩形输入--------------------------------------
  2. print("【矩形分割阶段】")
  3. print("[%s]正在分割图片......" % datetime.now().strftime('%Y-%m-%d %H:%M:%S'))
  4. input_box = np.array([212, 300, 350, 437])
  5. masks, _, _ = predictor.predict(
  6. point_coords=None,
  7. point_labels=None,
  8. box=input_box[None, :],
  9. multimask_output=False,
  10. )
  11. plt.figure(figsize=(10, 10))
  12. plt.imshow(image)
  13. show_mask(masks[0], plt.gca())
  14. show_box(input_box, plt.gca())
  15. plt.axis('off')
  16. plt.show()

2.6 矩形+点同时输入masks,分割一个目标

        点和矩形可以同时输入,只需定义这两种类型的label即可。在这里,这可以用来只选择卡车的轮胎(将车轴部分设置为负点),而不是整个车轮。需要注意的是矩形的label只能为1(正类)。

  1. # --------------------------------------点&矩形输入--------------------------------------
  2. print("【单点&矩形分割阶段】")
  3. print("[%s]正在分割图片......" % datetime.now().strftime('%Y-%m-%d %H:%M:%S'))
  4. input_box = np.array([215, 310, 350, 430]) # 只能默认框住正类
  5. input_point = np.array([[287, 375]])
  6. input_label = np.array([0]) # 将点设置为负点
  7. masks, _, _ = predictor.predict(
  8. point_coords=input_point,
  9. point_labels=input_label,
  10. box=input_box,
  11. multimask_output=False,
  12. )
  13. plt.figure(figsize=(10, 10))
  14. plt.imshow(image)
  15. show_mask(masks[0], plt.gca())
  16. show_box(input_box, plt.gca())
  17. show_points(input_point, input_label, plt.gca())
  18. plt.axis('off')
  19. plt.show()

2.7 多个矩形输入masks,分割多个目标

        SamPredictor函数可以使用predict_tarch方法对同一图像输入多个提示(点、矩形)。该方法假设输入点已经是tensor张量,且boxes信息与image size相符合(已有来自对象检测器的输出结果)。

        SamPredictor函数(也可以使用segment_anything.utils.transforms)可以将矩形信息编码为特征向量(以实现对多个矩形的支持,transformed_boxes),然后预测masks。

  1. # --------------------------------------多矩形输入--------------------------------------
  2. print("【多矩形分割阶段】")
  3. print("[%s]正在分割图片......" % datetime.now().strftime('%Y-%m-%d %H:%M:%S'))
  4. input_boxes = torch.tensor([
  5. [75, 275, 1725, 850],
  6. [425, 600, 700, 875],
  7. [1375, 550, 1650, 800],
  8. [1240, 675, 1400, 750],
  9. ], device=predictor.device) # 假设为目标检测的预测结果
  10. input_boxes = input_boxes/2
  11. transformed_boxes = predictor.transform.apply_boxes_torch(input_boxes, image.shape[:2])
  12. masks, _, _ = predictor.predict_torch(
  13. point_coords=None,
  14. point_labels=None,
  15. boxes=transformed_boxes,
  16. multimask_output=False,
  17. )
  18. # print(masks.shape) # batch_size,num_predicted_masks_per_input,H,W ------>[4, 1, 600, 900]
  19. plt.figure(figsize=(10, 10))
  20. plt.imshow(image)
  21. for mask in masks:
  22. show_mask(mask.cpu().numpy(), plt.gca(), random_color=True)
  23. for box in input_boxes:
  24. show_box(box.cpu().numpy(), plt.gca())
  25. plt.axis('off')
  26. plt.show()

3 总结

3.1 不足之处

        以上代码来源于官方的demo,自己修改了一部分。官方的源码只能简单的进行点/矩形输入,每次分割前都需要手动确定目标的图片坐标(x,y)。如果分割做成交互式的会更好,例如我点击图片中的某个点,它就分某个目标。

        另外官方的demo并没有保存图片的函数,如果3S工作者或者其他有需要的领域,可能需要保存分割后的mask就需要自己开发。我这里指的是单独保存mask,掩膜叠加底图显示的保存了也没啥用=。=

3.2 改进

        官方还有一个全局分割的demo我还没有分享,那个我已经加入了保存mask的代码,所以就没跟这篇文章一起分享,后面会分享给大家。此外我还编了一个单点输入mask的交互式操作的代码,后面都会分享给大家。

        总的来说,Segment Anything是真的强,官方的模型不夸张的说真的可以坐到分割万物。我自己拿高分辨率的遥感影像也试了试,建筑、树木、道路都分的还不错。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/不正经/article/detail/92406
推荐阅读
相关标签
  

闽ICP备14008679号