当前位置:   article > 正文

SAM+使用SAM应用数据集完成分割_sam微调 数据集

sam微调 数据集

什么是SAM

        SAM(Segment Anything Model)是由 Meta 的研究人员团队创建和训练的深度学习模型。在 Segment everything 研究论文中,SAM 被称为“基础模型”。

        基础模型是在大量数据上训练的机器学习模型(通常通过自监督或半监督学习),其目的是在更具体的任务上使用和重新训练。SAM 是一个预训练模型,旨在适应其他任务(特别是通过微调)。

sam安装

下载安装SAMhttps://github.com/facebookresearch/segment-anything

安装 Segment Anything:

pip install git+https://github.com/facebookresearch/segment-anything.git

或在本地克隆存储库并使用

  1. git clone git@github.com:facebookresearch/segment-anything.git
  2. cd segment-anything; pip install -e .

Github页面里点击下载一个或者多个模型:

模型文件放到项目的目录即可。

H,L,B分别表示huge,large,base,从大到小。根据硬件能力选择合适的模型。

下列依次ViT-H SAM模型(vit_h),ViT-L SAM 模型(vit_1), ViT-B SAM 模型(vit_b)

​​​​https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth

https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth 

https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth

 使用

 方法一:使用官方命令

建立input,output文件夹

在input中存放待分割的图片,output用作存放输出的mask。

在这我使用的是vit_h

python3 scripts/amg.py --checkpoint ./sam_vit_h_4b8939.pth --model-type default --input ./input.jpeg --output output

官方命令即执行amg.py文件,并传入了一些参数,当传入参数固定时可以直接写在amg.py文件中。

方法二:

  1. # coding=gb2312
  2. from segment_anything import SamPredictor, SamAutomaticMaskGenerator, sam_model_registry
  3. import cv2
  4. import numpy as np
  5. import torch
  6. import matplotlib.pyplot as plt
  7. device = "cuda"
  8. sam = sam_model_registry["default"](checkpoint="你下载的权重的位置")
  9. #sam_vit_h_4b8939.pth 是预训练的默认权重,需要单独下载
  10. sam.to(device=device)
  11. mask_generator = SamAutomaticMaskGenerator(sam)
  12. def show_anns(anns):
  13. if len(anns) == 0:
  14. return
  15. sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
  16. ax = plt.gca()
  17. ax.set_autoscale_on(False)
  18. img = np.ones((sorted_anns[0]['segmentation'].shape[0], sorted_anns[0]['segmentation'].shape[1], 4))
  19. img[:,:,3] = 0
  20. for ann in sorted_anns:
  21. m = ann['segmentation']
  22. color_mask = np.concatenate([np.random.random(3), [0.35]])
  23. img[m] = color_mask
  24. ax.imshow(img)
  25. image = cv2.imread('图片 位置.jpeg')
  26. masks = mask_generator.generate(image)
  27. plt.figure(figsize=(20,20))
  28. plt.imshow(image)
  29. show_anns(masks)
  30. plt.axis('off')
  31. plt.show()

 

参考: 

https://zhuanlan.zhihu.com/p/627535252

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/AllinToyou/article/detail/298282
推荐阅读
相关标签
  

闽ICP备14008679号