CV - 图像实例分割开源算法 SAM2(Segment Anything) 视频分割 教程 (2)
欢迎关注我的CSDN:https://spike.blog.csdn.net/
本文地址:https://spike.blog.csdn.net/article/details/143220597
免责声明:本文来源于个人知识与公开资料,仅用于学术交流,欢迎讨论,不支持转载。
SAM2(Segment Anything Model 2),在视频分割领域中,引入流式内存架构,实现实时视频处理,提高分割精度,减少用户交互的需求,使其在图像和视频中的对象识别和分割任务中表现出色,为各种下游应用打下了坚实的基础。
1. 导入环境
导入 Python 包,以及引入 CUDA 配置:
import os
# if using Apple MPS, fall back to CPU for unsupported ops
# os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
import numpy as np
import torch
import matplotlib.pyplot as plt
from PIL import Image
# select the device for computation
if torch.cuda.is_available():
device = torch.device("cuda")
elif torch.backends.mps.is_available():
device = torch.device("mps")
else:
device = torch.device("cpu")
print(f"using device: {device}")
if device.type == "cuda":
# use bfloat16 for the entire notebook
torch.autocast("cuda", dtype=torch.bfloat16).__enter__()
# turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
if torch.cuda.get_device_properties(0).major >= 8:
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
elif device.type == "mps":
print(
"\nSupport for MPS devices is preliminary. SAM 2 is trained with CUDA and might "
"give numerically different outputs and sometimes degraded performance on MPS. "
"See e.g. https://github.com/pytorch/pytorch/issues/84936 for a discussion."
)
加载 SAM2 模型:
from sam2.build_sam import build_sam2_video_predictor
sam2_checkpoint = "sam2.1_hiera_large.pt"
model_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml"
predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint, device=device)
相关细节参考:实例分割开源算法 SAM2(Segment Anything 2) 配置与推理 教程 (1)
视觉信息展示函数:
show_mask
展示 maskshow_points
展示标注点,区分正例点和负例点show_box
展示 box
即:
def show_mask(mask, ax, obj_id=None, random_color=False):
if random_color:
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
else:
cmap = plt.get_cmap("tab10")
cmap_idx = 0 if obj_id is None else obj_id
color = np.array([*cmap(cmap_idx)[:3], 0.6])
h, w = mask.shape[-2:]
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
ax.imshow(mask_image)
def show_points(coords, labels, ax, marker_size=200):
pos_points = coords[labels==1]
neg_points = coords[labels==0]
ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
def show_box(box, ax):
x0, y0 = box[0], box[1]
w, h = box[2] - box[0], box[3] - box[1]
ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0, 0, 0, 0), lw=2))
显示测试视频单帧,测试视频来源于,模特试穿鞋的商品展示,即:
# `video_dir` a directory of JPEG frames with filenames like `<frame_index>.jpg`
video_dir = "[your folder]/sam2/mydata/test_video"
# scan all the JPEG frame names in this directory
frame_names = [
p for p in os.listdir(video_dir)
if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"]
]
frame_names.sort(key=lambda p: int(os.path.splitext(p)[0]))
# take a look the first video frame
frame_idx = 0
plt.figure(figsize=(9, 6))
plt.title(f"frame {frame_idx}")
plt.imshow(Image.open(os.path.join(video_dir, frame_names[frame_idx])))
下载图像:选择 JupyterLab 的 Cell,使用
Shift + Command + 双击
下载图像。
2. 目标分割 - 正点击(Positive Click)
加载视频的全部图像帧,以及重置实例状态,SAM2 保持实例的记忆功能,每次清空都需要重置状态(reset_state
),即:
inference_state = predictor.init_state(video_path=video_dir)
predictor.reset_state(inference_state)
使用 正点击(Positive Click)
作为输入,获得相关区域的 mask,即:
ann_frame_idx = 0
,测试帧ann_obj_id = 1
,目标 ID,即目标的唯一标识points = np.array([[600, 350]], dtype=np.float32)
,输入点的位置,注意是一个列表,可以输入多个点。labels = np.array([1], np.int32)
,输入点的标签,1代表正例,0代表负例,即1是目标区域,0是非目标区域,避免覆盖过大。predictor.add_new_points_or_box()
,根据 点(points) 或 框(box),测试帧frame_idx=ann_frame_idx
,生成 mask
即:
ann_frame_idx = 0 # the frame index we interact with
ann_obj_id = 1 # give a unique id to each object we interact with (it can be any integers)
# Let's add a positive click at (x, y) = (210, 350) to get started
points = np.array([[600, 350]], dtype=np.float32)
# for labels, `1` means positive click and `0` means negative click
labels = np.array([1], np.int32)
_, out_obj_ids, out_mask_logits = predictor.add_new_points_or_box(
inference_state=inference_state,
frame_idx=ann_frame_idx,
obj_id=ann_obj_id,
points=points,
labels=labels,
)
# show the results on the current (interacted) frame
plt.figure(figsize=(9, 6))
plt.title(f"frame {ann_frame_idx}")
plt.imshow(Image.open(os.path.join(video_dir, frame_names[ann_frame_idx])))
show_points(points, labels, plt.gca())
show_mask((out_mask_logits[0] > 0.0).cpu().numpy(), plt.gca(), obj_id=out_obj_ids[0])
观察图像,目标选择脚,但是把鞋带部分带入实例中,如下:
使用2个正点击,精细确定实例范围,即[[600, 350], [900, 350]]
,使得轮廓更加精准,即:
ann_frame_idx = 0 # the frame index we interact with
ann_obj_id = 1 # give a unique id to each object we interact with (it can be any integers)
# Let's add a 2nd positive click at (x, y) = (250, 220) to refine the mask
# sending all clicks (and their labels) to `add_new_points_or_box`
points = np.array([[600, 350], [900, 350]], dtype=np.float32)
# for labels, `1` means positive click and `0` means negative click
labels = np.array([1, 1], np.int32)
_, out_obj_ids, out_mask_logits = predictor.add_new_points_or_box(
inference_state=inference_state,
frame_idx=ann_frame_idx,
obj_id=ann_obj_id,
points=points,
labels=labels,
)
# show the results on the current (interacted) frame
plt.figure(figsize=(9, 6))
plt.title(f"frame {ann_frame_idx}")
plt.imshow(Image.open(os.path.join(video_dir, frame_names[ann_frame_idx])))
show_points(points, labels, plt.gca())
show_mask((out_mask_logits[0] > 0.0).cpu().numpy(), plt.gca(), obj_id=out_obj_ids[0])
观察图像,目标的2个正点击都是选择脚,已经把鞋带部分完全排除,输出:
将已确定状态的预测器(predictor)
,应用于视频的多帧目标分割中,自动确定目标实例,即:
- 调用函数
predictor.propagate_in_video(inference_state)
预测全部的视频帧效果。 - 结果缓存至
video_segments
,用于绘制。
即:
# run propagation throughout the video and collect the results in a dict
video_segments = {} # video_segments contains the per-frame segmentation results
for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(inference_state):
video_segments[out_frame_idx] = {
out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
for i, out_obj_id in enumerate(out_obj_ids)
}
# render the segmentation results every few frames
vis_frame_stride = 20
plt.close("all")
# 多帧检测逻辑
grid_size = (2, 5)
fig, axs = plt.subplots(grid_size[0], grid_size[1], figsize=(24, 7))
for idx, out_frame_idx in enumerate(range(0, len(frame_names), vis_frame_stride)):
row = idx // grid_size[1]
col = idx % grid_size[1]
ax = axs[row, col]
ax.set_title(f"frame {out_frame_idx}")
ax.imshow(Image.open(os.path.join(video_dir, frame_names[out_frame_idx])))
for out_obj_id, out_mask in video_segments[out_frame_idx].items():
show_mask(out_mask, ax, obj_id=out_obj_id)
ax.axis('on')
plt.tight_layout()
plt.show()
观察图像,视频的每一帧都分割精确,即:
3. 目标分割 - 框(Box)
重置视频状态,使用 框(box)
选择目标,调用 predictor.add_new_points_or_box()
,进行目标分割,即:
- 点的参数是
points
和labels
,即[(x, y)]
和[0\1]
。 - 框的参数是
box
,即(x_min, y_min, x_max, y_max)
。
即:
predictor.reset_state(inference_state) # 提前清空实例状态
ann_frame_idx = 0 # the frame index we interact with
ann_obj_id = 4 # give a unique id to each object we interact with (it can be any integers)
# Let's add a box at (x_min, y_min, x_max, y_max) = (300, 0, 500, 400) to get started
box = np.array([290, 100, 1000, 580], dtype=np.float32)
_, out_obj_ids, out_mask_logits = predictor.add_new_points_or_box(
inference_state=inference_state,
frame_idx=ann_frame_idx,
obj_id=ann_obj_id,
box=box,
)
# show the results on the current (interacted) frame
plt.figure(figsize=(9, 6))
plt.title(f"frame {ann_frame_idx}")
plt.imshow(Image.open(os.path.join(video_dir, frame_names[ann_frame_idx])))
show_box(box, plt.gca())
show_mask((out_mask_logits[0] > 0.0).cpu().numpy(), plt.gca(), obj_id=out_obj_ids[0])
观察图像,在输入的框(box)中,进行目标分割,目标选择鞋,如下:
也支持同时,使用 框(box)
和 点(points&labels)
进行目标分割,可以更加突出 框(box) 中的待分割实例主体,即:
ann_frame_idx = 0 # the frame index we interact with
ann_obj_id = 4 # give a unique id to each object we interact with (it can be any integers)
# Let's add a positive click at (x, y) = (460, 60) to refine the mask
points = np.array([[900, 60]], dtype=np.float32)
# for labels, `1` means positive click and `0` means negative click
labels = np.array([1], np.int32)
# note that we also need to send the original box input along with
# the new refinement click together into `add_new_points_or_box`
box = np.array([300, 0, 1050, 450], dtype=np.float32)
_, out_obj_ids, out_mask_logits = predictor.add_new_points_or_box(
inference_state=inference_state,
frame_idx=ann_frame_idx,
obj_id=ann_obj_id,
points=points,
labels=labels,
box=box,
)
# show the results on the current (interacted) frame
plt.figure(figsize=(9, 6))
plt.title(f"frame {ann_frame_idx}")
plt.imshow(Image.open(os.path.join(video_dir, frame_names[ann_frame_idx])))
show_box(box, plt.gca())
show_points(points, labels, plt.gca())
show_mask((out_mask_logits[0] > 0.0).cpu().numpy(), plt.gca(), obj_id=out_obj_ids[0])
观察图像,更加明确主体,即:
将已确定状态的预测器(predictor)
,应用于视频的多帧目标分割中,自动确定目标实例,即:
# run propagation throughout the video and collect the results in a dict
video_segments = {} # video_segments contains the per-frame segmentation results
for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(inference_state):
video_segments[out_frame_idx] = {
out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
for i, out_obj_id in enumerate(out_obj_ids)
}
# render the segmentation results every few frames
vis_frame_stride = 20
plt.close("all")
grid_size = (2, 5)
fig, axs = plt.subplots(grid_size[0], grid_size[1], figsize=(24, 7))
for idx, out_frame_idx in enumerate(range(0, len(frame_names), vis_frame_stride)):
row = idx // grid_size[1]
col = idx % grid_size[1]
ax = axs[row, col]
ax.set_title(f"frame {out_frame_idx}")
ax.imshow(Image.open(os.path.join(video_dir, frame_names[out_frame_idx])))
for out_obj_id, out_mask in video_segments[out_frame_idx].items():
show_mask(out_mask, ax, obj_id=out_obj_id)
ax.axis('on')
观察图像,之前分割的是脚,现在是鞋,同样可以精准分割,效果如下:
4. 目标分割 - 多实例(Multi-Instance)
重新清空实例状态(reset_state
),其中 prompts
缓存多实例信息,用于绘图,继续使用正点击(label=1
)选择待分割的目标,即:
predictor.reset_state(inference_state)
prompts = {} # hold all the clicks we add for visualization
ann_frame_idx = 0 # the frame index we interact with
ann_obj_id = 2 # give a unique id to each object we interact with (it can be any integers)
# Let's add a positive click at (x, y) = (200, 300) to get started on the first object
points = np.array([[900, 500]], dtype=np.float32)
# for labels, `1` means positive click and `0` means negative click
labels = np.array([1], np.int32)
prompts[ann_obj_id] = points, labels
_, out_obj_ids, out_mask_logits = predictor.add_new_points_or_box(
inference_state=inference_state,
frame_idx=ann_frame_idx,
obj_id=ann_obj_id,
points=points,
labels=labels,
)
# show the results on the current (interacted) frame
plt.figure(figsize=(9, 6))
plt.title(f"frame {ann_frame_idx}")
plt.imshow(Image.open(os.path.join(video_dir, frame_names[ann_frame_idx])))
show_points(points, labels, plt.gca())
for i, out_obj_id in enumerate(out_obj_ids):
show_points(*prompts[out_obj_id], plt.gca())
show_mask((out_mask_logits[i] > 0.0).cpu().numpy(), plt.gca(), obj_id=out_obj_id)
观察图像,鞋的分割质量一般,如下:
在 正点击(label=1)
的基础上,再加上 负点击(label=0)
,效果提升,即:
# add the first object
ann_frame_idx = 0 # the frame index we interact with
ann_obj_id = 2 # give a unique id to each object we interact with (it can be any integers)
# Let's add a 2nd negative click at (x, y) = (275, 175) to refine the first object
# sending all clicks (and their labels) to `add_new_points_or_box`
points = np.array([[900, 500], [600, 350], [800, 200]], dtype=np.float32)
# for labels, `1` means positive click and `0` means negative click
labels = np.array([1, 0, 0], np.int32)
prompts[ann_obj_id] = points, labels
_, out_obj_ids, out_mask_logits = predictor.add_new_points_or_box(
inference_state=inference_state,
frame_idx=ann_frame_idx,
obj_id=ann_obj_id,
points=points,
labels=labels,
)
# show the results on the current (interacted) frame
plt.figure(figsize=(9, 6))
plt.title(f"frame {ann_frame_idx}")
plt.imshow(Image.open(os.path.join(video_dir, frame_names[ann_frame_idx])))
show_points(points, labels, plt.gca())
for i, out_obj_id in enumerate(out_obj_ids):
show_points(*prompts[out_obj_id], plt.gca())
show_mask((out_mask_logits[i] > 0.0).cpu().numpy(), plt.gca(), obj_id=out_obj_id)
观察图像,选择鞋,去除脚,如下:
在第1个实例(ann_obj_id = 2
)的基础上,再添加第 2 个实例(ann_obj_id = 3
),即
ann_frame_idx = 0 # the frame index we interact with
ann_obj_id = 3 # give a unique id to each object we interact with (it can be any integers)
# Let's now move on to the second object we want to track (giving it object id `3`)
# with a positive click at (x, y) = (400, 150)
points = np.array([[900, 350]], dtype=np.float32)
# for labels, `1` means positive click and `0` means negative click
labels = np.array([1], np.int32)
prompts[ann_obj_id] = points, labels
# `add_new_points_or_box` returns masks for all objects added so far on this interacted frame
_, out_obj_ids, out_mask_logits = predictor.add_new_points_or_box(
inference_state=inference_state,
frame_idx=ann_frame_idx,
obj_id=ann_obj_id,
points=points,
labels=labels,
)
# show the results on the current (interacted) frame on all objects
plt.figure(figsize=(9, 6))
plt.title(f"frame {ann_frame_idx}")
plt.imshow(Image.open(os.path.join(video_dir, frame_names[ann_frame_idx])))
show_points(points, labels, plt.gca())
for i, out_obj_id in enumerate(out_obj_ids):
show_points(*prompts[out_obj_id], plt.gca())
show_mask((out_mask_logits[i] > 0.0).cpu().numpy(), plt.gca(), obj_id=out_obj_id)
观察图像,2个实例,红色的是脚,绿色的是鞋,效果如下:
将已确定状态的预测器(predictor)
,即包括2个实例目标,应用于视频的多帧目标分割中,自动确定目标实例,即::
# run propagation throughout the video and collect the results in a dict
video_segments = {} # video_segments contains the per-frame segmentation results
for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(inference_state):
video_segments[out_frame_idx] = {
out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
for i, out_obj_id in enumerate(out_obj_ids)
}
# render the segmentation results every few frames
vis_frame_stride = 20
plt.close("all")
grid_size = (2, 5)
fig, axs = plt.subplots(grid_size[0], grid_size[1], figsize=(24, 7))
for idx, out_frame_idx in enumerate(range(0, len(frame_names), vis_frame_stride)):
row = idx // grid_size[1]
col = idx % grid_size[1]
ax = axs[row, col]
ax.set_title(f"frame {out_frame_idx}")
ax.imshow(Image.open(os.path.join(video_dir, frame_names[out_frame_idx])))
for out_obj_id, out_mask in video_segments[out_frame_idx].items():
show_mask(out_mask, ax, obj_id=out_obj_id)
ax.axis('on')
其中,2个实例(鞋和脚) 的 视频分割,效果如下:
5. 其他
视频处理网站:https://online-video-cutter.com/cn/
视频处理脚本,注意解码器必须选择 H264,否则播放异常,即:
fourcc = cv2.VideoWriter_fourcc(*'H264')
import os
import cv2
import numpy as np
def process_video(input_video_path, output_video_path, frame_output_dir, start_time, frame_count, frame_interval):
# 创建输出文件夹
if not os.path.exists(frame_output_dir):
os.makedirs(frame_output_dir)
# 打开视频文件
cap = cv2.VideoCapture(input_video_path)
fps = int(cap.get(cv2.CAP_PROP_FPS))
start_frame = start_time * fps
# 设置开始帧
cap.set(cv2.CAP_PROP_POS_FRAMES, start_frame)
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
print("Total frames: ", total_frames, fps)
# 获取视频的宽度和高度
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
# 定义视频编解码器和输出文件
fourcc = cv2.VideoWriter_fourcc(*'H264')
out = cv2.VideoWriter(output_video_path, fourcc, fps, (width, height))
extracted_frames = 0
current_frame = 0
while extracted_frames < frame_count:
ret, frame = cap.read()
if not ret:
break
if current_frame % frame_interval == 0:
# 保存每一帧到文件夹
frame_filename = os.path.join(frame_output_dir, f'{extracted_frames + 1:05d}.jpg')
cv2.imwrite(frame_filename, frame)
frame = frame.astype(np.uint8)
# 写入到新视频文件中
out.write(frame)
extracted_frames += 1
current_frame += 1
# 释放资源
cap.release()
out.release()
cv2.destroyAllWindows()
def main():
# 示例使用
input_video_path = 'video.mp4'
output_video_path = 'output_video.mp4'
frame_output_dir = 'output_frames'
start_time = 3 # 从第3秒开始
frame_count = 200 # 截取200帧
frame_interval = 4 # 每隔20帧取1帧
process_video(input_video_path, output_video_path, frame_output_dir, start_time, frame_count, frame_interval)
if __name__ == '__main__':
main()
参考:Python cv2 .mp4 codec unable to be displayed in browser