当前位置:   article > 正文

Segment Anything Model (SAM)本地部署,及应用于自己的数据集完成分割_sam部署pycharm

sam部署pycharm

1.什么是segment-anything?

Segment Anything Model (SAM): a new AI model from Meta AI that can "cut out" any object, in any image, with a single click.

官方网站:Segment Anything | Meta AI

官网提供了一些demo,交互式地分割展示,可以试一下

官网demo提供了3种分割方式。

Hover&Click(鼠标悬浮及点击):

点击想要分割的内容,自动完成分割。左键正选,右键反选。

Box

鼠标圈一个方框,自动分割。

Everything

一键分割所有。

2.Github项目

GitHub网址:https://github.com/facebookresearch/segment-anything

 readme很详细,可按照步骤来搞。

克隆远程仓库到本地:

git clone https://github.com/facebookresearch/segment-anything

安装Segment Anything :

cd segment-anything; pip install -e .

特别注意-e后面一个点

Github页面里点击下载一个或者多个模型:

模型文件放到项目的目录即可。

H,L,B分别表示huge,large,base,从大到小。根据硬件能力选择合适的模型。

创建虚拟环境及安装相应的包

3.使用自己的数据集进行分割

法1:使用官方命令

如下,建立input,output文件夹

在input中存放待分割的图片,output用作存放输出的mask。

  1. 在pycharm终端中使用下面的命令:
  2. 使用base模型
  3. python scripts/amg.py --checkpoint sam_vit_b_01ec64.pth --model-type vit_b --input C:\GithubFile\segment-anything\input --output C:\GithubFile\segment-anything\output
  4. 使用huge模型
  5. python scripts/amg.py --checkpoint sam_vit_h_4b8939.pth --model-type vit_h --input C:\GithubFile\segment-anything\input --output C:\GithubFile\segment-anything\output

 官方命令即执行amg.py文件,并传入了一些参数,当传入参数固定时可以直接写在amg.py文件中。

加入default=你的路径,并且将required改为False, 其他各项类似。help中会有一些辅助说明。

输入:

结果:

法2:使用第三方脚本

代码参考:segment-anything本地部署使用_yunteng521的博客-CSDN博客

  1. import cv2
  2. import os
  3. import numpy as np
  4. from segment_anything import sam_model_registry, SamPredictor
  5. input_dir = 'input'
  6. output_dir = 'output'
  7. crop_mode = True # 是否裁剪到最小范围
  8. # alpha_channel是否保留透明通道
  9. print('最好是每加一个点就按w键predict一次')
  10. os.makedirs(output_dir, exist_ok=True)
  11. image_files = [f for f in os.listdir(input_dir) if
  12. f.lower().endswith(('.png', '.jpg', '.jpeg', '.JPG', '.JPEG', '.PNG'))]
  13. # sam = sam_model_registry["vit_b"](checkpoint="sam_vit_b_01ec64.pth")
  14. sam = sam_model_registry["vit_h"](checkpoint="sam_vit_h_4b8939.pth")
  15. # 添加使用的模型,用A就把B注释掉
  16. _ = sam.to(device="cuda") # 注释掉这一行,会用cpu运行,速度会慢很多
  17. predictor = SamPredictor(sam) # SAM预测图像
  18. def mouse_click(event, x, y, flags, param): # 鼠标点击事件
  19. global input_point, input_label, input_stop # 全局变量,输入点,
  20. if not input_stop: # 判定标志是否停止输入响应了!
  21. if event == cv2.EVENT_LBUTTONDOWN: # 鼠标左键
  22. input_point.append([x, y])
  23. input_label.append(1) # 1表示前景点
  24. elif event == cv2.EVENT_RBUTTONDOWN: # 鼠标右键
  25. input_point.append([x, y])
  26. input_label.append(0) # 0表示背景点
  27. else:
  28. if event == cv2.EVENT_LBUTTONDOWN or event == cv2.EVENT_RBUTTONDOWN: # 提示添加不了
  29. print('此时不能添加点,按w退出mask选择模式')
  30. def apply_mask(image, mask, alpha_channel=True): # 应用并且响应mask
  31. if alpha_channel:
  32. alpha = np.zeros_like(image[..., 0]) # 制作掩体
  33. alpha[mask == 1] = 255 # 兴趣地方标记为1,且为白色
  34. image = cv2.merge((image[..., 0], image[..., 1], image[..., 2], alpha)) # 融合图像
  35. else:
  36. image = np.where(mask[..., None] == 1, image, 0)
  37. return image
  38. def apply_color_mask(image, mask, color, color_dark=0.5): # 对掩体进行赋予颜色
  39. for c in range(3):
  40. image[:, :, c] = np.where(mask == 1, image[:, :, c] * (1 - color_dark) + color_dark * color[c], image[:, :, c])
  41. return image
  42. def get_next_filename(base_path, filename): # 进行下一个图像
  43. name, ext = os.path.splitext(filename)
  44. for i in range(1, 3):
  45. new_name = f"{name}_{i}{ext}"
  46. if not os.path.exists(os.path.join(base_path, new_name)):
  47. return new_name
  48. return None
  49. def save_masked_image(image, mask, output_dir, filename, crop_mode_): # 保存掩盖部分的图像(感兴趣的图像)
  50. if crop_mode_:
  51. y, x = np.where(mask)
  52. y_min, y_max, x_min, x_max = y.min(), y.max(), x.min(), x.max()
  53. cropped_mask = mask[y_min:y_max + 1, x_min:x_max + 1]
  54. cropped_image = image[y_min:y_max + 1, x_min:x_max + 1]
  55. masked_image = apply_mask(cropped_image, cropped_mask)
  56. else:
  57. masked_image = apply_mask(image, mask)
  58. filename = filename[:filename.rfind('.')] + '.png'
  59. new_filename = get_next_filename(output_dir, filename)
  60. if new_filename:
  61. if masked_image.shape[-1] == 4:
  62. cv2.imwrite(os.path.join(output_dir, new_filename), masked_image, [cv2.IMWRITE_PNG_COMPRESSION, 9])
  63. else:
  64. cv2.imwrite(os.path.join(output_dir, new_filename), masked_image)
  65. print(f"Saved as {new_filename}")
  66. else:
  67. print("Could not save the image. Too many variations exist.")
  68. current_index = 0
  69. cv2.namedWindow("image")
  70. cv2.setMouseCallback("image", mouse_click)
  71. input_point = []
  72. input_label = []
  73. input_stop = False
  74. while True:
  75. filename = image_files[current_index]
  76. image_orign = cv2.imread(os.path.join(input_dir, filename))
  77. image_crop = image_orign.copy() # 原图裁剪
  78. image = cv2.cvtColor(image_orign.copy(), cv2.COLOR_BGR2RGB) # 原图色彩转变
  79. selected_mask = None
  80. logit_input = None
  81. while True:
  82. # print(input_point)
  83. input_stop = False
  84. image_display = image_orign.copy()
  85. display_info = f'{filename} | Press s to save | Press w to predict | Press d to next image | Press a to previous image | Press space to clear | Press q to remove last point '
  86. cv2.putText(image_display, display_info, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 255), 2, cv2.LINE_AA)
  87. for point, label in zip(input_point, input_label): # 输入点和输入类型
  88. color = (0, 255, 0) if label == 1 else (0, 0, 255)
  89. cv2.circle(image_display, tuple(point), 5, color, -1)
  90. if selected_mask is not None:
  91. color = tuple(np.random.randint(0, 256, 3).tolist())
  92. selected_image = apply_color_mask(image_display, selected_mask, color)
  93. cv2.imshow("image", image_display)
  94. key = cv2.waitKey(1)
  95. if key == ord(" "):
  96. input_point = []
  97. input_label = []
  98. selected_mask = None
  99. logit_input = None
  100. elif key == ord("w"):
  101. input_stop = True
  102. if len(input_point) > 0 and len(input_label) > 0:
  103. # todo 预测图像
  104. predictor.set_image(image) # 设置输入图像
  105. input_point_np = np.array(input_point) # 输入暗示点,需要转变array类型才可以输入
  106. input_label_np = np.array(input_label) # 输入暗示点的类型
  107. # todo 输入暗示信息,将返回masks
  108. masks, scores, logits = predictor.predict(
  109. point_coords=input_point_np,
  110. point_labels=input_label_np,
  111. mask_input=logit_input[None, :, :] if logit_input is not None else None,
  112. multimask_output=True,
  113. )
  114. mask_idx = 0
  115. num_masks = len(masks) # masks的数量
  116. while (1):
  117. color = tuple(np.random.randint(0, 256, 3).tolist()) # 随机列表颜色,就是
  118. image_select = image_orign.copy()
  119. selected_mask = masks[mask_idx] # 选择msks也就是,a,d切换
  120. selected_image = apply_color_mask(image_select, selected_mask, color)
  121. mask_info = f'Total: {num_masks} | Current: {mask_idx} | Score: {scores[mask_idx]:.2f} | Press w to confirm | Press d to next mask | Press a to previous mask | Press q to remove last point | Press s to save'
  122. cv2.putText(selected_image, mask_info, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 255), 2,
  123. cv2.LINE_AA)
  124. # todo 显示在当前的图片,
  125. cv2.imshow("image", selected_image)
  126. key = cv2.waitKey(10)
  127. if key == ord('q') and len(input_point) > 0:
  128. input_point.pop(-1)
  129. input_label.pop(-1)
  130. elif key == ord('s'):
  131. save_masked_image(image_crop, selected_mask, output_dir, filename, crop_mode_=crop_mode)
  132. elif key == ord('a'):
  133. if mask_idx > 0:
  134. mask_idx -= 1
  135. else:
  136. mask_idx = num_masks - 1
  137. elif key == ord('d'):
  138. if mask_idx < num_masks - 1:
  139. mask_idx += 1
  140. else:
  141. mask_idx = 0
  142. elif key == ord('w'):
  143. break
  144. elif key == ord(" "):
  145. input_point = []
  146. input_label = []
  147. selected_mask = None
  148. logit_input = None
  149. break
  150. logit_input = logits[mask_idx, :, :]
  151. print('max score:', np.argmax(scores), ' select:', mask_idx)
  152. elif key == ord('a'):
  153. current_index = max(0, current_index - 1)
  154. input_point = []
  155. input_label = []
  156. break
  157. elif key == ord('d'):
  158. current_index = min(len(image_files) - 1, current_index + 1)
  159. input_point = []
  160. input_label = []
  161. break
  162. elif key == 27:
  163. break
  164. elif key == ord('q') and len(input_point) > 0:
  165. input_point.pop(-1)
  166. input_label.pop(-1)
  167. elif key == ord('s') and selected_mask is not None:
  168. save_masked_image(image_crop, selected_mask, output_dir, filename, crop_mode_=crop_mode)
  169. if key == 27:
  170. break

使用样例:

先点击w标注,

后点击w预测,

然后点击s保存,

点击d下一张。

每张图片标记不只一个mask

个人感觉,这个脚本不太好用。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/Monodyee/article/detail/92357
推荐阅读
相关标签
  

闽ICP备14008679号