赞
踩
注:作者用的 第一个方法.
pip install git+https://github.com/facebookresearch/segment-anything.git
git clone git@github.com:facebookresearch/segment-anything.git
cd segment-anything; pip install -e .
链接:https://pan.baidu.com/s/1T-OPGiWVdM18uedR9V8imQ?pwd=0vhi
提取码:0vhi
pip install opencv-python pycocotools matplotlib onnxruntime onnx
注:单击下面的链接下载相应模型,作者用 第一个;也可以去作者网盘下载.
默认或vit_h
:ViT-H SAM 模型。
vit_l
:ViT-L SAM 模型。
vit_b
:ViT-B SAM 模型。
链接:https://pan.baidu.com/s/1Wp4_3eWJ9jhNnLeE3awpSA?pwd=5c1o
提取码:5c1o
注:需要自定义修改
sam_checkpoint
、image = cv2.imread('images/truck.jpg')
路径.
import numpy as np import torch import matplotlib.pyplot as plt import cv2 def show_mask(mask, ax, random_color=False): if random_color: color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0) else: color = np.array([30/255, 144/255, 255/255, 0.6]) h, w = mask.shape[-2:] mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) ax.imshow(mask_image) def show_points(coords, labels, ax, marker_size=375): pos_points = coords[labels==1] neg_points = coords[labels==0] ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25) ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25) def show_box(box, ax): x0, y0 = box[0], box[1] w, h = box[2] - box[0], box[3] - box[1] ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))
image = cv2.imread('images/truck.jpg')
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
plt.figure(figsize=(10,10))
plt.imshow(image)
plt.axis('on')
plt.show()
import sys
sys.path.append("..")
from segment_anything import sam_model_registry, SamPredictor
sam_checkpoint = "models/sam_vit_h_4b8939.pth"
model_type = "vit_h"
device = "cuda"
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)
predictor = SamPredictor(sam)
predictor.set_image(image)
input_point = np.array([[500, 375]])
input_label = np.array([1])
plt.figure(figsize=(10,10))
plt.imshow(image)
show_points(input_point, input_label, plt.gca())
plt.axis('on')
plt.show()
masks, scores, logits = predictor.predict(
point_coords=input_point,
point_labels=input_label,
multimask_output=True,
)
multimask_output=True
(默认设置),SAM 输出 3 个掩码,其中scores
给出了模型自己对这些掩码质量的估计。此设置用于歧义输入提示,并帮助模型消除与提示一致的不同对象的歧义。当False
时,它将返回单个掩码。对于模棱两可的提示,例如单个点,即使只需要一个掩码,也建议使用multimask_output=True
;最好的单个掩码可以通过选择在“分数”中返回的分数最高的掩码来选择。这通常会产生一个更好的 mask
。masks.shape # (3, 1200, 1800)
for i, (mask, score) in enumerate(zip(masks, scores)):
plt.figure(figsize=(10,10))
plt.imshow(image)
show_mask(mask, plt.gca())
show_points(input_point, input_label, plt.gca())
plt.title(f"Mask {i+1}, Score: {score:.3f}", fontsize=18)
plt.axis('off')
plt.show()
multimask_output=False
来请求单个掩码。input_point = np.array([[500, 375], [1125, 625]]) input_label = np.array([1, 1]) mask_input = logits[np.argmax(scores), :, :] # 选择模型最好的 mask masks, _, _ = predictor.predict( point_coords=input_point, point_labels=input_label, mask_input=mask_input[None, :, :], multimask_output=False, ) masks.shape plt.figure(figsize=(10,10)) plt.imshow(image) show_mask(masks, plt.gca()) show_points(input_point, input_label, plt.gca()) plt.axis('off') plt.show()
input_point = np.array([[500, 375], [1125, 625]]) input_label = np.array([1, 0]) mask_input = logits[np.argmax(scores), :, :] # Choose the model's best mask masks, _, _ = predictor.predict( point_coords=input_point, point_labels=input_label, mask_input=mask_input[None, :, :], multimask_output=False, ) plt.figure(figsize=(10, 10)) plt.imshow(image) show_mask(masks, plt.gca()) show_points(input_point, input_label, plt.gca()) plt.axis('off') plt.show()
input_box = np.array([425, 600, 700, 875])
masks, _, _ = predictor.predict(
point_coords=None,
point_labels=None,
box=input_box[None, :],
multimask_output=False,
)
plt.figure(figsize=(10, 10))
plt.imshow(image)
show_mask(masks[0], plt.gca())
show_box(input_box, plt.gca())
plt.axis('off')
plt.show()
input_box = np.array([425, 600, 700, 875]) input_point = np.array([[575, 750]]) input_label = np.array([0]) masks, _, _ = predictor.predict( point_coords=input_point, point_labels=input_label, box=input_box, multimask_output=False, ) plt.figure(figsize=(10, 10)) plt.imshow(image) show_mask(masks[0], plt.gca()) show_box(input_box, plt.gca()) show_points(input_point, input_label, plt.gca()) plt.axis('off') plt.show()
SamPredictor
可以使用predict_torch
方法为相同的图像接受多个输入提示。该方法假定输入点已经是 torch 张量,并且已经转换为输入帧。例如,假设我们有几个来自对象检测器的框输出。input_boxes = torch.tensor([
[75, 275, 1725, 850],
[425, 600, 700, 875],
[1375, 550, 1650, 800],
[1240, 675, 1400, 750],
], device=predictor.device)
SamPredictor
将必要的转换存储为transform
字段,以便于访问。transformed_boxes = predictor.transform.apply_boxes_torch(input_boxes, image.shape[:2]) masks, _, _ = predictor.predict_torch( point_coords=None, point_labels=None, boxes=transformed_boxes, multimask_output=False, ) masks.shape # torch.Size([4, 1, 1200, 1800]) plt.figure(figsize=(10, 10)) plt.imshow(image) for mask in masks: show_mask(mask.cpu().numpy(), plt.gca(), random_color=True) for box in input_boxes: show_box(box.cpu().numpy(), plt.gca()) plt.axis('off') plt.show()
image1 = image # truck.jpg from above image1_boxes = torch.tensor([ [75, 275, 1725, 850], [425, 600, 700, 875], [1375, 550, 1650, 800], [1240, 675, 1400, 750], ], device=sam.device) image2 = cv2.imread('images/groceries.jpg') image2 = cv2.cvtColor(image2, cv2.COLOR_BGR2RGB) image2_boxes = torch.tensor([ [450, 170, 520, 350], [350, 190, 450, 350], [500, 170, 580, 350], [580, 170, 640, 350], ], device=sam.device)
图像和提示都是作为 PyTorch 张量输入的,这些张量已经转换为正确的帧。输入被打包成一个包含图像的列表,其中每个元素都是一个字典,接受以下键:
image
:以 CHW 格式的 PyTorch 张量的形式输入图像。original_size
:图像在转换为 SAM 之前的大小,格式为(H, W)。point_coords
:点提示的批处理坐标。point_labels
:批处理点提示的标签。boxes
:批量输入框。mask_inputs
:批处理输入掩码。如果提示符不存在,则可以排除该键。
from segment_anything.utils.transforms import ResizeLongestSide resize_transform = ResizeLongestSide(sam.image_encoder.img_size) def prepare_image(image, transform, device): image = transform.apply_image(image) image = torch.as_tensor(image, device=device.device) return image.permute(2, 0, 1).contiguous() batched_input = [ { 'image': prepare_image(image1, resize_transform, sam), 'boxes': resize_transform.apply_boxes_torch(image1_boxes, image1.shape[:2]), 'original_size': image1.shape[:2] }, { 'image': prepare_image(image2, resize_transform, sam), 'boxes': resize_transform.apply_boxes_torch(image2_boxes, image2.shape[:2]), 'original_size': image2.shape[:2] } ]
batched_output = sam(batched_input, multimask_output=False) batched_output[0].keys() fig, ax = plt.subplots(1, 2, figsize=(20, 20)) ax[0].imshow(image1) for mask in batched_output[0]['masks']: show_mask(mask.cpu().numpy(), ax[0], random_color=True) for box in image1_boxes: show_box(box.cpu().numpy(), ax[0]) ax[0].axis('off') ax[1].imshow(image2) for mask in batched_output[1]['masks']: show_mask(mask.cpu().numpy(), ax[1], random_color=True) for box in image2_boxes: show_box(box.cpu().numpy(), ax[1]) ax[1].axis('off') plt.tight_layout() plt.show()
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。