当前位置: 首页 > article >正文

LLM - 视觉分割开源算法 SAM2(Segment Anything Model 2) 配置与推理 (1)

欢迎关注我的CSDN:https://spike.blog.csdn.net/
本文地址:https://spike.blog.csdn.net/article/details/143220597

免责声明:本文来源于个人知识与公开资料,仅用于学术交流,欢迎讨论,不支持转载。


SAM2

SAM2(Segment Anything Model 2) 视觉分割算法是 计算机视觉(CV) 的关键技术,精确的将图像中的不同对象区分开来,使用深度学习模型来分析图像中的像素分布,生成 Mask,标识出每个对象的边界。通过在多层裁剪和不同尺度下计算 Mask 的稳定性评分,确保结果的高精度和稳定性。

Paper: Segment Anything in Images and Videos

1. 环境配置

运行代码:

git clone https://github.com/facebookresearch/sam2.git

注意:项目文件比较大,可以直接使用 zip 包,或者使用 GitHub 代理。

构建环境:

conda create -n sam2 python=3.10
conda activate sam2

安装 PyTorch 包:

pip3 install torch torchvision torchaudio

python

import torch
print(torch.__version__)  # 2.5.0+cu124
print(torch.cuda.is_available())  # True
exit()

环境依赖:Python ≥ 3.10PyTorch ≥ 2.3.1

配置 CUDA 环境变量:

export CUDA_HOME=/usr/local/cuda  # change to your CUDA toolkit path
echo $CUDA_HOME

安装 SAM2 项目:

pip install --no-build-isolation -e .
pip install --no-build-isolation -e ".[notebooks]"  # 适配 Jupyter

--no-build-isolation 是禁用构建隔离,避免 CUDA 无法访问。

将 conda 导入 Jupyter 环境:

pip install ipykernel
python -m ipykernel install --user --name sam2 --display-name "sam2"

环境变量 export PYTORCH_ENABLE_MPS_FALLBACK=1,PyTorch 将会在遇到 MPS 不支持的操作时,自动切换到 CPU 处理。

2. 测试推理

导入 Python 包:

import os
import numpy as np
import torch
import matplotlib.pyplot as plt
from PIL import Image
# os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
os.environ["CUDA_VISIBLE_DEVICES"] = "4"  # 自定义使用的卡

配置 Torch 运行设备,与运行精度,即:

# select the device for computation
if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")
print(f"using device: {device}")

if device.type == "cuda":
    # use bfloat16 for the entire notebook
    torch.autocast("cuda", dtype=torch.bfloat16).__enter__()
    # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
    cuda_major = torch.cuda.get_device_properties(0).major
    print(f"[Info] cuda_major: {cuda_major}")
    if cuda_major >= 8:
        torch.backends.cuda.matmul.allow_tf32 = True
        torch.backends.cudnn.allow_tf32 = True
elif device.type == "mps":
    print(
        "\nSupport for MPS devices is preliminary. SAM 2 is trained with CUDA and might "
        "give numerically different outputs and sometimes degraded performance on MPS. "
        "See e.g. https://github.com/pytorch/pytorch/issues/84936 for a discussion."
    )

注意:

  • 启用 bfloat16 数据类型,自动混合精度计算,有助于提高模型训练的速度和效率,同时保持较高的精度。
    • torch.autocast("cuda", dtype=torch.bfloat16).__enter__()
  • 获取第一个 CUDA 设备(通常是 GPU)的主要版本号,判断设备是否支持 Ampere 架构的 GPU (>=8 支持)
    • cuda_major = torch.cuda.get_device_properties(0).major
  • 如果检测 CUDA 设备的主要版本号 >= 8,即 Ampere 架构的 GPU,则启用 tfloat32 (TensorFloat-32),在 Ampere 设备上提高矩阵运算性能的优化技术。
    • torch.backends.cuda.matmul.allow_tf32 = True
    • torch.backends.cudnn.allow_tf32 = True

显示标注的 mask 信息 show_anns() 即:

np.random.seed(3)

def show_anns(anns, borders=True):
    if len(anns) == 0:
        return
    sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
    ax = plt.gca()
    ax.set_autoscale_on(False)

    img = np.ones((sorted_anns[0]['segmentation'].shape[0], sorted_anns[0]['segmentation'].shape[1], 4))
    img[:, :, 3] = 0
    for ann in sorted_anns:
        m = ann['segmentation']
        color_mask = np.concatenate([np.random.random(3), [0.5]])
        img[m] = color_mask 
        if borders:
            import cv2
            contours, _ = cv2.findContours(m.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) 
            # Try to smooth contours
            contours = [cv2.approxPolyDP(contour, epsilon=0.01, closed=True) for contour in contours]
            cv2.drawContours(img, contours, -1, (0, 0, 1, 0.4), thickness=1) 

    ax.imshow(img)

读取图像,显示图像:

# image = Image.open('notebooks/images/cars.jpg')
image = Image.open('[your path]/llm/vision_test_data/image2.png')
image = np.array(image.convert("RGB"))
# image.shape (569, 1138, 3)

plt.figure(figsize=(20, 20))
plt.imshow(image)
plt.axis('on')  # 现在坐标
plt.show()

构建自动 Mask 生成器,使用默认参数,注意选择模型 sam2.1_hiera_large.pt ,以及配置参数 sam2.1_hiera_l.yaml,即:

from sam2.build_sam import build_sam2
from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator

sam2_checkpoint = "sam2.1_hiera_large.pt"
model_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml"

sam2 = build_sam2(model_cfg, sam2_checkpoint, device=device, apply_postprocessing=False)

mask_generator = SAM2AutomaticMaskGenerator(sam2)

生成图像的 Mask:

masks = mask_generator.generate(image)
print(len(masks))
# 43
print(masks[0].keys())
# dict_keys(['segmentation', 'area', 'bbox', 'predicted_iou', 'point_coords', 'stability_score', 'crop_box'])

图像预测效果:

plt.figure(figsize=(20, 20))
plt.imshow(image)
show_anns(masks)
plt.axis('on')
plt.show() 

默认效果:

Seg

自定义生成器:

mask_generator_2 = SAM2AutomaticMaskGenerator(
    model=sam2,
    points_per_side=64,
    points_per_batch=128,
    pred_iou_thresh=0.7,
    stability_score_thresh=0.92,
    stability_score_offset=0.7,
    crop_n_layers=1,
    box_nms_thresh=0.7,
    crop_n_points_downscale_factor=2,
    min_mask_region_area=25.0,
    use_m2m=True,
)

参数说明:

  1. model (Sam):用于生成 mask 预测的 SAM2 模型。
  2. points_per_side (int or None):沿图像一侧采样的点数。总点数为 points_per_side**2。如果为 None,则需要 point_grids 提供显式点采样。
  3. points_per_batch (int):模型同时处理的点数。点数越多,速度越快,占用更多 GPU 内存。
  4. pred_iou_thresh (float):使用模型预测的 mask 质量的过滤阈值,范围在 [0,1]
  5. stability_score_thresh (float):使用 mask 在二值化过程中,变化的稳定性作为过滤阈值,范围在 [0,1]
  6. stability_score_offset (float):计算稳定性评分时用于调整 mask 的偏移量。
  7. box_nms_thresh (float):非极大值抑制中使用的 Box IoU 阈值,用于过滤重复 mask。
  8. crop_n_layers (int):如果 >0,会在图像裁剪后,再次运行 mask 预测。设置要运行的层数,每层有 2**i_layer 个图像裁剪。
  9. crop_n_points_downscale_factor (int):第 n 层采样的每边点数按 crop_n_points_downscale_factor**n 缩放。
  10. min_mask_region_area (int):如果 >0,后处理将移除面积小于 min_mask_region_area 的分离区域和 mask 中的孔洞。需要 OpenCV。
  11. use_m2m (bool):是否使用以前的 mask 预测进行一步优化,即在 mask 中,继续进行分割 mask。

运行图像分割:

masks2 = mask_generator_2.generate(image)
plt.figure(figsize=(20, 20))
plt.imshow(image)
show_anns(masks2)
plt.axis('off')
plt.show() 

示例图像分割,更加细腻,即:

Seg


http://www.kler.cn/news/364326.html

相关文章:

  • C++11 28-纯虚函数的默认实现 The default implementation of pure virtual functions
  • 《人工智能往事》—— 简而言之,AI 已经包围了我们。AI 就是我们。
  • ref属性的作用对象类型
  • 项目管理新趋势!2024年,Jira与禅道你更倾向谁?
  • 无人机之低空管控技术
  • 外包干了2个月,技术明显退步
  • Windows无法打开组策略 | Windows家庭版如何添加和打开组策略
  • JavaWeb开发全攻略:从零到精通,掌握核心技术与最佳实践,打造高性能Web应用!
  • 9月模拟手游下载量迎来激增,两款新游跻身全球下载榜前十!
  • 【有啥问啥】智能座舱中的ADDW认证是什么?
  • [蓝桥杯 2024 省 C] 回文数组
  • Go语言开发环境搭建
  • 《a16z : 2024 年加密货币现状报告》解析
  • 云计算与SaaS赋能的工业软件服务新形态
  • 第五十一章 安全元素的详细信息 - EncryptedKey 详情
  • 2-解决联想拯救者Y7000p在ubuntu20.04未找到wifi适配器,安装rtl8852ce网卡驱动问题
  • django报错问题Error 0x800B0109(CERT_E_UNTRUSTEDROOT)(已解决)
  • 政府办公人员常见的办公软件技能
  • Python作业
  • JavaScript 中四种常见的数据类型判断方法
  • SSCI/SCI/EI/Scopus/期刊合集,周期短,快速发表,见刊快!
  • 微服务之网关、网关路由、网关登录校验
  • MySQL数据库—多表查询
  • 数码管显示屏驱动高亮LED驱动芯片VK16K33A数码管控制电路
  • 电脑程序变化监控怎么设置?实时监控电脑程序变化的五大方法,手把手教会你!
  • qt QMainWindow详解