当前位置:   article > 正文

计算机视觉 | YOLO 和 SAM 强强联合能干什么大事

yolo sam

点击下方卡片,关注“小白玩转Python”公众号

在这篇博客中,我们将探索计算机视觉和图像分析的迷人领域,探讨两种开创性模型之间的动态协同:YOLO(You Only Look Once)和 SAM(Segment Anything Model)。YOLO 因其在目标检测方面的革命性进展而备受赞誉,与在分割领域具有强大实力的 SAM 相结合,承诺带来令人兴奋的能力融合。

那么,什么是 SAM(Segment Anything Model)?:

SAM 由 Meta 在 2023 年推出,是一种革命性的图像分割模型。以其卓越的性能而著称,SAM 已巩固了其作为最先进的分割模型之一的地位。SAM 在图像分割技术方面代表了一项突破性的进展,提供了前所未有的精确性和多功能性。与传统的受特定对象类型或环境限制的分割模型不同,SAM 凭借先进的神经网络架构和在大型数据集上进行的广泛训练,能够以无与伦比的准确性分割图像中的几乎任何对象。

架构:

如论文中的图表所示,SAM 的架构采用多阶段的图像分割方法。其核心是一系列互联的神经网络模块,每个模块都针对分割过程的不同方面进行处理。

f058ea5fc4c78520ec5867305589a25a.png

架构的初始阶段涉及特征提取,输入图像通过卷积层处理以提取相关特征。这些特征随后通过一系列编码和解码层传递,在提取高层语义信息的同时保留空间细节。SAM 的关键创新在于其注意力机制,使模型在分割过程中能够有选择地关注图像中的相关区域。这种注意力机制通过一组注意力模块实现,基于上下文线索和特征重要性动态调整模型的关注点。

此外,SAM 还引入了跳跃连接,以促进不同网络层之间的信息流动。这些连接使模型能够利用低层和高层特征,增强其捕捉复杂细节和上下文的能力。总体而言,SAM 的架构经过精心设计,优化了分割过程,利用注意力机制和跳跃连接等先进的神经网络技术,实现了精确且多功能的分割结果。这种复杂的设计使 SAM 在广泛的应用中表现出色,从医学成像到自动驾驶,确立了其在计算机视觉领域的开创性地位。

SAM 的关键特性:

  • 多功能性:SAM 设计用于分割图像中的任何事物,从日常物品到复杂场景,具备出色的准确性和细节。

  • 鲁棒性:得益于其复杂的架构和在多样化数据集上的广泛训练,SAM 在各种场景中表现出色,包括不同的光照条件和物体方向。

  • 规模:SAM 能够处理不同分辨率和规模的图像,适用于高分辨率图像和实时应用。

  • 上下文理解:SAM 整合了上下文信息,以提高分割精度,甚至在杂乱场景中也能有效地区分对象与其环境。

  • 效率:尽管具备先进功能,SAM 仍保持高效率,确保快速处理速度,非常适合实时应用。

  • 适应性:SAM 可以为特定任务或数据集进行微调和定制,允许无缝集成到各种应用和行业中。

论文:https://arxiv.org/pdf/2304.02643.pdf

现在,让我们深入探讨如何将 YOLO 与 SAM 嵌入在一起。但是,为什么我们需要将这两个模型结合在一起?

将 YOLO(You Only Look Once)与 SAM(Segment Anything Model)结合起来,提供了一个强大的协同效应,增强了两个模型的能力。 YOLO 在快速识别图像中的对象方面表现出色,而 SAM 在高精度分割对象方面具有优势。通过将 YOLO 与 SAM 嵌入在一起,我们可以利用这两个模型的优势,实现更全面和准确的图像分析。这种集成不仅可以检测对象,还可以精确地描绘它们的边界,为下游任务提供更丰富的上下文信息。此外,将 YOLO 与 SAM 嵌入在一起,可以更稳健和高效地处理复杂的视觉数据,在自动驾驶、医学成像和监控系统等应用中具有不可估量的价值。

实施 SAM 处理图像:

步骤1:首先从 GitHub 仓库下载 SAM 模型。

  1. import os
  2. HOME = os.getcwd()
  3. pip install roboflow ultralytics 'git+https://github.com/facebookresearch/segment-anything.git'

步骤2:安装 SAM 模型的权重,可以从 SAM 的 GitHub 仓库获取。

  1. %cd {HOME}/weights
  2. !wget -q https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth

步骤3:验证是否已成功下载 SAM 权重文件。

  1. import os
  2. CHECKPOINT_PATH = os.path.join(HOME, "weights", "sam_vit_h_4b8939.pth")
  3. print(CHECKPOINT_PATH, "; exist:", os.path.isfile(CHECKPOINT_PATH))

步骤4:加载模型。

  1. import torch
  2. from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor
  3. DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
  4. MODEL_TYPE = "vit_h"
  5. sam = sam_model_registry[MODEL_TYPE](checkpoint=CHECKPOINT_PATH).to(device=DEVICE)

步骤5:初始化掩码生成器。

mask_generator = SamAutomaticMaskGenerator(sam)

步骤6:为图像生成掩码。sam_result 变量包含生成的掩码。

  1. import cv2
  2. import supervision as sv
  3. image_bgr = cv2.imread("path/to/image")
  4. image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)
  5. sam_result = mask_generator.generate(image_rgb)

掩码生成返回一个包含多个掩码的列表,每个掩码都是一个包含各种数据的字典。这些键包括:

  • Segmentation:掩码

  • area:掩码的像素面积

  • bbox:掩码的边界框,格式为 XYWH

  • predicted_iou:模型对掩码质量的自我预测

  • point_coords:生成该掩码的采样输入点

  • stability_score:掩码质量的附加度量

  • crop_box:用于生成该掩码的图像裁剪框,格式为 XYWH

  1. print(len(masks))
  2. print(sam_results[0].keys())

步骤7:结果

  1. mask_annotator = sv.MaskAnnotator(color_lookup = sv.ColorLookup.INDEX)
  2. detections = sv.Detections.from_sam(sam_result=sam_result)
  3. annotated_image = mask_annotator.annotate(scene=image_bgr.copy(), detections=detections)
  4. sv.plot_images_grid(
  5. images=[image_bgr, annotated_image],
  6. grid_size=(1, 2),
  7. titles=['source image', 'segmented image']
  8. )

c6f5cf0211f9981502c4225d14f016a7.png

现在,实施 YOLO+SAM 处理视频:

步骤1:下载 YOLO、SAM 权重和其他依赖项。

  1. from ultralytics import YOLO
  2. from IPython.display import display, Image
  3. model = YOLO(MODEL)
  4. model.fuse()
  1. %cd {HOME}
  2. import sys
  3. !{sys.executable} -m pip install 'git+https://github.com/facebookresearch/segment-anything.git'
  1. !mkdir {HOME}/weights
  2. %cd {HOME}/weights
  3. !wget -q https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth

步骤2:确保正确安装了权重,并用掩码生成器初始化 SAM。

  1. import os
  2. CHECKPOINT_PATH = os.path.join(HOME, "weights", "sam_vit_h_4b8939.pth")
  3. print(CHECKPOINT_PATH, "; exist:", os.path.isfile(CHECKPOINT_PATH))
  1. import torch
  2. import matplotlib.pyplot as plt
  3. from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor
  4. DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
  5. sam = sam_model_registry["vit_h"](checkpoint=CHECKPOINT_PATH).to(device=DEVICE)
  6. mask_predictor = SamPredictor(sam)

步骤3:嵌入 YOLO 检测和 SAM 掩码。

  1. CLASS_NAMES_DICT = model.model.names
  2. # class_ids of interest - based on the number of classses
  3. CLASS_ID = [item for item in range(0,len(CLASS_NAMES_DICT))]
  4. CLASS_NAMES_DICT
  1. import cv2
  2. import numpy as np
  3. import torch
  4. # Replace the following line with your actual VIDEO_PATH
  5. VIDEO_PATH = "/path/to/input_video"
  6. OUTPUT_VIDEO_PATH = "/path/to/save/output_video"
  7. # This will contain the resulting mask predictions for local use
  8. mask_frames = []
  9. def get_video_dimensions(cap):
  10. width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
  11. height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
  12. return width, height
  13. def add_color_to_mask(mask, color):
  14. # Convert the color tensor to CPU
  15. color = torch.tensor(color).cpu().numpy()
  16. # Create a binary mask based on the original mask
  17. color_mask = np.zeros_like(mask.cpu().numpy(), dtype=np.uint8)
  18. color_mask[mask.cpu().numpy() > 0] = 1 # Set non-zero values to 1
  19. # Expand the color tensor and apply it to the binary mask
  20. colored_mask = color_mask[..., None] * color
  21. return colored_mask
  22. def draw_class_names(frame, class_names, positions, color, font_size=0.5):
  23. for class_name, position in zip(class_names, positions):
  24. cv2.putText(frame, class_name, position, cv2.FONT_HERSHEY_SIMPLEX, font_size, color, 2, cv2.LINE_AA)
  25. def draw_yolov8_boxes(frame, boxes, color):
  26. for box in boxes:
  27. box = list(map(int, box))
  28. cv2.rectangle(frame, (box[0], box[1]), (box[2], box[3]), color, 2)
  29. constant_mask_color = np.array([0, 0, 255], dtype=np.uint8) # Red color for masks
  30. output_class_color = (0, 255, 0) # Green color for class names
  31. yolov8_box_color = (255, 0, 0) # Blue color for YOLOv8 bounding boxes
  32. cap = cv2.VideoCapture(VIDEO_PATH)
  33. width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
  34. height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
  35. fourcc = cv2.VideoWriter_fourcc(*'XVID')
  36. output_video = cv2.VideoWriter(OUTPUT_VIDEO_PATH, fourcc, 15.0, (width, height))
  37. frame_num = 1
  38. while cap.isOpened():
  39. ret, frame = cap.read()
  40. if not ret:
  41. break
  42. # Check if the frame is empty or None
  43. if frame is None:
  44. continue # Skip processing for empty frames
  45. # Run frame through YOLOv8 to get detections
  46. detections = model.predict(frame, conf=0.7)
  47. # Check if there are fish detections
  48. if len(detections[0].boxes) == 0:
  49. continue # Skip processing for frames without fish detections
  50. # Run frame and detections through SAM to get masks
  51. transformed_boxes = mask_predictor.transform.apply_boxes_torch(
  52. detections[0].boxes.xyxy, list(get_video_dimensions(cap))
  53. )
  54. mask_predictor.set_image(frame)
  55. masks, _, _ = mask_predictor.predict_torch(
  56. boxes=transformed_boxes,
  57. multimask_output=False,
  58. point_coords=None,
  59. point_labels=None
  60. )
  61. # Check if the mask is empty
  62. if masks[0][0].numel() == 0:
  63. continue # Skip processing for empty masks
  64. # Combine mask predictions into a single mask, each with the same color
  65. class_ids = detections[0].boxes.cpu().cls
  66. merged_with_colors = add_color_to_mask(masks[0][0], constant_mask_color)
  67. for i in range(1, len(masks)):
  68. curr_mask_with_colors = add_color_to_mask(masks[i][0], constant_mask_color)
  69. merged_with_colors = np.bitwise_or(merged_with_colors, curr_mask_with_colors)
  70. # Draw YOLOv8 bounding boxes on the frame
  71. draw_yolov8_boxes(frame, detections[0].boxes.xyxy, yolov8_box_color)
  72. # Draw class names on the frame with a slightly larger font
  73. class_names = [CLASS_NAMES_DICT[int(class_id)] for class_id in class_ids]
  74. draw_class_names(frame, class_names, [(int(box[0]), int(box[1])) for box in detections[0].boxes.xyxy], output_class_color, font_size=0.7)
  75. # Overlay the SAM masks onto the frame
  76. frame_with_masks = cv2.addWeighted(frame, 1, merged_with_colors, 0.5, 0)
  77. # Write the frame with masks, YOLOv8 boxes, and class names to the output video
  78. output_video.write(frame_with_masks)
  79. frame_num += 1
  80. cap.release()
  81. output_video.release()
  82. cv2.destroyAllWindows()

701ddf90b023dffdad2678eec3dad6cd.png

上述示例适用于标准 YOLO 训练的数据集。此外,此方法也可用于自定义数据集。

应用:

  • 自动驾驶车辆:增强目标检测和分割能力,确保自动驾驶汽车的安全导航和决策。

  • 医学成像:通过精确识别和分割医学图像中的异常,提高诊断准确性,如 X 光片、MRI 和 CT 扫描。

  • 监控系统:通过精确检测和分割感兴趣的对象,提高公共场所的安全监控。

  • 工业自动化:通过检测和分割装配线上制造产品中的缺陷,优化质量控制过程。

  • 农业:通过精确识别和分割农业图像中的植物和害虫,协助作物监测和害虫检测。

  • 环境监测:通过检测和分割卫星图像中的树木、水体和野生动物,帮助监测和分析环境变化。

  • 增强现实:通过精确检测和分割现实世界中的物体,提升 AR 应用的沉浸式用户体验。

  • 零售分析:通过精确检测和分割零售环境中的产品,改善客户分析和库存管理。

总之,SAM(Segment Anything Model)和 YOLO(You Only Look Once)的融合代表了图像分析领域的重大进步,在各个领域具有深远的影响。这一整合结合了 YOLO 在目标检测方面的敏锐性和 SAM 在分割方面的精确性,使我们能够从视觉数据中获得更深入的见解。从优化自动驾驶车辆的感知系统到帮助医学专家诊断疾病,SAM+YOLO 的协同潜力远远超越了传统边界。

·  END  ·

HAPPY LIFE

60c50eb004d8dce632bf55a08e9cf924.png

本文仅供学习交流使用,如有侵权请联系作者删除

本文内容由网友自发贡献,转载请注明出处:【wpsshop博客】
推荐阅读
相关标签
  

闽ICP备14008679号