当前位置:   article > 正文

SAM2视频模块使用(翻译自video_predictor_example.ipynb)_sam2 视频分割教程

sam2 视频分割教程

使用 SAM 2 进行视频分割
笔记本介绍如何使用 SAM 2 进行视频交互式分割。它将涵盖以下内容:

在帧上添加点击,以获取并完善小掩码(时空掩码)
在整个视频中传播点击以获取掩码
同时分割和跟踪多个对象
我们使用分段或掩码来指单个帧上的物体模型预测,使用小掩码来指整个视频中的时空掩码。

如果使用 jupyter 在本地运行,请首先使用软件仓库中的安装说明在您的环境中安装 segment-anything-2。

导入库

  1. import os
  2. import torch
  3. import numpy as np
  4. import matplotlib.pyplot as plt
  5. from PIL import Image
  1. # 为整个 notebook 使用 bfloat16
  2. torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
  3. # 如果 CUDA 设备的属性为 8 或更高版本,则为 Ampere GPU 开启 tfloat32
  4. # 详细信息参考 https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
  5. if torch.cuda.get_device_properties(0).major >= 8:
  6. torch.backends.cuda.matmul.allow_tf32 = True
  7. torch.backends.cudnn.allow_tf32 = True

加载 SAM 2 视频预测器

  1. from sam2.build_sam import build_sam2_video_predictor
  2. # 指定 sam2 模型的检查点文件路径
  3. sam2_checkpoint = "../checkpoints/sam2_hiera_large.pt"
  4. # 指定模型配置文件路径
  5. model_cfg = "sam2_hiera_l.yaml"
  6. # 使用指定的模型配置和检查点文件构建视频预测器
  7. predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint)

 sam2_checkpoint修改为自己的权重文件所在位置

model_cfg修改为与你权重相匹配的,一般在sam2_configs文件夹

选择视频示例
我们假设视频存储为 JPEG 帧列表,文件名为 <frame_index>.jpg。

对于自定义视频,您可以使用 ffmpeg (https://ffmpeg.org/) 提取其 JPEG 帧,如下所示:

ffmpeg -i <your_video>.mp4 -q:v 2 -start_number 0 <output_dir>/'%05d.jpg' 其中,-q:v 生成高质量的 JPEG 帧。
其中 -q:v 生成高质量的 JPEG 帧,而 -start_number 0 则要求 ffmpeg 从 00000.jpg 开始生成 JPEG 文件。

  1. # `video_dir` 是一个包含 JPEG 帧的目录,文件名格式如 `<frame_index>.jpg`
  2. video_dir = "./videos/bedroom"
  3. # 扫描该目录中的所有 JPEG 帧文件名
  4. frame_names = [
  5. p for p in os.listdir(video_dir)
  6. if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"]
  7. ]
  8. frame_names.sort(key=lambda p: int(os.path.splitext(p)[0]))
  9. # 查看第一帧视频帧
  10. frame_idx = 0
  11. plt.figure(figsize=(12, 8))
  12. plt.title(f"frame {frame_idx}")
  13. plt.imshow(Image.open(os.path.join(video_dir, frame_names[frame_idx])))

注意:记得加上plt.show()否则可能没有输出

(图见SAM2中video_predictor_example.ipynb,懒的放了)

初始化推理状态
SAM 2 需要有状态的推理来进行交互式视频分割,因此我们需要在这段视频上初始化推理状态。

在初始化过程中,它会加载 video_path 中的所有 JPEG 帧,并将其像素存储在 inference_state 中(如下图进度条所示)。

  1. # 初始化推理状态
  2. inference_state = predictor.init_state(video_path=video_dir)

例 1:分割并跟踪一个对象
注意:如果您之前使用此 inference_state 运行过任何跟踪,请先通过 reset_state 重置它。

(下面的单元格只是为了说明;这里不需要调用 reset_state,因为这个 inference_state 只是刚刚初始化)。

  1. # 重置推理状态
  2. predictor.reset_state(inference_state)

步骤 1:在框架上添加第一次点击
首先,让我们尝试分割左侧的孩子。

在这里,我们通过向 add_new_points API 发送坐标和标签,在 (x, y) = (210, 350) 处添加标签为 1 的正点击。

注意:标签 1 表示正点击(添加一个区域),标签 0 表示负点击(删除一个区域)。

  1. ann_frame_idx = 0 # 交互的帧索引
  2. ann_obj_id = 1 # 给每个交互对象一个唯一的ID(可以是任何整数)
  3. # 添加一个正点击 (x, y) = (210, 350) 来开始
  4. points = np.array([[210, 350]], dtype=np.float32)
  5. # 对于 labels,`1` 表示正点击,`0` 表示负点击
  6. labels = np.array([1], np.int32)
  7. # 向预测器添加新点
  8. _, out_obj_ids, out_mask_logits = predictor.add_new_points(
  9. inference_state=inference_state,
  10. frame_idx=ann_frame_idx,
  11. obj_id=ann_obj_id,
  12. points=points,
  13. labels=labels,
  14. )
  15. # 在当前(交互)帧上显示结果
  16. plt.figure(figsize=(12, 8))
  17. plt.title(f"frame {ann_frame_idx}")
  18. plt.imshow(Image.open(os.path.join(video_dir, frame_names[ann_frame_idx])))
  19. show_points(points, labels, plt.gca())
  20. show_mask((out_mask_logits[0] > 0.0).cpu().numpy(), plt.gca(), obj_id=out_obj_ids[0])
  21. plt.show()

注意:记得加上plt.show()否则可能没有输出

步骤 2:增加第二次点击以完善预测
嗯,看来虽然我们想要分割左侧的孩子,但模型只预测了短裤的遮罩--这有可能发生,因为单次点击会对目标对象产生歧义。我们可以通过再次点击孩子的上衣来完善这一帧的遮罩。

在这里,我们在 (x, y) = (250, 220) 处进行第二次正面点击,标签为 1,以扩展遮罩。

注意:在调用 add_new_points 时,我们需要发送所有点击及其标签(即不仅仅是最后一次点击)。

  1. ann_frame_idx = 0 # 交互的帧索引
  2. ann_obj_id = 1 # 给每个交互对象一个唯一的ID(可以是任何整数)
  3. # 添加第二个正点击 (x, y) = (250, 220) 以优化掩码
  4. # 将所有点击(及其标签)发送到 `add_new_points`
  5. points = np.array([[210, 350], [250, 220]], dtype=np.float32)
  6. # 对于 labels,`1` 表示正点击,`0` 表示负点击
  7. labels = np.array([1, 1], np.int32)
  8. # 向预测器添加新点
  9. _, out_obj_ids, out_mask_logits = predictor.add_new_points(
  10. inference_state=inference_state,
  11. frame_idx=ann_frame_idx,
  12. obj_id=ann_obj_id,
  13. points=points,
  14. labels=labels,
  15. )
  16. # 在当前(交互)帧上显示结果
  17. plt.figure(figsize=(12, 8))
  18. plt.title(f"frame {ann_frame_idx}")
  19. plt.imshow(Image.open(os.path.join(video_dir, frame_names[ann_frame_idx])))
  20. show_points(points, labels, plt.gca())
  21. show_mask((out_mask_logits[0] > 0.0).cpu().numpy(), plt.gca(), obj_id=out_obj_ids[0])
  22. plt.show()

点击第 2 次细化后,我们就能得到第 0 帧上整个儿童的分割蒙版。

第 3 步:传播提示,在整个视频中获取小掩码
为了在整个视频中获取掩码,我们使用 propagate_in_video API 传播提示信息。

  1. # 在整个视频中运行传播并将结果收集到一个字典中
  2. video_segments = {} # video_segments 包含每帧的分割结果
  3. for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(inference_state):
  4. video_segments[out_frame_idx] = {
  5. out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
  6. for i, out_obj_id in enumerate(out_obj_ids)
  7. }
  8. # 每隔几帧渲染一次分割结果
  9. vis_frame_stride = 15
  10. plt.close("all")
  11. for out_frame_idx in range(0, len(frame_names), vis_frame_stride):
  12. plt.figure(figsize=(6, 4))
  13. plt.title(f"frame {out_frame_idx}")
  14. plt.imshow(Image.open(os.path.join(video_dir, frame_names[out_frame_idx])))
  15. for out_obj_id, out_mask in video_segments[out_frame_idx].items():
  16. show_mask(out_mask, plt.gca(), obj_id=out_obj_id)
  17. plt.show()

步骤 4:添加新的提示以进一步完善小掩码
在上面的输出小掩码中,第 150 帧的边界细节似乎存在一些瑕疵。

通过 SAM 2,我们可以交互式地修正模型预测。我们可以在该帧的 (x, y) = (82, 415) 处添加一个标签为 0 的负点击,以完善子掩码。在这里,我们使用不同的 frame_idx 参数调用 add_new_points 应用程序接口,以指示我们要细化的帧索引。

  1. # 设定需要进一步细化的帧索引和对象ID
  2. ann_frame_idx = 150 # 需要细化的帧索引
  3. ann_obj_id = 1 # 与我们交互的对象的唯一ID(可以是任何整数)
  4. # 显示细化前的分割结果
  5. plt.figure(figsize=(12, 8))
  6. plt.title(f"frame {ann_frame_idx} -- before refinement")
  7. plt.imshow(Image.open(os.path.join(video_dir, frame_names[ann_frame_idx])))
  8. show_mask(video_segments[ann_frame_idx][ann_obj_id], plt.gca(), obj_id=ann_obj_id)
  9. plt.show() # 确保图像在细化前显示
  10. # 在该帧添加一个负点击 (x, y) = (82, 415) 以细化分割结果
  11. points = np.array([[82, 415]], dtype=np.float32)
  12. # 标签为`1`表示正点击,`0`表示负点击
  13. labels = np.array([0], np.int32)
  14. _, _, out_mask_logits = predictor.add_new_points(
  15. inference_state=inference_state,
  16. frame_idx=ann_frame_idx,
  17. obj_id=ann_obj_id,
  18. points=points,
  19. labels=labels,
  20. )
  21. # 显示细化后的分割结果
  22. plt.figure(figsize=(12, 8))
  23. plt.title(f"frame {ann_frame_idx} -- after refinement")
  24. plt.imshow(Image.open(os.path.join(video_dir, frame_names[ann_frame_idx])))
  25. show_points(points, labels, plt.gca())
  26. show_mask((out_mask_logits > 0.0).cpu().numpy(), plt.gca(), obj_id=ann_obj_id)
  27. plt.show() # 确保图像在细化后显示

第 5 步:传播提示(再次),在整个视频中获取小掩码
让我们更新整个视频的字幕。在此,我们再次调用 propagate_in_video,在添加上述新的细化点击后传播所有提示信息。

  1. # 运行分割传播并在字典中收集结果
  2. video_segments = {} # video_segments 包含每帧的分割结果
  3. for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(inference_state):
  4. video_segments[out_frame_idx] = {
  5. out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
  6. for i, out_obj_id in enumerate(out_obj_ids)
  7. }
  8. # 每隔几帧渲染分割结果
  9. vis_frame_stride = 15
  10. plt.close("all")
  11. for out_frame_idx in range(0, len(frame_names), vis_frame_stride):
  12. plt.figure(figsize=(6, 4))
  13. plt.title(f"frame {out_frame_idx}")
  14. plt.imshow(Image.open(os.path.join(video_dir, frame_names[out_frame_idx])))
  15. for out_obj_id, out_mask in video_segments[out_frame_idx].items():
  16. show_mask(out_mask, plt.gca(), obj_id=out_obj_id)
  17. plt.show() # 确保每个渲染的图像显示出来

现在,所有框架上的线段都很美观。

例1完整代码如下:

  1. import os
  2. import torch
  3. import numpy as np
  4. import matplotlib.pyplot as plt
  5. from PIL import Image
  6. torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
  7. if torch.cuda.get_device_properties(0).major >= 8:
  8. # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
  9. torch.backends.cuda.matmul.allow_tf32 = True
  10. torch.backends.cudnn.allow_tf32 = True
  11. from sam2.build_sam import build_sam2_video_predictor
  12. sam2_checkpoint = r"E:\segment-anything-2-main\checkpoints\sam2_hiera_large.pt"
  13. model_cfg = r"E:\segment-anything-2-main\sam2_configs\sam2_hiera_l.yaml"
  14. predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint)
  15. def show_mask(mask, ax, obj_id=None, random_color=False):
  16. if random_color:
  17. color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
  18. else:
  19. cmap = plt.get_cmap("tab10")
  20. cmap_idx = 0 if obj_id is None else obj_id
  21. color = np.array([*cmap(cmap_idx)[:3], 0.6])
  22. h, w = mask.shape[-2:]
  23. mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
  24. ax.imshow(mask_image)
  25. def show_points(coords, labels, ax, marker_size=200):
  26. pos_points = coords[labels==1]
  27. neg_points = coords[labels==0]
  28. ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
  29. ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
  30. # `video_dir` a directory of JPEG frames with filenames like `<frame_index>.jpg`
  31. video_dir = r"E:\segment-anything-2-main\notebooks\videos\bedroom"
  32. #video_dir = r"E:\segment-anything-2-main\jc-imgs"
  33. # scan all the JPEG frame names in this directory
  34. frame_names = [
  35. p for p in os.listdir(video_dir)
  36. if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"]
  37. ]
  38. frame_names.sort(key=lambda p: int(os.path.splitext(p)[0]))
  39. # take a look the first video frame
  40. frame_idx = 0
  41. plt.figure(figsize=(12, 8))
  42. plt.title(f"frame {frame_idx}")
  43. plt.imshow(Image.open(os.path.join(video_dir, frame_names[frame_idx])))
  44. #plt.show()
  45. inference_state = predictor.init_state(video_path=video_dir)
  46. predictor.reset_state(inference_state)
  47. ann_frame_idx = 0 # the frame index we interact with
  48. ann_obj_id = 1 # give a unique id to each object we interact with (it can be any integers)
  49. # Let's add a positive click at (x, y) = (210, 350) to get started
  50. points = np.array([[210, 350]], dtype=np.float32)
  51. # for labels, `1` means positive click and `0` means negative click
  52. labels = np.array([1], np.int32)
  53. _, out_obj_ids, out_mask_logits = predictor.add_new_points(
  54. inference_state=inference_state,
  55. frame_idx=ann_frame_idx,
  56. obj_id=ann_obj_id,
  57. points=points,
  58. labels=labels,
  59. )
  60. # show the results on the current (interacted) frame
  61. plt.figure(figsize=(12, 8))
  62. plt.title(f"frame {ann_frame_idx}")
  63. plt.imshow(Image.open(os.path.join(video_dir, frame_names[ann_frame_idx])))
  64. show_points(points, labels, plt.gca())
  65. show_mask((out_mask_logits[0] > 0.0).cpu().numpy(), plt.gca(), obj_id=out_obj_ids[0])
  66. #plt.show()
  67. ann_frame_idx = 0 # 交互的帧索引
  68. ann_obj_id = 1 # 给每个交互对象一个唯一的ID(可以是任何整数)
  69. # 添加第二个正点击 (x, y) = (250, 220) 以优化掩码
  70. # 将所有点击(及其标签)发送到 `add_new_points`
  71. points = np.array([[210, 350], [250, 220]], dtype=np.float32)
  72. # 对于 labels,`1` 表示正点击,`0` 表示负点击
  73. labels = np.array([1, 1], np.int32)
  74. # 向预测器添加新点
  75. _, out_obj_ids, out_mask_logits = predictor.add_new_points(
  76. inference_state=inference_state,
  77. frame_idx=ann_frame_idx,
  78. obj_id=ann_obj_id,
  79. points=points,
  80. labels=labels,
  81. )
  82. # 在当前(交互)帧上显示结果
  83. plt.figure(figsize=(12, 8))
  84. plt.title(f"frame {ann_frame_idx}")
  85. plt.imshow(Image.open(os.path.join(video_dir, frame_names[ann_frame_idx])))
  86. show_points(points, labels, plt.gca())
  87. show_mask((out_mask_logits[0] > 0.0).cpu().numpy(), plt.gca(), obj_id=out_obj_ids[0])
  88. plt.show()
  89. # 在整个视频中运行传播并将结果收集到一个字典中
  90. video_segments = {} # video_segments 包含每帧的分割结果
  91. for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(inference_state):
  92. video_segments[out_frame_idx] = {
  93. out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
  94. for i, out_obj_id in enumerate(out_obj_ids)
  95. }
  96. # 每隔几帧渲染一次分割结果
  97. vis_frame_stride = 1
  98. plt.close("all")
  99. for out_frame_idx in range(0, len(frame_names), vis_frame_stride):
  100. plt.figure(figsize=(6, 4))
  101. plt.title(f"frame {out_frame_idx}")
  102. plt.imshow(Image.open(os.path.join(video_dir, frame_names[out_frame_idx])))
  103. for out_obj_id, out_mask in video_segments[out_frame_idx].items():
  104. show_mask(out_mask, plt.gca(), obj_id=out_obj_id)
  105. plt.show()
  106. # 设定需要进一步细化的帧索引和对象ID
  107. ann_frame_idx = 15 # 需要细化的帧索引
  108. ann_obj_id = 1 # 与我们交互的对象的唯一ID(可以是任何整数)
  109. # 显示细化前的分割结果
  110. plt.figure(figsize=(12, 8))
  111. plt.title(f"frame {ann_frame_idx} -- before refinement")
  112. plt.imshow(Image.open(os.path.join(video_dir, frame_names[ann_frame_idx])))
  113. show_mask(video_segments[ann_frame_idx][ann_obj_id], plt.gca(), obj_id=ann_obj_id)
  114. plt.show() # 确保图像在细化前显示
  115. # 在该帧添加一个负点击 (x, y) = (82, 415) 以细化分割结果
  116. points = np.array([[82, 415]], dtype=np.float32)
  117. # 标签为`1`表示正点击,`0`表示负点击
  118. labels = np.array([0], np.int32)
  119. _, _, out_mask_logits = predictor.add_new_points(
  120. inference_state=inference_state,
  121. frame_idx=ann_frame_idx,
  122. obj_id=ann_obj_id,
  123. points=points,
  124. labels=labels,
  125. )
  126. # 显示细化后的分割结果
  127. plt.figure(figsize=(12, 8))
  128. plt.title(f"frame {ann_frame_idx} -- after refinement")
  129. plt.imshow(Image.open(os.path.join(video_dir, frame_names[ann_frame_idx])))
  130. show_points(points, labels, plt.gca())
  131. show_mask((out_mask_logits > 0.0).cpu().numpy(), plt.gca(), obj_id=ann_obj_id)
  132. plt.show() # 确保图像在细化后显示

可以让gpt写一个分割结果整合成视频的代码,方便使用。

例2见http://t.csdnimg.cn/tL4cL

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/秋刀鱼在做梦/article/detail/1007607
推荐阅读
相关标签
  

闽ICP备14008679号