当前位置:   article > 正文

Segment Anything实现_seg anything

seg anything

1、创造虚拟环境

conda create -n seganything python=3.8
  • 1

2、安装torch>=1.7,torchvision>=0.8,torch网址:https://pytorch.org/get-started/previous-versions/根据自己的cuda版本选择对应的下载命令

conda install pytorch==1.7.0 torchvision==0.8.0 torchaudio==0.7.0 cudatoolkit=9.2 -c pytorch
  • 1

3、验证是否安装成功,返回True,表示安装成功

import torch
print(torch.cuda.is_available())
  • 1
  • 2

4、下载segment anything 工程并安装

git clone git@github.com:facebookresearch/segment-anything.git
cd segment-anything
pip install -e .
  • 1
  • 2
  • 3

5、安装其他依赖包

pip install opencv-python pycocotools matplotlib onnxruntime onnx
  • 1

6、下载权重
7、可提示的seg_anything,对此图片进行可提示分割
输入影像:
在这里插入图片描述
实现代码:

import numpy as np
import torch
import matplotlib.pyplot as plt
import cv2
from segment_anything import sam_model_registry, SamPredictor
#显示提示points
def show_points(coords, labels, ax, marker_size=375):
    pos_points = coords[labels==1]
    neg_points = coords[labels==0]
    ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
    ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)   

#显示mask
def show_mask(mask, ax, random_color=False):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        color = np.array([30/255, 144/255, 255/255, 0.6])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)
#显示box
def show_box(box, ax):
    x0, y0 = box[0], box[1]
    w, h = box[2] - box[0], box[3] - box[1]
    ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2)) 


#输入影像
image = cv2.imread('..../1.jpg')
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
#输入提示
input_point = np.array([[1299, 815]])
input_label = np.array([1])
plt.figure(figsize=(10,10))
plt.imshow(image)
show_points(input_point, input_label, plt.gca())
plt.savefig('.../1_promt.jpg')

#加载模型
sam_checkpoint = "./checkpoint/sam_vit_h_4b8939.pth"
model_type = "vit_h"
device = "cuda"
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)
predictor = SamPredictor(sam)

predictor.set_image(image)

masks, scores, logits = predictor.predict(point_coords=input_point,point_labels=input_label,multimask_output=True,)
print(masks.shape)

for i, (mask, score) in enumerate(zip(masks, scores)):
    plt.figure(figsize=(10,10))
    plt.imshow(image)
    show_mask(mask, plt.gca())
    show_points(input_point, input_label, plt.gca())
    plt.title(f"Mask {i+1}, Score: {score:.3f}", fontsize=18)
    plt.savefig('.../'+str(i)+'.jpg')

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60

结果如下所示:
如下图所示,绿色的星型标志为point的位置
在这里插入图片描述
下图为mask1的结果:
在这里插入图片描述
下图为mask2的结果:
在这里插入图片描述
下图为mask3的结果:
在这里插入图片描述

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/凡人多烦事01/article/detail/482684
推荐阅读
相关标签
  

闽ICP备14008679号