当前位置:   article > 正文

SAM(分割一切模型)的简单调用_昇腾gpu sam分割

昇腾gpu sam分割

        Segment Anything Model (SAM):Meta AI 的一种新 AI 模型,只需单击一下即可“剪切”任何图像中的任何对象.

目录

一、前期准备(安装、搭建相关环境)

二、查看官方对其物体分割

三、实践:鼠标单击、画框对物体分割

四、将分割的内容单独剪切出来

五、分割图像的背景融合

六、目前的SAM衍生的标注工具

        由于本人目前的学习需要,了解了一下,sam相关的简单使用并进行了简单的总结,主要包括:指定点对其物体分割、指定框对其物体分割。

一、前期准备(安装、搭建相关环境)

第一步:进入GitHub下载整个项目:https://github.com/facebookresearch/segment-anything

第二步:在其项目环境下,安装安装 PyTorch 和 TorchVision 依赖项。(在硬件条件允许的条件下,建议安装支持 CUDA 的 PyTorch 和 TorchVision,速度更快),下载 PyTorch 可参考此处

第三步:确保安装了以下各个内容

pip install opencv-python pycocotools matplotlib onnxruntime onnx

第四步:下载模型检查点,可根据自己的需求选择,目前本人选择的是默认的vit_h.

以上便是准备工作的全部内容。正式开始!

二、查看官方对其物体分割

以下为官方提供内容(单点和单框)的整个过程详解,了解全部案例可访问:https://github.com/facebookresearch/segment-anything/blob/main/notebooks/predictor_example.ipynb

必要导入和帮助程序函数

  1. import numpy as np
  2. import torch
  3. import matplotlib.pyplot as plt
  4. import cv2
  5. # 分割相关
  6. def show_mask(mask, ax, random_color=False):
  7. if random_color:
  8. color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
  9. else:
  10. color = np.array([30/255, 144/255, 255/255, 0.6])
  11. h, w = mask.shape[-2:]
  12. mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
  13. ax.imshow(mask_image)
  14. # 仅显示点相关
  15. def show_points(coords, labels, ax, marker_size=375):
  16. pos_points = coords[labels==1]
  17. neg_points = coords[labels==0]
  18. ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
  19. ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
  20. #仅显示框相关
  21. def show_box(box, ax):
  22. x0, y0 = box[0], box[1]
  23. w, h = box[2] - box[0], box[3] - box[1]
  24. ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))
  25. #加载待处理图片
  26. image = cv2.imread('图片路径')
  27. image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

首先,加载 SAM 模型和预测变量。更改以下路径以指向 SAM 检查点。

  1. import sys
  2. from segment_anything import sam_model_registry, SamPredictor
  3. sam_checkpoint = "sam_vit_h_4b8939.pth"
  4. model_type = "vit_h"
  5. device = "cuda"#如果使用gpu
  6. sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
  7. sam.to(device=device)
  8. predictor = SamPredictor(sam)

通过调用 SamPredictor.set_image 处理图像以生成图像嵌入。SamPredictor 会记住这种嵌入,并将其用于后续的掩码预测。

predictor.set_image(image)

在其上选择一个点。点以 (x,y) 格式输入到模型中,并带有标签 1(前景点)或 0(背景点)。可输入多个点。

  1. # 单个点时
  2. input_point = np.array([[500,375]]) # 为要分割的指定点
  3. input_label = np.array([1]) # 为分割对象的性质(背景|前景)
  4. #为单个框
  5. input_box = np.array([425, 600, 700, 875])
  1. #只是为了显示点,与分割无关
  2. plt.figure(figsize=(10,10))
  3. plt.imshow(image)
  4. show_points(input_point, input_label, plt.gca())
  5. plt.axis('on')
  6. plt.show()

使用 SamPredictor.predict 进行预测。该模型返回掩码、这些掩码的质量预测以及可传递给下一次预测迭代的低分辨率掩码日志。

  1. #点
  2. masks, scores, logits = predictor.predict(
  3. point_coords=input_point,
  4. point_labels=input_label,
  5. multimask_output=True,
  6. )
  7. # 框
  8. masks, _, _ = predictor.predict(
  9. point_coords=None,
  10. point_labels=None,
  11. box=input_box[None, :],
  12. multimask_output=False,
  13. )

(使用 multimask_output=True(默认设置),SAM 输出 3 个掩码,其中分数给出了模型自己对这些掩码质量的估计。此设置适用于不明确的输入提示,并帮助模型消除与提示一致的不同对象的歧义。如果为 False,它将返回单个掩码。对于单点等模棱两可的提示,即使只需要一个掩码,也建议使用 multimask_output=True;可以通过选择分数最高的一个来选择最佳的单个掩模。这通常会产生更好的掩码。)——————so,后续本人使用的一直是False.

  1. for i, (mask, score) in enumerate(zip(masks, scores)):
  2. plt.figure(figsize=(10,10))
  3. plt.imshow(image)
  4. show_mask(mask, plt.gca())
  5. show_points(input_point, input_label, plt.gca())
  6. show_box(input_box, plt.gca())
  7. plt.title(f"Mask {i+1}, Score: {score:.3f}", fontsize=18)
  8. plt.axis('off')
  9. plt.show()

三、实践:鼠标单击、画框对物体分割

        本人使用opencv来完成鼠标的监听:cv2.setMouseCallback(),点击鼠标后的回调函数.

  1. cv2.setMouseCallback(windowName, onMouse [, userdata])
  2. 相关参数说明如下:
  3. windowName:窗口的名字
  4. onMouse:鼠标响应函数,回调函数
  5. userdata:传给回调函数的参数

        本人使用的回调函数是: on_mouse,以on_mouse为例介绍回调函数.

  1. on_mouse(event, x, y, flags, param)
  2. 相关参数说明如下:
  3. event 是 CV_EVENT_* 变量之一
  4. x 和 y 是鼠标在图像坐标系的坐标(不是窗口坐标系)
  5. flags 是 CV_EVENT_FLAG 的组合
  6. param 是用户定义的传递到 setMouseCallback 函数调用的参数

                以下是event、flags的可选值:

         

        单击画框实践代码:

  1. import cv2
  2. import numpy as np
  3. import matplotlib.pyplot as plt
  4. from segment_anything import SamPredictor, sam_model_registry
  5. sam = sam_model_registry["vit_h"](checkpoint='sam_vit_h_4b8939.pth')
  6. # sam.to(device="cuda") # 使用gpu
  7. predictor = SamPredictor(sam)
  8. def show_mask(mask, ax, random_color=False):
  9. # 掩膜部分
  10. if random_color:
  11. color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
  12. else:
  13. color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6])
  14. h, w = mask.shape[-2:]
  15. mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
  16. ax.imshow(mask_image)
  17. def on_mouse(event,x,y,flags,param):
  18. global img, point1
  19. img3 = img.copy()
  20. if event == cv2.EVENT_LBUTTONDOWN:#左键点击
  21. input_point = np.array([[x, y]])
  22. input_label = np.array([1])
  23. img3 = cv2.cvtColor(img3, cv2.COLOR_BGR2RGB)
  24. predictor.set_image(img3)
  25. masks, scores, logits = predictor.predict(
  26. point_coords=input_point,
  27. point_labels=input_label,
  28. multimask_output=False,
  29. )
  30. # 保存
  31. plt.imshow(img3)
  32. show_mask(masks, plt.gca())
  33. plt.savefig("1.png")
  34. if __name__ == '__main__':
  35. path = "图片路径"
  36. img = cv2.imread(path)
  37. cv2.namedWindow('image')
  38. cv2.setMouseCallback('image',on_mouse)
  39. cv2.imshow('image', img)
  40. cv2.waitKey(0)
  41. cv2.destroyAllWindows()

        画框实践代码:

  1. import cv2
  2. import numpy as np
  3. import matplotlib.pyplot as plt
  4. from segment_anything import SamPredictor, sam_model_registry
  5. sam = sam_model_registry["vit_h"](checkpoint='sam_vit_h_4b8939.pth')
  6. # sam.to(device="cuda") # gpu
  7. predictor = SamPredictor(sam)
  8. def show_mask(mask, ax, random_color=False):
  9. # 掩膜部分
  10. if random_color:
  11. color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
  12. else:
  13. color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6])
  14. h, w = mask.shape[-2:]
  15. mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
  16. ax.imshow(mask_image)
  17. def on_mouse(event,x,y,flags,param):
  18. global img, point1, point2
  19. img2 = img.copy()
  20. img3 = img.copy()
  21. if event == cv2.EVENT_LBUTTONDOWN: # 左键点击
  22. point1 = (x, y)
  23. cv2.circle(img2, point1, 10, (0, 255, 0), 5)
  24. cv2.imshow('image', img2)
  25. elif event == cv2.EVENT_MOUSEMOVE and (flags & cv2.EVENT_FLAG_LBUTTON): # 移动鼠标,左键拖拽
  26. cv2.rectangle(img2, point1, (x, y), (255, 0, 0), 15) # 需要确定的就是矩形的两个点(左上角与右下角),颜色红色,线的类型(不设置就默认)。
  27. cv2.imshow('image', img2)
  28. elif event == cv2.EVENT_LBUTTONUP: # 左键释放
  29. point2 = (x, y)
  30. cv2.rectangle(img2, point1, point2, (0, 0, 255), 5) # 需要确定的就是矩形的两个点(左上角与右下角),颜色蓝色,线的类型(不设置就默认)。
  31. cv2.imshow('image', img2)
  32. min_x = min(point1[0], point2[0])
  33. min_y = min(point1[1], point2[1])
  34. width = abs(point1[0] - point2[0])
  35. height = abs(point1[1] - point2[1])
  36. roi = [min_x, min_y, min_x + width, min_y + height]
  37. roi = np.array(roi).astype(dtype=int).tolist()
  38. image = cv2.cvtColor(img3, cv2.COLOR_BGR2RGB)
  39. predictor.set_image(image)
  40. # 选区
  41. input_box = np.array(roi)
  42. masks, _, _ = predictor.predict(
  43. point_coords=None,
  44. point_labels=None,
  45. box=input_box[None, :],
  46. multimask_output=False,
  47. )
  48. # 保存
  49. plt.imshow(img3)
  50. show_mask(masks, plt.gca())
  51. plt.savefig("1.png")
  52. if __name__ == '__main__':
  53. path = "图片路径"
  54. img = cv2.imread(path)
  55. cv2.namedWindow('image')
  56. cv2.setMouseCallback('image',on_mouse)
  57. cv2.imshow('image', img)
  58. cv2.waitKey(0)
  59. cv2.destroyAllWindows()

 注:若想将图片以cv的方式处理,需要进行以下的转换

  1. # 处理show_mask内结果
  2. mask_image *= 250.0
  3. mask2 = mask_image.astype(np.uint8)

四、将分割的内容单独剪切出来

本人最终需要将上述分割图片进行粘贴,以此进行了图像的处理以便于进行图像融合

1、灰度、二值图像

  1. image_gray = cv2.cvtColor(mask2, cv2.COLOR_BGR2GRAY) # 以灰度方法读取图像
  2. ret1, th1 = cv2.threshold(image_gray, 0, 255, cv2.THRESH_OTSU) # 方法选择为THRESH_OTSU

2、将黑色非背景区域使用掩膜去除

  1. # 按位与运算
  2. pic = cv2.bitwise_and(img3, img3, mask=th1)

3、按其掩膜区域进行剪裁

streamline_pic = pic[roi[1]:roi[3], roi[0]:roi[2]]

五、分割图像的背景融合

 图像融合,本人使用cv2.seamlessClone:

  1. cv2.seamlessClone(src, dst, mask, center, flags)
  2. 相关参数说明如下:
  3. src:目标影像
  4. dst:背景图像
  5. mask:目标影像上的mask,表示目标影像上那些区域是感兴趣区域。
  6. center:目标影像的中心在背景图像上的坐标!注意是目标影像的中心!
  7. flags:选择融合的方式:
  8. NORMAL_CLONE(不保留dst 图像的texture细节)
  9. MIXED_CLONE(保留dest图像的texture 细节)
  10. MONOCHROME_TRANSFER( 不保留src图像的颜色细节,只有src图像的质地,颜色和目标图像一样,可以用来进行皮肤质地填充)

具体使用:

  1. img = cv2.imread(path_back) # 大图(背景)
  2. image = cv2.imread(path_stick) # 小图(前景)
  3. h_img, w_img = img.shape[:2]
  4. h_image, w_image = image.shape[:2]
  5. mask = 255 * np.ones(image.shape, image.dtype)
  6. w = random.randrange(int(0.5 * w_image), int(w_img - 0.5 * w_image - 1))
  7. h = random.randrange(int(0.5 * h_image), int(h_img - 0.5 * h_image - 1))
  8. center = (w, h)
  9. normal_clone = cv2.seamlessClone(image, img, mask, center, cv2.MIXED_CLONE)
  10. cv2.imwrite("2.jpg", normal_clone)

六、目前的SAM衍生的标注工具

可参考以下文章完成:https://blog.csdn.net/weixin_45977690/article/details/130088039文章浏览阅读1.5w次,点赞51次,收藏238次。详细教程,使用Segment Anything(SAM)模型当工具进行自己数据的自动标注。https://blog.csdn.net/weixin_45977690/article/details/130088039

利用SAM实现自动标注-CSDN博客文章浏览阅读943次。(2)检测图像的文件(可调整后面的图片高/宽):python helpers/generate_onnx.py --checkpoint-path sam_vit_h_4b8939.pth --onnx-model-path ./sam_onnx.onnx --orig-im-size 720 1280。/segment-anything/dataset/ -a …(3)运行完会有对应的sam_onnx.onnx文件,将其移到SAM工具主文件夹中:cp sam_onnx.onnx …/SAM-Tool/https://blog.csdn.net/qq_37249793/article/details/131956211本人是自己实践编写了自己的工具,未实践上述方法,以上便是全部内容,感谢您可以看到这里,手动鞠躬~

本文内容由网友自发贡献,转载请注明出处:【wpsshop博客】
推荐阅读
相关标签
  

闽ICP备14008679号