赞
踩
SAM(Segment Anything Model)是由 Meta 的研究人员团队创建和训练的深度学习模型。在 Segment everything 研究论文中,SAM 被称为“基础模型”。
基础模型是在大量数据上训练的机器学习模型(通常通过自监督或半监督学习),其目的是在更具体的任务上使用和重新训练。SAM 是一个预训练模型,旨在适应其他任务(特别是通过微调)。
下载安装SAMhttps://github.com/facebookresearch/segment-anything
安装 Segment Anything:
pip install git+https://github.com/facebookresearch/segment-anything.git
或在本地克隆存储库并使用
- git clone git@github.com:facebookresearch/segment-anything.git
- cd segment-anything; pip install -e .
Github页面里点击下载一个或者多个模型:
模型文件放到项目的目录即可。
H,L,B分别表示huge,large,base,从大到小。根据硬件能力选择合适的模型。
下列依次:ViT-H SAM模型(vit_h),ViT-L SAM 模型(vit_1), ViT-B SAM 模型(vit_b)
https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth
https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth
https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth
方法一:使用官方命令
建立input,output文件夹
在input中存放待分割的图片,output用作存放输出的mask。
在这我使用的是vit_h
python3 scripts/amg.py --checkpoint ./sam_vit_h_4b8939.pth --model-type default --input ./input.jpeg --output output
官方命令即执行amg.py文件,并传入了一些参数,当传入参数固定时可以直接写在amg.py文件中。
方法二:
- # coding=gb2312
- from segment_anything import SamPredictor, SamAutomaticMaskGenerator, sam_model_registry
- import cv2
- import numpy as np
- import torch
- import matplotlib.pyplot as plt
- device = "cuda"
- sam = sam_model_registry["default"](checkpoint="你下载的权重的位置")
- #sam_vit_h_4b8939.pth 是预训练的默认权重,需要单独下载
- sam.to(device=device)
- mask_generator = SamAutomaticMaskGenerator(sam)
-
- def show_anns(anns):
- if len(anns) == 0:
- return
- sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
- ax = plt.gca()
- ax.set_autoscale_on(False)
-
- img = np.ones((sorted_anns[0]['segmentation'].shape[0], sorted_anns[0]['segmentation'].shape[1], 4))
- img[:,:,3] = 0
- for ann in sorted_anns:
- m = ann['segmentation']
- color_mask = np.concatenate([np.random.random(3), [0.35]])
- img[m] = color_mask
- ax.imshow(img)
-
- image = cv2.imread('图片 位置.jpeg')
- masks = mask_generator.generate(image)
- plt.figure(figsize=(20,20))
- plt.imshow(image)
- show_anns(masks)
- plt.axis('off')
- plt.show()
参考:
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。