当前位置: 首页 > article >正文

CV - 图像实例分割开源算法 SAM2(Segment Anything) 视频分割 教程 (2)

欢迎关注我的CSDN:https://spike.blog.csdn.net/
本文地址:https://spike.blog.csdn.net/article/details/143220597

免责声明:本文来源于个人知识与公开资料,仅用于学术交流,欢迎讨论,不支持转载。


SAM2

SAM2(Segment Anything Model 2),在视频分割领域中,引入流式内存架构,实现实时视频处理,提高分割精度,减少用户交互的需求,使其在图像和视频中的对象识别和分割任务中表现出色,为各种下游应用打下了坚实的基础。

1. 导入环境

导入 Python 包,以及引入 CUDA 配置:

import os
# if using Apple MPS, fall back to CPU for unsupported ops
# os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
import numpy as np
import torch
import matplotlib.pyplot as plt
from PIL import Image

# select the device for computation
if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")
print(f"using device: {device}")

if device.type == "cuda":
    # use bfloat16 for the entire notebook
    torch.autocast("cuda", dtype=torch.bfloat16).__enter__()
    # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
    if torch.cuda.get_device_properties(0).major >= 8:
        torch.backends.cuda.matmul.allow_tf32 = True
        torch.backends.cudnn.allow_tf32 = True
elif device.type == "mps":
    print(
        "\nSupport for MPS devices is preliminary. SAM 2 is trained with CUDA and might "
        "give numerically different outputs and sometimes degraded performance on MPS. "
        "See e.g. https://github.com/pytorch/pytorch/issues/84936 for a discussion."
    )

加载 SAM2 模型:

from sam2.build_sam import build_sam2_video_predictor

sam2_checkpoint = "sam2.1_hiera_large.pt"
model_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml"

predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint, device=device)

相关细节参考:实例分割开源算法 SAM2(Segment Anything 2) 配置与推理 教程 (1)

视觉信息展示函数:

  • show_mask 展示 mask
  • show_points 展示标注点,区分正例点和负例点
  • show_box 展示 box

即:

def show_mask(mask, ax, obj_id=None, random_color=False):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        cmap = plt.get_cmap("tab10")
        cmap_idx = 0 if obj_id is None else obj_id
        color = np.array([*cmap(cmap_idx)[:3], 0.6])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)


def show_points(coords, labels, ax, marker_size=200):
    pos_points = coords[labels==1]
    neg_points = coords[labels==0]
    ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
    ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)


def show_box(box, ax):
    x0, y0 = box[0], box[1]
    w, h = box[2] - box[0], box[3] - box[1]
    ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0, 0, 0, 0), lw=2))

显示测试视频单帧,测试视频来源于,模特试穿鞋的商品展示,即:

# `video_dir` a directory of JPEG frames with filenames like `<frame_index>.jpg`
video_dir = "[your folder]/sam2/mydata/test_video"

# scan all the JPEG frame names in this directory
frame_names = [
    p for p in os.listdir(video_dir)
    if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"]
]
frame_names.sort(key=lambda p: int(os.path.splitext(p)[0]))

# take a look the first video frame
frame_idx = 0
plt.figure(figsize=(9, 6))
plt.title(f"frame {frame_idx}")
plt.imshow(Image.open(os.path.join(video_dir, frame_names[frame_idx])))

下载图像:选择 JupyterLab 的 Cell,使用 Shift + Command + 双击 下载图像。


2. 目标分割 - 正点击(Positive Click)

加载视频的全部图像帧,以及重置实例状态,SAM2 保持实例的记忆功能,每次清空都需要重置状态(reset_state),即:

inference_state = predictor.init_state(video_path=video_dir)
predictor.reset_state(inference_state)

使用 正点击(Positive Click) 作为输入,获得相关区域的 mask,即:

  • ann_frame_idx = 0,测试帧
  • ann_obj_id = 1,目标 ID,即目标的唯一标识
  • points = np.array([[600, 350]], dtype=np.float32),输入点的位置,注意是一个列表,可以输入多个点。
  • labels = np.array([1], np.int32),输入点的标签,1代表正例,0代表负例,即1是目标区域,0是非目标区域,避免覆盖过大。
  • predictor.add_new_points_or_box(),根据 点(points) 或 框(box),测试帧 frame_idx=ann_frame_idx,生成 mask

即:

ann_frame_idx = 0  # the frame index we interact with
ann_obj_id = 1  # give a unique id to each object we interact with (it can be any integers)

# Let's add a positive click at (x, y) = (210, 350) to get started
points = np.array([[600, 350]], dtype=np.float32)
# for labels, `1` means positive click and `0` means negative click
labels = np.array([1], np.int32)
_, out_obj_ids, out_mask_logits = predictor.add_new_points_or_box(
    inference_state=inference_state,
    frame_idx=ann_frame_idx,
    obj_id=ann_obj_id,
    points=points,
    labels=labels,
)

# show the results on the current (interacted) frame
plt.figure(figsize=(9, 6))
plt.title(f"frame {ann_frame_idx}")
plt.imshow(Image.open(os.path.join(video_dir, frame_names[ann_frame_idx])))
show_points(points, labels, plt.gca())
show_mask((out_mask_logits[0] > 0.0).cpu().numpy(), plt.gca(), obj_id=out_obj_ids[0])

观察图像,目标选择脚,但是把鞋带部分带入实例中,如下:

Positive

使用2个正点击,精细确定实例范围,即[[600, 350], [900, 350]],使得轮廓更加精准,即:

ann_frame_idx = 0  # the frame index we interact with
ann_obj_id = 1  # give a unique id to each object we interact with (it can be any integers)

# Let's add a 2nd positive click at (x, y) = (250, 220) to refine the mask
# sending all clicks (and their labels) to `add_new_points_or_box`
points = np.array([[600, 350], [900, 350]], dtype=np.float32)
# for labels, `1` means positive click and `0` means negative click
labels = np.array([1, 1], np.int32)
_, out_obj_ids, out_mask_logits = predictor.add_new_points_or_box(
    inference_state=inference_state,
    frame_idx=ann_frame_idx,
    obj_id=ann_obj_id,
    points=points,
    labels=labels,
)

# show the results on the current (interacted) frame
plt.figure(figsize=(9, 6))
plt.title(f"frame {ann_frame_idx}")
plt.imshow(Image.open(os.path.join(video_dir, frame_names[ann_frame_idx])))
show_points(points, labels, plt.gca())
show_mask((out_mask_logits[0] > 0.0).cpu().numpy(), plt.gca(), obj_id=out_obj_ids[0])

观察图像,目标的2个正点击都是选择脚,已经把鞋带部分完全排除,输出:

Positive

将已确定状态的预测器(predictor),应用于视频的多帧目标分割中,自动确定目标实例,即:

  • 调用函数 predictor.propagate_in_video(inference_state) 预测全部的视频帧效果。
  • 结果缓存至 video_segments,用于绘制。

即:

# run propagation throughout the video and collect the results in a dict
video_segments = {}  # video_segments contains the per-frame segmentation results
for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(inference_state):
    video_segments[out_frame_idx] = {
        out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
        for i, out_obj_id in enumerate(out_obj_ids)
    }

# render the segmentation results every few frames
vis_frame_stride = 20
plt.close("all")

# 多帧检测逻辑
grid_size = (2, 5)
fig, axs = plt.subplots(grid_size[0], grid_size[1], figsize=(24, 7))
for idx, out_frame_idx in enumerate(range(0, len(frame_names), vis_frame_stride)):
    row = idx // grid_size[1]
    col = idx % grid_size[1]
    ax = axs[row, col]
    ax.set_title(f"frame {out_frame_idx}")
    ax.imshow(Image.open(os.path.join(video_dir, frame_names[out_frame_idx])))
    for out_obj_id, out_mask in video_segments[out_frame_idx].items():
        show_mask(out_mask, ax, obj_id=out_obj_id)
    ax.axis('on')

plt.tight_layout()
plt.show()

观察图像,视频的每一帧都分割精确,即:

Positive


3. 目标分割 - 框(Box)

重置视频状态,使用 框(box) 选择目标,调用 predictor.add_new_points_or_box(),进行目标分割,即:

  • 点的参数是 pointslabels,即 [(x, y)][0\1]
  • 框的参数是 box,即 (x_min, y_min, x_max, y_max)

即:

predictor.reset_state(inference_state)  # 提前清空实例状态

ann_frame_idx = 0  # the frame index we interact with
ann_obj_id = 4  # give a unique id to each object we interact with (it can be any integers)

# Let's add a box at (x_min, y_min, x_max, y_max) = (300, 0, 500, 400) to get started
box = np.array([290, 100, 1000, 580], dtype=np.float32)
_, out_obj_ids, out_mask_logits = predictor.add_new_points_or_box(
    inference_state=inference_state,
    frame_idx=ann_frame_idx,
    obj_id=ann_obj_id,
    box=box,
)

# show the results on the current (interacted) frame
plt.figure(figsize=(9, 6))
plt.title(f"frame {ann_frame_idx}")
plt.imshow(Image.open(os.path.join(video_dir, frame_names[ann_frame_idx])))
show_box(box, plt.gca())
show_mask((out_mask_logits[0] > 0.0).cpu().numpy(), plt.gca(), obj_id=out_obj_ids[0])

观察图像,在输入的框(box)中,进行目标分割,目标选择鞋,如下:

Box

也支持同时,使用 框(box)点(points&labels) 进行目标分割,可以更加突出 框(box) 中的待分割实例主体,即:

ann_frame_idx = 0  # the frame index we interact with
ann_obj_id = 4  # give a unique id to each object we interact with (it can be any integers)

# Let's add a positive click at (x, y) = (460, 60) to refine the mask
points = np.array([[900, 60]], dtype=np.float32)
# for labels, `1` means positive click and `0` means negative click
labels = np.array([1], np.int32)
# note that we also need to send the original box input along with
# the new refinement click together into `add_new_points_or_box`
box = np.array([300, 0, 1050, 450], dtype=np.float32)
_, out_obj_ids, out_mask_logits = predictor.add_new_points_or_box(
    inference_state=inference_state,
    frame_idx=ann_frame_idx,
    obj_id=ann_obj_id,
    points=points,
    labels=labels,
    box=box,
)

# show the results on the current (interacted) frame
plt.figure(figsize=(9, 6))
plt.title(f"frame {ann_frame_idx}")
plt.imshow(Image.open(os.path.join(video_dir, frame_names[ann_frame_idx])))
show_box(box, plt.gca())
show_points(points, labels, plt.gca())
show_mask((out_mask_logits[0] > 0.0).cpu().numpy(), plt.gca(), obj_id=out_obj_ids[0])

观察图像,更加明确主体,即:

Img

将已确定状态的预测器(predictor),应用于视频的多帧目标分割中,自动确定目标实例,即:

# run propagation throughout the video and collect the results in a dict
video_segments = {}  # video_segments contains the per-frame segmentation results
for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(inference_state):
    video_segments[out_frame_idx] = {
        out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
        for i, out_obj_id in enumerate(out_obj_ids)
    }

# render the segmentation results every few frames
vis_frame_stride = 20
plt.close("all")
grid_size = (2, 5)
fig, axs = plt.subplots(grid_size[0], grid_size[1], figsize=(24, 7))
for idx, out_frame_idx in enumerate(range(0, len(frame_names), vis_frame_stride)):
    row = idx // grid_size[1]
    col = idx % grid_size[1]
    ax = axs[row, col]
    ax.set_title(f"frame {out_frame_idx}")
    ax.imshow(Image.open(os.path.join(video_dir, frame_names[out_frame_idx])))
    for out_obj_id, out_mask in video_segments[out_frame_idx].items():
        show_mask(out_mask, ax, obj_id=out_obj_id)
    ax.axis('on')

观察图像,之前分割的是脚,现在是鞋,同样可以精准分割,效果如下:

Box


4. 目标分割 - 多实例(Multi-Instance)

重新清空实例状态(reset_state),其中 prompts 缓存多实例信息,用于绘图,继续使用正点击(label=1)选择待分割的目标,即:

predictor.reset_state(inference_state)

prompts = {}  # hold all the clicks we add for visualization

ann_frame_idx = 0  # the frame index we interact with
ann_obj_id = 2  # give a unique id to each object we interact with (it can be any integers)

# Let's add a positive click at (x, y) = (200, 300) to get started on the first object
points = np.array([[900, 500]], dtype=np.float32)
# for labels, `1` means positive click and `0` means negative click
labels = np.array([1], np.int32)
prompts[ann_obj_id] = points, labels
_, out_obj_ids, out_mask_logits = predictor.add_new_points_or_box(
    inference_state=inference_state,
    frame_idx=ann_frame_idx,
    obj_id=ann_obj_id,
    points=points,
    labels=labels,
)

# show the results on the current (interacted) frame
plt.figure(figsize=(9, 6))
plt.title(f"frame {ann_frame_idx}")
plt.imshow(Image.open(os.path.join(video_dir, frame_names[ann_frame_idx])))
show_points(points, labels, plt.gca())
for i, out_obj_id in enumerate(out_obj_ids):
    show_points(*prompts[out_obj_id], plt.gca())
    show_mask((out_mask_logits[i] > 0.0).cpu().numpy(), plt.gca(), obj_id=out_obj_id)

观察图像,鞋的分割质量一般,如下:

Positive

正点击(label=1) 的基础上,再加上 负点击(label=0),效果提升,即:

# add the first object
ann_frame_idx = 0  # the frame index we interact with
ann_obj_id = 2  # give a unique id to each object we interact with (it can be any integers)

# Let's add a 2nd negative click at (x, y) = (275, 175) to refine the first object
# sending all clicks (and their labels) to `add_new_points_or_box`
points = np.array([[900, 500], [600, 350], [800, 200]], dtype=np.float32)
# for labels, `1` means positive click and `0` means negative click
labels = np.array([1, 0, 0], np.int32)
prompts[ann_obj_id] = points, labels
_, out_obj_ids, out_mask_logits = predictor.add_new_points_or_box(
    inference_state=inference_state,
    frame_idx=ann_frame_idx,
    obj_id=ann_obj_id,
    points=points,
    labels=labels,
)

# show the results on the current (interacted) frame
plt.figure(figsize=(9, 6))
plt.title(f"frame {ann_frame_idx}")
plt.imshow(Image.open(os.path.join(video_dir, frame_names[ann_frame_idx])))
show_points(points, labels, plt.gca())
for i, out_obj_id in enumerate(out_obj_ids):
    show_points(*prompts[out_obj_id], plt.gca())
    show_mask((out_mask_logits[i] > 0.0).cpu().numpy(), plt.gca(), obj_id=out_obj_id)

观察图像,选择鞋,去除脚,如下:

Negtive

在第1个实例(ann_obj_id = 2)的基础上,再添加第 2 个实例(ann_obj_id = 3),即

ann_frame_idx = 0  # the frame index we interact with
ann_obj_id = 3  # give a unique id to each object we interact with (it can be any integers)

# Let's now move on to the second object we want to track (giving it object id `3`)
# with a positive click at (x, y) = (400, 150)
points = np.array([[900, 350]], dtype=np.float32)
# for labels, `1` means positive click and `0` means negative click
labels = np.array([1], np.int32)
prompts[ann_obj_id] = points, labels

# `add_new_points_or_box` returns masks for all objects added so far on this interacted frame
_, out_obj_ids, out_mask_logits = predictor.add_new_points_or_box(
    inference_state=inference_state,
    frame_idx=ann_frame_idx,
    obj_id=ann_obj_id,
    points=points,
    labels=labels,
)

# show the results on the current (interacted) frame on all objects
plt.figure(figsize=(9, 6))
plt.title(f"frame {ann_frame_idx}")
plt.imshow(Image.open(os.path.join(video_dir, frame_names[ann_frame_idx])))
show_points(points, labels, plt.gca())
for i, out_obj_id in enumerate(out_obj_ids):
    show_points(*prompts[out_obj_id], plt.gca())
    show_mask((out_mask_logits[i] > 0.0).cpu().numpy(), plt.gca(), obj_id=out_obj_id)

观察图像,2个实例,红色的是脚,绿色的是鞋,效果如下:

Instance

将已确定状态的预测器(predictor),即包括2个实例目标,应用于视频的多帧目标分割中,自动确定目标实例,即::

# run propagation throughout the video and collect the results in a dict
video_segments = {}  # video_segments contains the per-frame segmentation results
for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(inference_state):
    video_segments[out_frame_idx] = {
        out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
        for i, out_obj_id in enumerate(out_obj_ids)
    }

# render the segmentation results every few frames
vis_frame_stride = 20
plt.close("all")
grid_size = (2, 5)
fig, axs = plt.subplots(grid_size[0], grid_size[1], figsize=(24, 7))
for idx, out_frame_idx in enumerate(range(0, len(frame_names), vis_frame_stride)):
    row = idx // grid_size[1]
    col = idx % grid_size[1]
    ax = axs[row, col]
    ax.set_title(f"frame {out_frame_idx}")
    ax.imshow(Image.open(os.path.join(video_dir, frame_names[out_frame_idx])))
    for out_obj_id, out_mask in video_segments[out_frame_idx].items():
        show_mask(out_mask, ax, obj_id=out_obj_id)
    ax.axis('on')

其中,2个实例(鞋和脚) 的 视频分割,效果如下:

Instances

5. 其他

视频处理网站:https://online-video-cutter.com/cn/

视频处理脚本,注意解码器必须选择 H264,否则播放异常,即:

  • fourcc = cv2.VideoWriter_fourcc(*'H264')
import os

import cv2
import numpy as np


def process_video(input_video_path, output_video_path, frame_output_dir, start_time, frame_count, frame_interval):
    # 创建输出文件夹
    if not os.path.exists(frame_output_dir):
        os.makedirs(frame_output_dir)

    # 打开视频文件
    cap = cv2.VideoCapture(input_video_path)
    fps = int(cap.get(cv2.CAP_PROP_FPS))
    start_frame = start_time * fps

    # 设置开始帧
    cap.set(cv2.CAP_PROP_POS_FRAMES, start_frame)
    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    print("Total frames: ", total_frames, fps)

    # 获取视频的宽度和高度
    width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))

    # 定义视频编解码器和输出文件
    fourcc = cv2.VideoWriter_fourcc(*'H264')
    out = cv2.VideoWriter(output_video_path, fourcc, fps, (width, height))

    extracted_frames = 0
    current_frame = 0

    while extracted_frames < frame_count:
        ret, frame = cap.read()
        if not ret:
            break

        if current_frame % frame_interval == 0:
            # 保存每一帧到文件夹
            frame_filename = os.path.join(frame_output_dir, f'{extracted_frames + 1:05d}.jpg')
            cv2.imwrite(frame_filename, frame)
            frame = frame.astype(np.uint8)

            # 写入到新视频文件中
            out.write(frame)
            extracted_frames += 1

        current_frame += 1

    # 释放资源
    cap.release()
    out.release()
    cv2.destroyAllWindows()





def main():
    # 示例使用
    input_video_path = 'video.mp4'
    output_video_path = 'output_video.mp4'
    frame_output_dir = 'output_frames'
    start_time = 3  # 从第3秒开始
    frame_count = 200  # 截取200帧
    frame_interval = 4  # 每隔20帧取1帧

    process_video(input_video_path, output_video_path, frame_output_dir, start_time, frame_count, frame_interval)


if __name__ == '__main__':
    main()

参考:Python cv2 .mp4 codec unable to be displayed in browser


http://www.kler.cn/a/370279.html

相关文章:

  • 《目标检测数据集下载地址》
  • 记录一下OpenCV Contrib 编译踩的坑
  • Axios 封装:处理重复调用与内容覆盖问题
  • (7)(7.2) 围栏
  • Golang Gin系列-1:Gin 框架总体概述
  • 如何下载对应城市的地理json文件
  • electron Debian arm64 linux设备打包deb安装包 遇到的fpm问题
  • 基于深度学习算法的动物检测系统(含PyQt+代码+训练数据集)
  • 反编译华为-研究功耗联网监控日志
  • 3.1.4 Hyperspace 的临时映射1
  • Golang | Leetcode Golang题解之第509题斐波那契数
  • HttpServer模块 --- 封装TcpServer支持Http协议
  • 基于neo4j的鸟类百科知识图谱问答系统
  • QT 中彻底解决中文乱码问题的指南
  • appium文本输入的多种形式
  • 使用微信免费的内容安全识别接口,UGC场景开发检测违规内容功能
  • Claude 3.5新模型发布:Sonnet与Haiku双雄登场,助力开发者高效创作
  • python基础(类、实例、属性、方法)
  • vue3 中 props 使用 ts 类型定义复杂类型
  • SVN常用命令
  • Android 下载进度条HorizontalProgressView 基础版
  • Docker 部署MongoDb
  • 【网路原理】——HTTP状态码和Postman使用
  • 【Vscode】设置
  • Unity自定义数组在Inspector窗口的显示方式
  • 【10天速通Navigation2】(四) :ORB-SLAM3的ROS2 humble编译和配置