当前位置: 首页 > article >正文

Vision - 视觉分割开源算法 SAM2(Segment Anything 2) 配置与推理 教程 (1)

欢迎关注我的CSDN:https://spike.blog.csdn.net/
本文地址:https://spike.blog.csdn.net/article/details/143220597

免责声明:本文来源于个人知识与公开资料,仅用于学术交流,欢迎讨论,不支持转载。


SAM2

SAM2(Segment Anything Model 2) 视觉分割算法是 计算机视觉(CV) 的关键技术,精确的将图像中的不同对象区分开来,使用深度学习模型来分析图像中的像素分布,生成 Mask,标识出每个对象的边界。通过在多层裁剪和不同尺度下计算 Mask 的稳定性评分,确保结果的高精度和稳定性。

Paper: Segment Anything in Images and Videos

1. 环境配置

运行代码:

git clone https://github.com/facebookresearch/sam2.git

注意:项目文件比较大,可以直接使用 zip 包,或者使用 GitHub 代理。

构建环境:

conda create -n sam2 python=3.10
conda activate sam2

安装 PyTorch 包:

pip3 install torch torchvision torchaudio

python

import torch
print(torch.__version__)  # 2.5.0+cu124
print(torch.cuda.is_available())  # True
exit()

环境依赖:Python ≥ 3.10PyTorch ≥ 2.3.1

配置 CUDA 环境变量:

export CUDA_HOME=/usr/local/cuda  # change to your CUDA toolkit path
echo $CUDA_HOME

安装 SAM2 项目:

pip install --no-build-isolation -e .
pip install --no-build-isolation -e ".[notebooks]"  # 适配 Jupyter

--no-build-isolation 是禁用构建隔离,避免 CUDA 无法访问。

将 conda 导入 Jupyter 环境:

pip install ipykernel
python -m ipykernel install --user --name sam2 --display-name "sam2"

环境变量 export PYTORCH_ENABLE_MPS_FALLBACK=1,PyTorch 将会在遇到 MPS 不支持的操作时,自动切换到 CPU 处理。

2. 测试推理

导入 Python 包:

import os
import numpy as np
import torch
import matplotlib.pyplot as plt
from PIL import Image
# os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
os.environ["CUDA_VISIBLE_DEVICES"] = "4"  # 自定义使用的卡

配置 Torch 运行设备,与运行精度,即:

# select the device for computation
if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")
print(f"using device: {device}")

if device.type == "cuda":
    # use bfloat16 for the entire notebook
    torch.autocast("cuda", dtype=torch.bfloat16).__enter__()
    # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
    cuda_major = torch.cuda.get_device_properties(0).major
    print(f"[Info] cuda_major: {cuda_major}")
    if cuda_major >= 8:
        torch.backends.cuda.matmul.allow_tf32 = True
        torch.backends.cudnn.allow_tf32 = True
elif device.type == "mps":
    print(
        "\nSupport for MPS devices is preliminary. SAM 2 is trained with CUDA and might "
        "give numerically different outputs and sometimes degraded performance on MPS. "
        "See e.g. https://github.com/pytorch/pytorch/issues/84936 for a discussion."
    )

注意:

  • 启用 bfloat16 数据类型,自动混合精度计算,有助于提高模型训练的速度和效率,同时保持较高的精度。
    • torch.autocast("cuda", dtype=torch.bfloat16).__enter__()
  • 获取第一个 CUDA 设备(通常是 GPU)的主要版本号,判断设备是否支持 Ampere 架构的 GPU (>=8 支持)
    • cuda_major = torch.cuda.get_device_properties(0).major
  • 如果检测 CUDA 设备的主要版本号 >= 8,即 Ampere 架构的 GPU,则启用 tfloat32 (TensorFloat-32),在 Ampere 设备上提高矩阵运算性能的优化技术。
    • torch.backends.cuda.matmul.allow_tf32 = True
    • torch.backends.cudnn.allow_tf32 = True

显示标注的 mask 信息 show_anns() 即:

np.random.seed(3)

def show_anns(anns, borders=True):
    if len(anns) == 0:
        return
    sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
    ax = plt.gca()
    ax.set_autoscale_on(False)

    img = np.ones((sorted_anns[0]['segmentation'].shape[0], sorted_anns[0]['segmentation'].shape[1], 4))
    img[:, :, 3] = 0
    for ann in sorted_anns:
        m = ann['segmentation']
        color_mask = np.concatenate([np.random.random(3), [0.5]])
        img[m] = color_mask 
        if borders:
            import cv2
            contours, _ = cv2.findContours(m.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) 
            # Try to smooth contours
            contours = [cv2.approxPolyDP(contour, epsilon=0.01, closed=True) for contour in contours]
            cv2.drawContours(img, contours, -1, (0, 0, 1, 0.4), thickness=1) 

    ax.imshow(img)

读取图像,显示图像:

# image = Image.open('notebooks/images/cars.jpg')
image = Image.open('[your path]/llm/vision_test_data/image2.png')
image = np.array(image.convert("RGB"))
# image.shape (569, 1138, 3)

plt.figure(figsize=(20, 20))
plt.imshow(image)
plt.axis('on')  # 现在坐标
plt.show()

构建自动 Mask 生成器,使用默认参数,注意选择模型 sam2.1_hiera_large.pt ,以及配置参数 sam2.1_hiera_l.yaml,即:

from sam2.build_sam import build_sam2
from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator

sam2_checkpoint = "sam2.1_hiera_large.pt"
model_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml"

sam2 = build_sam2(model_cfg, sam2_checkpoint, device=device, apply_postprocessing=False)

mask_generator = SAM2AutomaticMaskGenerator(sam2)

生成图像的 Mask:

masks = mask_generator.generate(image)
print(len(masks))
# 43
print(masks[0].keys())
# dict_keys(['segmentation', 'area', 'bbox', 'predicted_iou', 'point_coords', 'stability_score', 'crop_box'])

图像预测效果:

plt.figure(figsize=(20, 20))
plt.imshow(image)
show_anns(masks)
plt.axis('on')
plt.show() 

默认效果:

Seg

自定义生成器:

mask_generator_2 = SAM2AutomaticMaskGenerator(
    model=sam2,
    points_per_side=64,
    points_per_batch=128,
    pred_iou_thresh=0.7,
    stability_score_thresh=0.92,
    stability_score_offset=0.7,
    crop_n_layers=1,
    box_nms_thresh=0.7,
    crop_n_points_downscale_factor=2,
    min_mask_region_area=25.0,
    use_m2m=True,
)

参数说明:

  1. model (Sam):用于生成 mask 预测的 SAM2 模型。
  2. points_per_side (int or None):沿图像一侧采样的点数。总点数为 points_per_side**2。如果为 None,则需要 point_grids 提供显式点采样。
  3. points_per_batch (int):模型同时处理的点数。点数越多,速度越快,占用更多 GPU 内存。
  4. pred_iou_thresh (float):使用模型预测的 mask 质量的过滤阈值,范围在 [0,1]
  5. stability_score_thresh (float):使用 mask 在二值化过程中,变化的稳定性作为过滤阈值,范围在 [0,1]
  6. stability_score_offset (float):计算稳定性评分时用于调整 mask 的偏移量。
  7. box_nms_thresh (float):非极大值抑制中使用的 Box IoU 阈值,用于过滤重复 mask。
  8. crop_n_layers (int):如果 >0,会在图像裁剪后,再次运行 mask 预测。设置要运行的层数,每层有 2**i_layer 个图像裁剪。
  9. crop_n_points_downscale_factor (int):第 n 层采样的每边点数按 crop_n_points_downscale_factor**n 缩放。
  10. min_mask_region_area (int):如果 >0,后处理将移除面积小于 min_mask_region_area 的分离区域和 mask 中的孔洞。需要 OpenCV。
  11. use_m2m (bool):是否使用以前的 mask 预测进行一步优化,即在 mask 中,继续进行分割 mask。

运行图像分割:

masks2 = mask_generator_2.generate(image)
plt.figure(figsize=(20, 20))
plt.imshow(image)
show_anns(masks2)
plt.axis('off')
plt.show() 

示例图像分割,更加细腻,即:

Seg


http://www.kler.cn/a/373507.html

相关文章:

  • ip命令网络配置详解
  • 链表:两数相加
  • 算法复杂度分析:深入剖析最好、最坏、平均、均摊时间复杂度
  • MySQL 复合索引测试
  • JavaFx -- chapter05(多用户服务器)
  • 北京迅为iTOP-LS2K0500开发板快速使用编译环境虚拟机Ubuntu基础操作及设置
  • ValueError: Object arrays cannot be loaded when allow_pickle=False
  • “换行”与“回车”
  • OpenCV 学习笔记
  • 同步和异步
  • AprilTag在相机标定中的应用简介
  • 20 Docker容器集群网络架构:三、Docker集群部署
  • window11使用wsl2安装Ubuntu22.04
  • Linux_04 Linux常用命令——tar
  • 深度学习(九):推荐系统的新引擎(9/10)
  • 【Java并发编程】信号量Semaphore详解
  • docker pull 拉取镜像失败,使用Docker离线包
  • 零基础学西班牙语,柯桥专业小语种培训泓畅学校
  • Si24R05:125K接收2.4G收发SoC芯片规格书
  • CSS行块标签的显示方式
  • 无人机之目标检测算法篇
  • 全自动采集、即时传输:RFID技术为BD数字化装备场尽力!
  • 嵌入式C语言字符串具体实现
  • linux离线安装Ollama并完成大模型配置(无网络)
  • 快速上手 Rust——实用示例
  • (五)Web前端开发进阶2——AJAX