当前位置:   article > 正文

如何利用SAM(segment-anything)制作自己的分割数据集_segment-anything训练自己的数据

segment-anything训练自己的数据

1. 环境搭建

        github地址 GitHub - facebookresearch/segment-anything: The repository provides code for running inference with the SegmentAnything Model (SAM), links for downloading the trained model checkpoints, and example notebooks that show how to use the model.

        1.2 步骤        

        

该代码需要 python>=3.8,以及 pytorch>=1.7 和 torchvision>=0.8。 请按照此处的说明安装 PyTorch 和 TorchVision 依赖项。 强烈建议安装支持 CUDA 的 PyTorch 和 TorchVision。

安装 Segment Anything:

pip install git+https://github.com/facebookresearch/segment-anything.git

或在本地克隆存储库并安装

  1. git clone git@github.com:facebookresearch/segment-anything.git
  2. cd segment-anything; pip install -e .

以下可选依赖项对于掩模后处理、以 COCO 格式保存掩模、示例笔记本以及以 ONNX 格式导出模型是必需的。 运行示例笔记本还需要 jupyter。

pip install opencv-python pycocotools matplotlib onnxruntime onnx

 2. 制作蒙版

相关代码如下

  1. import numpy as np
  2. import torch
  3. import matplotlib.pyplot as plt
  4. import cv2
  5. # 添加掩码
  6. def show_anns(anns):
  7. if len(anns) == 0:
  8. return
  9. sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
  10. ax = plt.gca()
  11. ax.set_autoscale_on(False)
  12. img = np.ones((sorted_anns[0]['segmentation'].shape[0], sorted_anns[0]['segmentation'].shape[1], 4))
  13. img[:,:,3] = 0
  14. for ann in sorted_anns[10:200]:
  15. m = ann['segmentation']
  16. color_mask = np.concatenate([np.random.random(3), [0.35]])
  17. img[m] = color_mask
  18. ax.imshow(img)
  19. image = cv2.imread('./images/05.jpg')
  20. image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
  21. print(image.shape)
  22. plt.figure(figsize=(20,20))
  23. plt.imshow(image)
  24. plt.axis('off')
  25. plt.show()
  26. import sys
  27. sys.path.append("..")
  28. from mobile_sam import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor
  29. sam_checkpoint = "../weights/mobile_sam.pt"
  30. model_type = "vit_t"
  31. device = "cuda" if torch.cuda.is_available() else "cpu"
  32. sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
  33. sam.to(device=device)
  34. sam.eval()
  35. mask_generator = SamAutomaticMaskGenerator(sam)
  36. masks = mask_generator.generate(image)
  37. plt.figure(figsize=(20,20))
  38. plt.imshow(image)
  39. show_anns(masks[:])
  40. plt.axis('off')
  41. plt.show()
  42. # 保存掩码
  43. def save_mask(anns):
  44. if len(anns) == 0:
  45. return
  46. sorted_anns = sorted(anns, key=(lambda x :x['area']), reverse=False)
  47. img = np.zeros((sorted_anns[0]['segmentation'].shape[0], sorted_anns[0]['segmentation'].shape[1], 1))
  48. for ann in sorted_anns[:]:
  49. m = ann['segmentation']
  50. img[m] = 255
  51. cv2.imwrite('res.jpg', img)
  52. # save_mask(masks)
  53. sorted_anns = sorted(masks, key=(lambda x: x['area']), reverse=True)
  54. save_mask(sorted_anns[:])

3. 制作COCO格式数据集可用来语义分割、目标检测、实例分割

接蒙版代码

  1. # 获取边缘
  2. import cv2
  3. import numpy as np
  4. image = cv2.imread('./images/05.jpg')
  5. # image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
  6. kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE,(5,5))
  7. img_ = np.zeros_like(image)
  8. gray_images = mask_show(masks[:])
  9. for img in gray_images[:]:
  10. gray_image = np.uint8(img)
  11. gray_image = cv2.morphologyEx(gray_image,cv2.MORPH_OPEN,kernel)
  12. edges = cv2.Canny(gray_image, 50, 150)
  13. contours, _ = cv2.findContours(edges, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
  14. cv2.drawContours(img_, contours, -1, (255, 255, 255), 2)
  15. cv2.imwrite("counte2.png", img_)
  16. # 蒙版-边缘
  17. im = cv2.imread('images/05.jpg', cv2.IMREAD_GRAYSCALE)
  18. image1 = cv2.imread('res.jpg', cv2.IMREAD_GRAYSCALE)
  19. image2 = cv2.imread('counte2.png', cv2.IMREAD_GRAYSCALE)
  20. img = cv2.subtract(image1, image2)
  21. kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE,(7,7))
  22. dst2 = cv2.morphologyEx(img,cv2.MORPH_OPEN,kernel)
  23. # print(dst2.shape)
  24. re_img = cv2.addWeighted(dst2, 0.2, im, 0.8 ,0)
  25. cv2.imwrite("res3.jpg", dst2)
  26. plt.figure(figsize=(20,20))
  27. plt.imshow(dst2, cmap='gray')
  28. plt.axis('off')
  29. plt.show()
  30. # 以COCO格式存储
  31. import json
  32. orig_img = cv2.imread('./images/05.jpg')
  33. image = cv2.imread('res3.jpg')
  34. edges = cv2.Canny(image, 50, 150)
  35. contours, _ = cv2.findContours(edges, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
  36. images = [
  37. {
  38. 'file_name':'05.jpg',
  39. 'height': int(orig_img.shape[0]),
  40. 'width' : int(orig_img.shape[1]),
  41. 'id': 1
  42. },
  43. ]
  44. categories = [
  45. {
  46. 'id': 1,
  47. 'name': 'qituan'
  48. },
  49. ]
  50. annotations= []
  51. for contour in contours:
  52. seg = []
  53. contour_ = contour.squeeze(1)
  54. seg.append(list(contour_.flatten().tolist()))
  55. x, y, w, h = cv2.boundingRect(contour)
  56. bbox = [x, y, w, h]
  57. area = cv2.contourArea(contour)
  58. iscrowd = 0
  59. image_id = 1
  60. category_id = 1
  61. id = len(annotations) + 1
  62. annotations.append({
  63. 'segmentation': seg,
  64. 'bbox': bbox,
  65. 'area': area,
  66. 'iscrowd': 0,
  67. 'image_id': 1,
  68. 'category_id': 1,
  69. 'id': id
  70. })
  71. coco_data = {
  72. 'images': images,
  73. 'annotations': annotations,
  74. 'categories': categories
  75. }
  76. print(coco_data)
  77. output_file_path = 'coco_data.json'
  78. # Serialize the data and write to a JSON file
  79. with open(output_file_path, 'w') as f:
  80. json.dump(coco_data, f, indent=4)

4. 验证COCO格式数据

  1. import cv2
  2. import random
  3. import json, os
  4. from pycocotools.coco import COCO
  5. from skimage import io
  6. from matplotlib import pyplot as plt
  7. import numpy as np
  8. train_json = 'coco_data.json'
  9. train_path = './images/'
  10. coco = COCO(train_json)
  11. list_imgIds = coco.getImgIds(catIds=1 )
  12. list_imgIds
  13. img = coco.loadImgs(list_imgIds[0])[0]
  14. image = cv2.imread(train_path + img['file_name']) # 读取图像
  15. img_annIds = coco.getAnnIds(imgIds=1, catIds=1, iscrowd=None)
  16. anns = coco.loadAnns(img_annIds)
  17. img = coco.loadImgs(list_imgIds[0])[0]
  18. img1 = cv2.imread(train_path + img['file_name']) # 读取图像
  19. #分割
  20. for ann in anns:
  21. data = np.array(ann['segmentation'][0])
  22. num_points = len(data) // 2
  23. contour_restored = data.reshape((num_points, 2))
  24. contour_restored = contour_restored.reshape(contour_restored.shape[0], 1, contour_restored.shape[1])
  25. # print(contour_restored.shape)
  26. color = np.random.randint(0, 255, 3).tolist()
  27. cv2.drawContours(img1, [contour_restored], -1, tuple(color), thickness=cv2.FILLED)
  28. # mask = coco.annToMask(ann)
  29. # color = np.random.randint(0, 255, 3) # Random color for each mask
  30. # img = cv2.addWeighted(img, 1, cv2.cvtColor(mask * 255, cv2.COLOR_GRAY2BGR), 0.5, 0)
  31. plt.rcParams['figure.figsize'] = (20.0, 20.0)
  32. # 此处的20.0是由于我的图片是2000*2000,目前还没去研究怎么利用plt自动分辨率。
  33. plt.imshow(img1)
  34. plt.show()
  35. img_annIds = coco.getAnnIds(imgIds=1, catIds=1, iscrowd=None)
  36. img_annIds
  37. # 目标检测
  38. for id in img_annIds[:]:
  39. ann = coco.loadAnns(id)[0]
  40. x, y, w, h = ann['bbox']
  41. # print(ann['bbox'])
  42. image1 = cv2.rectangle(image, (int(x), int(y)), (int(x + w), int(y + h)), (0, 255, 255), 2)
  43. plt.rcParams['figure.figsize'] = (20.0, 20.0)
  44. # 此处的20.0是由于我的图片是2000*2000,目前还没去研究怎么利用plt自动分辨率。
  45. plt.imshow(image1)
  46. plt.show()

本文内容由网友自发贡献,转载请注明出处:【wpsshop博客】
推荐阅读
相关标签
  

闽ICP备14008679号