赞
踩
使用 SAM 2 进行视频分割
本笔记本介绍如何使用 SAM 2 进行视频交互式分割。它将涵盖以下内容:
在帧上添加点击,以获取并完善小掩码(时空掩码)
在整个视频中传播点击以获取掩码
同时分割和跟踪多个对象
我们使用分段或掩码来指单个帧上的物体模型预测,使用小掩码来指整个视频中的时空掩码。
如果使用 jupyter 在本地运行,请首先使用软件仓库中的安装说明在您的环境中安装 segment-anything-2。
导入库
- import os
- import torch
- import numpy as np
- import matplotlib.pyplot as plt
- from PIL import Image
- # 为整个 notebook 使用 bfloat16
- torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
-
- # 如果 CUDA 设备的属性为 8 或更高版本,则为 Ampere GPU 开启 tfloat32
- # 详细信息参考 https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
- if torch.cuda.get_device_properties(0).major >= 8:
- torch.backends.cuda.matmul.allow_tf32 = True
- torch.backends.cudnn.allow_tf32 = True
加载 SAM 2 视频预测器
- from sam2.build_sam import build_sam2_video_predictor
-
- # 指定 sam2 模型的检查点文件路径
- sam2_checkpoint = "../checkpoints/sam2_hiera_large.pt"
-
- # 指定模型配置文件路径
- model_cfg = "sam2_hiera_l.yaml"
-
- # 使用指定的模型配置和检查点文件构建视频预测器
- predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint)
-
sam2_checkpoint修改为自己的权重文件所在位置
model_cfg修改为与你权重相匹配的,一般在sam2_configs文件夹
选择视频示例
我们假设视频存储为 JPEG 帧列表,文件名为 <frame_index>.jpg。
对于自定义视频,您可以使用 ffmpeg (https://ffmpeg.org/) 提取其 JPEG 帧,如下所示:
ffmpeg -i <your_video>.mp4 -q:v 2 -start_number 0 <output_dir>/'%05d.jpg' 其中,-q:v 生成高质量的 JPEG 帧。
其中 -q:v 生成高质量的 JPEG 帧,而 -start_number 0 则要求 ffmpeg 从 00000.jpg 开始生成 JPEG 文件。
- # `video_dir` 是一个包含 JPEG 帧的目录,文件名格式如 `<frame_index>.jpg`
- video_dir = "./videos/bedroom"
-
- # 扫描该目录中的所有 JPEG 帧文件名
- frame_names = [
- p for p in os.listdir(video_dir)
- if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"]
- ]
- frame_names.sort(key=lambda p: int(os.path.splitext(p)[0]))
-
- # 查看第一帧视频帧
- frame_idx = 0
- plt.figure(figsize=(12, 8))
- plt.title(f"frame {frame_idx}")
- plt.imshow(Image.open(os.path.join(video_dir, frame_names[frame_idx])))
注意:记得加上plt.show()否则可能没有输出
(图见SAM2中video_predictor_example.ipynb,懒的放了)
初始化推理状态
SAM 2 需要有状态的推理来进行交互式视频分割,因此我们需要在这段视频上初始化推理状态。
在初始化过程中,它会加载 video_path 中的所有 JPEG 帧,并将其像素存储在 inference_state 中(如下图进度条所示)。
- # 初始化推理状态
- inference_state = predictor.init_state(video_path=video_dir)
例 1:分割并跟踪一个对象
注意:如果您之前使用此 inference_state 运行过任何跟踪,请先通过 reset_state 重置它。
(下面的单元格只是为了说明;这里不需要调用 reset_state,因为这个 inference_state 只是刚刚初始化)。
- # 重置推理状态
- predictor.reset_state(inference_state)
步骤 1:在框架上添加第一次点击
首先,让我们尝试分割左侧的孩子。
在这里,我们通过向 add_new_points API 发送坐标和标签,在 (x, y) = (210, 350) 处添加标签为 1 的正点击。
注意:标签 1 表示正点击(添加一个区域),标签 0 表示负点击(删除一个区域)。
- ann_frame_idx = 0 # 交互的帧索引
- ann_obj_id = 1 # 给每个交互对象一个唯一的ID(可以是任何整数)
-
- # 添加一个正点击 (x, y) = (210, 350) 来开始
- points = np.array([[210, 350]], dtype=np.float32)
- # 对于 labels,`1` 表示正点击,`0` 表示负点击
- labels = np.array([1], np.int32)
-
- # 向预测器添加新点
- _, out_obj_ids, out_mask_logits = predictor.add_new_points(
- inference_state=inference_state,
- frame_idx=ann_frame_idx,
- obj_id=ann_obj_id,
- points=points,
- labels=labels,
- )
-
- # 在当前(交互)帧上显示结果
- plt.figure(figsize=(12, 8))
- plt.title(f"frame {ann_frame_idx}")
- plt.imshow(Image.open(os.path.join(video_dir, frame_names[ann_frame_idx])))
- show_points(points, labels, plt.gca())
- show_mask((out_mask_logits[0] > 0.0).cpu().numpy(), plt.gca(), obj_id=out_obj_ids[0])
- plt.show()
注意:记得加上plt.show()否则可能没有输出
步骤 2:增加第二次点击以完善预测
嗯,看来虽然我们想要分割左侧的孩子,但模型只预测了短裤的遮罩--这有可能发生,因为单次点击会对目标对象产生歧义。我们可以通过再次点击孩子的上衣来完善这一帧的遮罩。
在这里,我们在 (x, y) = (250, 220) 处进行第二次正面点击,标签为 1,以扩展遮罩。
注意:在调用 add_new_points 时,我们需要发送所有点击及其标签(即不仅仅是最后一次点击)。
- ann_frame_idx = 0 # 交互的帧索引
- ann_obj_id = 1 # 给每个交互对象一个唯一的ID(可以是任何整数)
-
- # 添加第二个正点击 (x, y) = (250, 220) 以优化掩码
- # 将所有点击(及其标签)发送到 `add_new_points`
- points = np.array([[210, 350], [250, 220]], dtype=np.float32)
- # 对于 labels,`1` 表示正点击,`0` 表示负点击
- labels = np.array([1, 1], np.int32)
-
- # 向预测器添加新点
- _, out_obj_ids, out_mask_logits = predictor.add_new_points(
- inference_state=inference_state,
- frame_idx=ann_frame_idx,
- obj_id=ann_obj_id,
- points=points,
- labels=labels,
- )
-
- # 在当前(交互)帧上显示结果
- plt.figure(figsize=(12, 8))
- plt.title(f"frame {ann_frame_idx}")
- plt.imshow(Image.open(os.path.join(video_dir, frame_names[ann_frame_idx])))
- show_points(points, labels, plt.gca())
- show_mask((out_mask_logits[0] > 0.0).cpu().numpy(), plt.gca(), obj_id=out_obj_ids[0])
- plt.show()
点击第 2 次细化后,我们就能得到第 0 帧上整个儿童的分割蒙版。
第 3 步:传播提示,在整个视频中获取小掩码
为了在整个视频中获取掩码,我们使用 propagate_in_video API 传播提示信息。
- # 在整个视频中运行传播并将结果收集到一个字典中
- video_segments = {} # video_segments 包含每帧的分割结果
- for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(inference_state):
- video_segments[out_frame_idx] = {
- out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
- for i, out_obj_id in enumerate(out_obj_ids)
- }
-
- # 每隔几帧渲染一次分割结果
- vis_frame_stride = 15
- plt.close("all")
- for out_frame_idx in range(0, len(frame_names), vis_frame_stride):
- plt.figure(figsize=(6, 4))
- plt.title(f"frame {out_frame_idx}")
- plt.imshow(Image.open(os.path.join(video_dir, frame_names[out_frame_idx])))
- for out_obj_id, out_mask in video_segments[out_frame_idx].items():
- show_mask(out_mask, plt.gca(), obj_id=out_obj_id)
- plt.show()
步骤 4:添加新的提示以进一步完善小掩码
在上面的输出小掩码中,第 150 帧的边界细节似乎存在一些瑕疵。
通过 SAM 2,我们可以交互式地修正模型预测。我们可以在该帧的 (x, y) = (82, 415) 处添加一个标签为 0 的负点击,以完善子掩码。在这里,我们使用不同的 frame_idx 参数调用 add_new_points 应用程序接口,以指示我们要细化的帧索引。
- # 设定需要进一步细化的帧索引和对象ID
- ann_frame_idx = 150 # 需要细化的帧索引
- ann_obj_id = 1 # 与我们交互的对象的唯一ID(可以是任何整数)
-
- # 显示细化前的分割结果
- plt.figure(figsize=(12, 8))
- plt.title(f"frame {ann_frame_idx} -- before refinement")
- plt.imshow(Image.open(os.path.join(video_dir, frame_names[ann_frame_idx])))
- show_mask(video_segments[ann_frame_idx][ann_obj_id], plt.gca(), obj_id=ann_obj_id)
- plt.show() # 确保图像在细化前显示
-
- # 在该帧添加一个负点击 (x, y) = (82, 415) 以细化分割结果
- points = np.array([[82, 415]], dtype=np.float32)
- # 标签为`1`表示正点击,`0`表示负点击
- labels = np.array([0], np.int32)
- _, _, out_mask_logits = predictor.add_new_points(
- inference_state=inference_state,
- frame_idx=ann_frame_idx,
- obj_id=ann_obj_id,
- points=points,
- labels=labels,
- )
-
- # 显示细化后的分割结果
- plt.figure(figsize=(12, 8))
- plt.title(f"frame {ann_frame_idx} -- after refinement")
- plt.imshow(Image.open(os.path.join(video_dir, frame_names[ann_frame_idx])))
- show_points(points, labels, plt.gca())
- show_mask((out_mask_logits > 0.0).cpu().numpy(), plt.gca(), obj_id=ann_obj_id)
- plt.show() # 确保图像在细化后显示
第 5 步:传播提示(再次),在整个视频中获取小掩码
让我们更新整个视频的字幕。在此,我们再次调用 propagate_in_video,在添加上述新的细化点击后传播所有提示信息。
- # 运行分割传播并在字典中收集结果
- video_segments = {} # video_segments 包含每帧的分割结果
- for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(inference_state):
- video_segments[out_frame_idx] = {
- out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
- for i, out_obj_id in enumerate(out_obj_ids)
- }
-
- # 每隔几帧渲染分割结果
- vis_frame_stride = 15
- plt.close("all")
- for out_frame_idx in range(0, len(frame_names), vis_frame_stride):
- plt.figure(figsize=(6, 4))
- plt.title(f"frame {out_frame_idx}")
- plt.imshow(Image.open(os.path.join(video_dir, frame_names[out_frame_idx])))
- for out_obj_id, out_mask in video_segments[out_frame_idx].items():
- show_mask(out_mask, plt.gca(), obj_id=out_obj_id)
- plt.show() # 确保每个渲染的图像显示出来
现在,所有框架上的线段都很美观。
例1完整代码如下:
- import os
- import torch
- import numpy as np
- import matplotlib.pyplot as plt
- from PIL import Image
-
-
- torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
-
- if torch.cuda.get_device_properties(0).major >= 8:
- # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
- torch.backends.cuda.matmul.allow_tf32 = True
- torch.backends.cudnn.allow_tf32 = True
-
- from sam2.build_sam import build_sam2_video_predictor
-
- sam2_checkpoint = r"E:\segment-anything-2-main\checkpoints\sam2_hiera_large.pt"
- model_cfg = r"E:\segment-anything-2-main\sam2_configs\sam2_hiera_l.yaml"
-
- predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint)
-
- def show_mask(mask, ax, obj_id=None, random_color=False):
- if random_color:
- color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
- else:
- cmap = plt.get_cmap("tab10")
- cmap_idx = 0 if obj_id is None else obj_id
- color = np.array([*cmap(cmap_idx)[:3], 0.6])
- h, w = mask.shape[-2:]
- mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
- ax.imshow(mask_image)
-
-
- def show_points(coords, labels, ax, marker_size=200):
- pos_points = coords[labels==1]
- neg_points = coords[labels==0]
- ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
- ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
-
- # `video_dir` a directory of JPEG frames with filenames like `<frame_index>.jpg`
- video_dir = r"E:\segment-anything-2-main\notebooks\videos\bedroom"
- #video_dir = r"E:\segment-anything-2-main\jc-imgs"
- # scan all the JPEG frame names in this directory
- frame_names = [
- p for p in os.listdir(video_dir)
- if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"]
- ]
- frame_names.sort(key=lambda p: int(os.path.splitext(p)[0]))
- # take a look the first video frame
- frame_idx = 0
- plt.figure(figsize=(12, 8))
- plt.title(f"frame {frame_idx}")
- plt.imshow(Image.open(os.path.join(video_dir, frame_names[frame_idx])))
- #plt.show()
-
- inference_state = predictor.init_state(video_path=video_dir)
- predictor.reset_state(inference_state)
-
- ann_frame_idx = 0 # the frame index we interact with
- ann_obj_id = 1 # give a unique id to each object we interact with (it can be any integers)
-
- # Let's add a positive click at (x, y) = (210, 350) to get started
- points = np.array([[210, 350]], dtype=np.float32)
- # for labels, `1` means positive click and `0` means negative click
- labels = np.array([1], np.int32)
- _, out_obj_ids, out_mask_logits = predictor.add_new_points(
- inference_state=inference_state,
- frame_idx=ann_frame_idx,
- obj_id=ann_obj_id,
- points=points,
- labels=labels,
- )
- # show the results on the current (interacted) frame
- plt.figure(figsize=(12, 8))
- plt.title(f"frame {ann_frame_idx}")
- plt.imshow(Image.open(os.path.join(video_dir, frame_names[ann_frame_idx])))
- show_points(points, labels, plt.gca())
- show_mask((out_mask_logits[0] > 0.0).cpu().numpy(), plt.gca(), obj_id=out_obj_ids[0])
- #plt.show()
- ann_frame_idx = 0 # 交互的帧索引
- ann_obj_id = 1 # 给每个交互对象一个唯一的ID(可以是任何整数)
- # 添加第二个正点击 (x, y) = (250, 220) 以优化掩码
- # 将所有点击(及其标签)发送到 `add_new_points`
- points = np.array([[210, 350], [250, 220]], dtype=np.float32)
- # 对于 labels,`1` 表示正点击,`0` 表示负点击
- labels = np.array([1, 1], np.int32)
- # 向预测器添加新点
- _, out_obj_ids, out_mask_logits = predictor.add_new_points(
- inference_state=inference_state,
- frame_idx=ann_frame_idx,
- obj_id=ann_obj_id,
- points=points,
- labels=labels,
- )
- # 在当前(交互)帧上显示结果
- plt.figure(figsize=(12, 8))
- plt.title(f"frame {ann_frame_idx}")
- plt.imshow(Image.open(os.path.join(video_dir, frame_names[ann_frame_idx])))
- show_points(points, labels, plt.gca())
- show_mask((out_mask_logits[0] > 0.0).cpu().numpy(), plt.gca(), obj_id=out_obj_ids[0])
- plt.show()
- # 在整个视频中运行传播并将结果收集到一个字典中
- video_segments = {} # video_segments 包含每帧的分割结果
- for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(inference_state):
- video_segments[out_frame_idx] = {
- out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
- for i, out_obj_id in enumerate(out_obj_ids)
- }
- # 每隔几帧渲染一次分割结果
- vis_frame_stride = 1
- plt.close("all")
- for out_frame_idx in range(0, len(frame_names), vis_frame_stride):
- plt.figure(figsize=(6, 4))
- plt.title(f"frame {out_frame_idx}")
- plt.imshow(Image.open(os.path.join(video_dir, frame_names[out_frame_idx])))
- for out_obj_id, out_mask in video_segments[out_frame_idx].items():
- show_mask(out_mask, plt.gca(), obj_id=out_obj_id)
- plt.show()
- # 设定需要进一步细化的帧索引和对象ID
- ann_frame_idx = 15 # 需要细化的帧索引
- ann_obj_id = 1 # 与我们交互的对象的唯一ID(可以是任何整数)
- # 显示细化前的分割结果
- plt.figure(figsize=(12, 8))
- plt.title(f"frame {ann_frame_idx} -- before refinement")
- plt.imshow(Image.open(os.path.join(video_dir, frame_names[ann_frame_idx])))
- show_mask(video_segments[ann_frame_idx][ann_obj_id], plt.gca(), obj_id=ann_obj_id)
- plt.show() # 确保图像在细化前显示
- # 在该帧添加一个负点击 (x, y) = (82, 415) 以细化分割结果
- points = np.array([[82, 415]], dtype=np.float32)
- # 标签为`1`表示正点击,`0`表示负点击
- labels = np.array([0], np.int32)
- _, _, out_mask_logits = predictor.add_new_points(
- inference_state=inference_state,
- frame_idx=ann_frame_idx,
- obj_id=ann_obj_id,
- points=points,
- labels=labels,
- )
- # 显示细化后的分割结果
- plt.figure(figsize=(12, 8))
- plt.title(f"frame {ann_frame_idx} -- after refinement")
- plt.imshow(Image.open(os.path.join(video_dir, frame_names[ann_frame_idx])))
- show_points(points, labels, plt.gca())
- show_mask((out_mask_logits > 0.0).cpu().numpy(), plt.gca(), obj_id=ann_obj_id)
- plt.show() # 确保图像在细化后显示
可以让gpt写一个分割结果整合成视频的代码,方便使用。
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。