当前位置: 首页 > article >正文

Vision - 开源视觉分割算法框架 Grounded SAM2 视频推理 教程 (2)

欢迎关注我的CSDN:https://spike.blog.csdn.net/
本文地址:https://spike.blog.csdn.net/article/details/143388189

免责声明:本文来源于个人知识与公开资料,仅用于学术交流,欢迎讨论,不支持转载。


Grounded SAM

Grounded SAM2 集成多个先进模型的视觉 AI 框架,融合 GroundingDINO、Florence-2 和 SAM2 等模型,实现开放域目标检测、分割和跟踪等多项视觉任务的突破性进展,通过自然语言描述来定位图像中的目标,生成精细的目标分割掩码,在视频序列中持续跟踪目标,保持 ID 的一致性。

Paper: Grounded SAM: Assembling Open-World Models for Diverse Visual Tasks,SAM 版本由 1.0 升级至 2.0

环境配置,参考:开源视觉分割算法框架 Grounded SAM2 配置与推理 教程


导入 Python 的包文件,核心库包括 torchsam2grounding_dino 等,即:

import os
import cv2
import torch
import numpy as np
import supervision as sv
from torchvision.ops import box_convert
from pathlib import Path
from tqdm import tqdm
from PIL import Image
from sam2.build_sam import build_sam2_video_predictor, build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor 
from grounding_dino.groundingdino.util.inference import load_model, load_image, predict
from utils.track_utils import sample_points_from_masks
from utils.video_utils import create_video_from_images

import matplotlib.pyplot as plt

流程中的模型,配置参数,其中:

  • GroundingDINO 模型,选择更大的 SwinB 模型,效果对比于 SwinT 模型,效果提升明显,注意同步修改配置。
  • 提示词 (TEXT_PROMPT),选择 "shoes. legs.",鞋和腿
  • 提示类型选择(PROMPT_TYPE_FOR_VIDEO) box

即:

"""
Hyperparam for Ground and Tracking
"""
GROUNDING_DINO_CONFIG = "grounding_dino/groundingdino/config/GroundingDINO_SwinB_cfg.py"
GROUNDING_DINO_CHECKPOINT = "gdino_checkpoints/groundingdino_swinb_cogcoor.pth"
BOX_THRESHOLD = 0.35
TEXT_THRESHOLD = 0.25
VIDEO_PATH = "/nfs_beijing_ai/chenlong/llm/vision_test_data/video5_shoes.mp4"
TEXT_PROMPT = "shoes. person."
OUTPUT_VIDEO_PATH = "./video5_shoes_demo.mp4"
SOURCE_VIDEO_FRAME_DIR = "./video5_custom_video_frames"
SAVE_TRACKING_RESULTS_DIR = "./video5_tracking_results"
PROMPT_TYPE_FOR_VIDEO = "box" # choose from ["point", "box", "mask"]
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"[Info] DEVICE: {DEVICE}")

其中,groundingdino_swinb_cogcoor.pthgroundingdino_swint_ogc.pth 主要区别在于大小和性能,都是基于 SwinTransformer 框架,但是参数数量和处理能力有所不同:

  1. groundingdino_swinb_cogcoor.pth 是 GroundingDINO 的 SwinB 模型,文件大小为 938MB。“SwinB” 代表 “SwinTransformer Base” 基础版本的模型,适用于处理复杂任务和提供更高准确性方面更有优势,在速度上可能会稍慢一些。
  2. groundingdino_swint_ogc.pth 是 GroundingDINO 的 SwinT 模型,文件大小为 694MB。“SwinT” 代表 “SwinTransformer” 的一个轻量级版本,相对 SwinB 模型来说,更轻量级,适用于对速度要求更高的场景,在准确性上略有牺牲。

配置 GroundingDINO 模型 和 SAM2 模型,即:

"""
Step 1: Environment settings and model initialization for Grounding DINO and SAM 2
"""
# https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha2/groundingdino_swinb_cogcoor.pth
# https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha/groundingdino_swint_ogc.pth
# build grounding dino model from local path
grounding_model = load_model(
    model_config_path=GROUNDING_DINO_CONFIG, 
    model_checkpoint_path=GROUNDING_DINO_CHECKPOINT,
    device=DEVICE
)

# init sam image predictor and video predictor model
sam2_checkpoint = "./checkpoints/sam2.1_hiera_large.pt"
model_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml"

video_predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint)
sam2_image_model = build_sam2(model_cfg, sam2_checkpoint)
image_predictor = SAM2ImagePredictor(sam2_image_model)

解析输入数据,视频信息包括 widthheightfps 帧率、total_frames 帧数,构建帧文件夹,写入帧文件,共写入 200 帧,即:

  • import supervision as sv,使用 supervision 库,进行视频处理。
"""
Custom video input directly using video files
"""
video_info = sv.VideoInfo.from_video_path(VIDEO_PATH)  # get video info
# VideoInfo(width=1280, height=720, fps=30, total_frames=200)
print(video_info)
frame_generator = sv.get_video_frames_generator(VIDEO_PATH, stride=1, start=0, end=None)

# saving video to frames
source_frames = Path(SOURCE_VIDEO_FRAME_DIR)
source_frames.mkdir(parents=True, exist_ok=True)

with sv.ImageSink(
    target_dir_path=source_frames, 
    overwrite=True, 
    image_name_pattern="{:05d}.jpg"
) as sink:
    for frame in tqdm(frame_generator, desc="Saving Video Frames"):
        sink.save_image(frame)

显示视频的多帧图像,输入图像文件夹(image_folder)、文件数量(total_images),即:

import os
import matplotlib.pyplot as plt
from PIL import Image

def show_images_in_grid(image_dir, total_images):
    # 获取图像文件列表
    image_files = sorted(os.listdir(image_dir))[:total_images]
    
    if total_images < 8:
        print("图像总数不足8张,请确保文件夹中至少有8张图像。")
        return

    num = 8
    # 选择展示的8张图像
    gap = total_images // num
    
    # 设置图像网格布局
    fig, axs = plt.subplots(2, 4, figsize=(15, 6))
    
    for idx in range(num):
        image_file = image_files[gap*idx]
        # 计算行和列的位置
        row = idx // 4
        col = idx % 4
        
        # 打开图像并显示
        image_path = os.path.join(image_dir, image_file)
        image = Image.open(image_path)
        
        axs[row, col].imshow(image)
        axs[row, col].set_title(f"Image {idx + 1}")
        axs[row, col].axis('on')

    plt.tight_layout()
    plt.show()

# 示例使用
image_dir = source_frames
total_images = video_info.total_frames  # 假设文件夹中有20张图像
show_images_in_grid(image_dir, total_images)

输入的视频图像集合:

Input

SAM2 构建 VideoPredictor 输入是 多个视频帧路径,即,

# scan all the JPEG frame names in this directory
frame_names = [
    p for p in os.listdir(SOURCE_VIDEO_FRAME_DIR)
    if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"]
]
frame_names.sort(key=lambda p: int(os.path.splitext(p)[0]))

# init video predictor state
inference_state = video_predictor.init_state(video_path=SOURCE_VIDEO_FRAME_DIR)
ann_frame_idx = 0  # the frame index we interact with

使用 GroundingDINO 检测出多个类别,注意模型 torch.float32,即:

"""
Step 2: Prompt Grounding DINO 1.5 with Cloud API for box coordinates
"""
# prompt grounding dino to get the box coordinates on specific frame
img_path = os.path.join(SOURCE_VIDEO_FRAME_DIR, frame_names[ann_frame_idx])
image_source, image = load_image(img_path)

# FIXME: figure how does this influence the G-DINO model
torch.autocast(device_type=DEVICE, dtype=torch.float32).__enter__()

print(f"[Info] TEXT_PROMPT: {TEXT_PROMPT}")
boxes, confidences, labels = predict(
    model=grounding_model,
    image=image,
    caption=TEXT_PROMPT,
    box_threshold=BOX_THRESHOLD,
    text_threshold=TEXT_THRESHOLD,
)

# process the box prompt for SAM 2
h, w, _ = image_source.shape
boxes = boxes * torch.Tensor([w, h, w, h])
input_boxes = box_convert(boxes=boxes, in_fmt="cxcywh", out_fmt="xyxy").numpy()
confidences = confidences.numpy().tolist()
class_names = labels

print(input_boxes)
# process the detection results
OBJECTS = class_names

print(f"[Info] OBJECTS: {OBJECTS}")

输出:

[Info] TEXT_PROMPT: shoes. legs.
[[2.9252551e+02 4.6411133e-01 1.0295570e+03 5.6271893e+02]
 [2.9232632e+02 1.8811917e+02 9.9246057e+02 5.6478638e+02]]
[Info] OBJECTS: ['legs', 'shoes']

显示标注之后的 Box 结果,即:

class_ids = np.array(list(range(len(class_names))))
img = cv2.imread(img_path)
detections = sv.Detections(
    xyxy=input_boxes,  # (n, 4)
    class_id=class_ids
)

box_annotator = sv.BoxAnnotator()
annotated_frame = box_annotator.annotate(scene=img.copy(), detections=detections)

label_annotator = sv.LabelAnnotator()
annotated_frame = label_annotator.annotate(scene=annotated_frame, detections=detections, labels=labels)
plt.figure(figsize=(9, 6))
plt.title(f"annotated_frame")
plt.imshow(annotated_frame[:,:,::-1])

第 1 帧的 GroundingDINO 的 检测框(BBox) 效果,如下:

BBox

SAM2 的检测结果,注意模型 torch.bfloat16,即:

image_predictor.set_image(image_source)
# FIXME: figure how does this influence the G-DINO model
torch.autocast(device_type=DEVICE, dtype=torch.bfloat16).__enter__()

if torch.cuda.get_device_properties(0).major >= 8:
    # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True

# prompt SAM 2 image predictor to get the mask for the object
masks, scores, logits = image_predictor.predict(
    point_coords=None,
    point_labels=None,
    box=input_boxes,
    multimask_output=False,
)
# convert the mask shape to (n, H, W)
if masks.ndim == 4:
    masks = masks.squeeze(1)

显示检测框和分割像素,即:

img = cv2.imread(img_path)
detections = sv.Detections(
    xyxy=input_boxes,  # (n, 4)
    mask=masks.astype(bool),  # (n, h, w)
    class_id=class_ids
)

box_annotator = sv.BoxAnnotator()
annotated_frame = box_annotator.annotate(scene=img.copy(), detections=detections)

mask_annotator = sv.MaskAnnotator()
annotated_frame = mask_annotator.annotate(scene=annotated_frame, detections=detections)
plt.figure(figsize=(9, 6))
plt.title(f"annotated_frame")
plt.imshow(annotated_frame[:,:,::-1])

第 1 帧的 GroundingDINO + SAM2 的检测框和分割区域效果,如下:

Seg

将目标对象,注入视频帧中,本例使用 box 类型,即:

"""
Step 3: Register each object's positive points to video predictor with seperate add_new_points call
"""

assert PROMPT_TYPE_FOR_VIDEO in ["point", "box", "mask"], "SAM 2 video predictor only support point/box/mask prompt"

# If you are using point prompts, we uniformly sample positive points based on the mask
if PROMPT_TYPE_FOR_VIDEO == "point":
    # sample the positive points from mask for each objects
    all_sample_points = sample_points_from_masks(masks=masks, num_points=10)

    for object_id, (label, points) in enumerate(zip(OBJECTS, all_sample_points), start=1):
        labels = np.ones((points.shape[0]), dtype=np.int32)
        _, out_obj_ids, out_mask_logits = video_predictor.add_new_points_or_box(
            inference_state=inference_state,
            frame_idx=ann_frame_idx,
            obj_id=object_id,
            points=points,
            labels=labels,
        )
# Using box prompt
elif PROMPT_TYPE_FOR_VIDEO == "box":
    for object_id, (label, box) in enumerate(zip(OBJECTS, input_boxes), start=1):
        _, out_obj_ids, out_mask_logits = video_predictor.add_new_points_or_box(
            inference_state=inference_state,
            frame_idx=ann_frame_idx,
            obj_id=object_id,
            box=box,
        )
# Using mask prompt is a more straightforward way
elif PROMPT_TYPE_FOR_VIDEO == "mask":
    for object_id, (label, mask) in enumerate(zip(OBJECTS, masks), start=1):
        labels = np.ones((1), dtype=np.int32)
        _, out_obj_ids, out_mask_logits = video_predictor.add_new_mask(
            inference_state=inference_state,
            frame_idx=ann_frame_idx,
            obj_id=object_id,
            mask=mask
        )
else:
    raise NotImplementedError("SAM 2 video predictor only support point/box/mask prompts")

将目标区域应用于全部视频帧,调用 video_predictor.propagate_in_video() 即:

"""
Step 4: Propagate the video predictor to get the segmentation results for each frame
"""
video_segments = {}  # video_segments contains the per-frame segmentation results
for out_frame_idx, out_obj_ids, out_mask_logits in video_predictor.propagate_in_video(inference_state):
    video_segments[out_frame_idx] = {
        out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
        for i, out_obj_id in enumerate(out_obj_ids)
    }

将检测框与目标区域,全部写入视频帧,即:

  • supervision 的显示颜色,与 class_id 参数相关,修改 class_id ,修改显示颜色。
  • 颜色系,参考:supervision.draw.color.ColorPalette
"""
Step 5: Visualize the segment results across the video and save them
"""

if not os.path.exists(SAVE_TRACKING_RESULTS_DIR):
    os.makedirs(SAVE_TRACKING_RESULTS_DIR)

ID_TO_OBJECTS = {i: obj for i, obj in enumerate(OBJECTS, start=1)}

for frame_idx, segments in video_segments.items():
    img = cv2.imread(os.path.join(SOURCE_VIDEO_FRAME_DIR, frame_names[frame_idx]))
    
    object_ids = list(segments.keys())
    masks = list(segments.values())
    masks = np.concatenate(masks, axis=0)
    
    detections = sv.Detections(
        xyxy=sv.mask_to_xyxy(masks),  # (n, 4)
        mask=masks, # (n, h, w)
        class_id=np.array(object_ids, dtype=np.int32)-1,
    )
    box_annotator = sv.BoxAnnotator()
    annotated_frame = box_annotator.annotate(scene=img.copy(), detections=detections)
    label_annotator = sv.LabelAnnotator()
    annotated_frame = label_annotator.annotate(annotated_frame, detections=detections, labels=[ID_TO_OBJECTS[i] for i in object_ids])
    mask_annotator = sv.MaskAnnotator()
    annotated_frame = mask_annotator.annotate(scene=annotated_frame, detections=detections)
    cv2.imwrite(os.path.join(SAVE_TRACKING_RESULTS_DIR, f"annotated_frame_{frame_idx:05d}.jpg"), annotated_frame)

输出效果,注意,其中有 1 帧鞋子没有分割出来,即:

Img

写入视频逻辑,注意编码是 H264,支持浏览器播放,即:

  • 如果 OpenCV 不支持,建议使用 conda 安装 opencv-python,代替 pip 安装,这样可以更好的支持 H264 编码器。
"""
Step 6: Convert the annotated frames to video
"""
import cv2
import os
from tqdm import tqdm

def create_video_from_images(image_folder, output_video_path, frame_rate=5):
    # define valid extension
    valid_extensions = [".jpg", ".jpeg", ".JPG", ".JPEG", ".png", ".PNG"]
    
    # get all image files in the folder
    image_files = [f for f in os.listdir(image_folder) 
                   if os.path.splitext(f)[1] in valid_extensions]
    image_files.sort()  # sort the files in alphabetical order
    print(image_files)
    if not image_files:
        raise ValueError("No valid image files found in the specified folder.")
    
    # load the first image to get the dimensions of the video
    first_image_path = os.path.join(image_folder, image_files[0])
    first_image = cv2.imread(first_image_path)
    height, width, _ = first_image.shape
    
    # create a video writer
    # fourcc = cv2.VideoWriter_fourcc(*'mp4v') # codec for saving the video
    fourcc = cv2.VideoWriter_fourcc(*'H264')
    video_writer = cv2.VideoWriter(output_video_path, fourcc, frame_rate, (width, height))
    
    # write each image to the video
    for image_file in tqdm(image_files):
        image_path = os.path.join(image_folder, image_file)
        image = cv2.imread(image_path)
        video_writer.write(image)
    
    # source release
    video_writer.release()
    print(f"Video saved at {output_video_path}")

create_video_from_images(SAVE_TRACKING_RESULTS_DIR, OUTPUT_VIDEO_PATH)

http://www.kler.cn/a/380840.html

相关文章:

  • SQL 实战:窗口函数的妙用 – 分析排名与分组聚合
  • Spring Boot 中的 @Scheduled 定时任务以及开关控制
  • 解读Makefile中,`=`、`:=`、`?=` 和 `+=`差异
  • FreeSWITCH实现多人电话会议功能
  • VLM--CLIP作分类任务的损失函数
  • JavaWeb期末复习(习题)
  • K8S简单部署,以及UI界面配置
  • Vue指令:v-else、v-else-if
  • 展示+分享|美创科技@2024年数据安全关键技术研究及产业应用成果大会
  • 【云备份】httplib库
  • 信息安全工程师(77)常见网络安全应急事件场景与处理流程
  • 拓展学习-golang的基础语法和常用开发工具
  • 【LeetCode】【算法】234.回文链表
  • Spring Data Redis的基本使用
  • Spring Boot 与 Vue 共铸卓越采购管理新平台
  • OpenID Connect 和 OAuth 2.0 有什么不同?
  • 揭秘rust中默认参数类型不为人知的秘密,你确定不来了解下吗?
  • Java 基于SpringBoot+Vue 的公交智能化系统,附源码、文档
  • Django Form 实现多层(嵌套)模型表单
  • 深度学习模块创作(缝合)教程|适合1-360月小宝宝食用,干货满满
  • 深度学习基础知识-损失函数
  • 【C/C++】memcpy函数的模拟实现
  • Mac OS 配置Docker+Mysql
  • C++中的继承——第一篇
  • ​CSS之三
  • 【OJ题解】C++实现字符串大数相乘:无BigInteger库的字符串乘积解决方案