赞
踩
# 创建虚拟环境
conda create -n sam python=3.8
# 激活环境
conda activate sam
# 下载代码
git clone git@github.com:facebookresearch/segment-anything.git
# 安装
cd segment-anything; pip install -e .
# 常见库安装
pip install torch torchvision opencv-python pycocotools matplotlib onnxruntime onnx
下载模型,放置models
文件夹,本示例使用ViT-H
,
SAM输入为points
, boxes
, text
或mask
输入图片‘onepiece.jpg’,
输出结果如下图,
代码:
# coding=utf-8 import numpy as np import matplotlib.pyplot as plt import cv2 from pathlib import Path from segment_anything import SamAutomaticMaskGenerator, sam_model_registry, SamPredictor def show_anns(anns): if len(anns) == 0: return sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True) ax = plt.gca() ax.set_autoscale_on(False) img = np.ones((sorted_anns[0]['segmentation'].shape[0], sorted_anns[0]['segmentation'].shape[1], 4)) img[:,:,3] = 0 for ann in sorted_anns: m = ann['segmentation'] color_mask = np.concatenate([np.random.random(3), [0.35]]) img[m] = color_mask ax.imshow(img) def process_img(img_path): '''img_path to img(np.array) ''' image = cv2.imread(img_path) image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) return image def entire_img(img_path): '''whole img generate mask ''' image = process_img(img_path) sam = sam_model_registry["vit_h"](checkpoint="./models/sam_vit_h_4b8939.pth") sam.to(device="cuda") mask_generator = SamAutomaticMaskGenerator(sam) masks = mask_generator.generate(image) plt.figure(figsize=(20,20)) plt.imshow(image) show_anns(masks) plt.axis('off') plt.savefig(str(Path(img_path).name)) # predictor = SamPredictor(sam) def main(): img_path = './notebooks/images/onepiece.jpg' entire_img(img_path) if __name__ == "__main__": main()
选取绿色五角星位置[1064, 1205]
选取框坐标[1305, 244, 2143, 1466]
完整代码如下,欢迎大家体验
# coding=utf-8 import numpy as np import matplotlib.pyplot as plt import cv2 from pathlib import Path from segment_anything import SamAutomaticMaskGenerator, sam_model_registry, SamPredictor def show_anns(anns): if len(anns) == 0: return sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True) ax = plt.gca() ax.set_autoscale_on(False) img = np.ones((sorted_anns[0]['segmentation'].shape[0], sorted_anns[0]['segmentation'].shape[1], 4)) img[:,:,3] = 0 for ann in sorted_anns: m = ann['segmentation'] color_mask = np.concatenate([np.random.random(3), [0.35]]) img[m] = color_mask ax.imshow(img) def show_mask(mask, ax, random_color=False): if random_color: color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0) else: color = np.array([30/255, 144/255, 255/255, 0.6]) h, w = mask.shape[-2:] mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) ax.imshow(mask_image) def show_points(coords, labels, ax, marker_size=375): pos_points = coords[labels==1] neg_points = coords[labels==0] ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25) ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25) def show_box(box, ax): x0, y0 = box[0], box[1] w, h = box[2] - box[0], box[3] - box[1] ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2)) def process_img(img_path): '''img_path to img(np.array) ''' image = cv2.imread(img_path) image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) return image def entire_img(img_path): '''whole img generate mask ''' image = process_img(img_path) sam = sam_model_registry["vit_h"](checkpoint="./models/sam_vit_h_4b8939.pth") sam.to(device="cuda") mask_generator = SamAutomaticMaskGenerator(sam) masks = mask_generator.generate(image) plt.figure(figsize=(20,20)) plt.imshow(image) show_anns(masks) plt.axis('off') plt.savefig(str(Path(img_path).name)) def predict(img_path, type='point'): image = process_img(img_path) sam = sam_model_registry["vit_h"](checkpoint="./models/sam_vit_h_4b8939.pth") sam.to(device="cuda") predictor = SamPredictor(sam) predictor.set_image(image) if type == 'point': # [X, Y] input_point = np.array([[1064, 1205]]) input_label = np.array([1]) masks, scores, logits = predictor.predict( point_coords=input_point, point_labels=input_label, multimask_output=True, ) elif type == 'bbox': input_box = np.array([1305, 244, 2143, 1466]) masks, scores, logits = predictor.predict( point_coords=None, point_labels=None, box=input_box[None, :], multimask_output=False, ) index = np.argmax(scores) plt.figure(figsize=(10,10)) plt.imshow(image) show_mask(masks[index], plt.gca()) if type == 'point': show_points(input_point, input_label, plt.gca()) elif type == 'bbox': show_box(input_box, plt.gca()) plt.title(f"Score: {scores[index]:.3f}", fontsize=18) plt.savefig(str(Path(img_path).stem)+f'{scores[index]:.3f}.png') # predictor = SamPredictor(sam) def main(): img_path = './notebooks/images/onepiece.jpg' # entire_img(img_path) predict(img_path, type='bbox') # predict(img_path) if __name__ == "__main__": main()
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。