当前位置: 首页 > article >正文

Segment Anything C++ 项目【Part2:修改源码+自动分割】

目录

写在前面

原始代码

test.cpp

代码理解

整理理解

参数理解

修改代码

再次编译


写在前面

为啥要写这篇文章?在上一篇博客中Segment Anything C++ 项目【Part1:本地跑通+全网最详细】-CSDN博客,我们已经实现了:跑通原始代码,生成exe,手动点击,才会给出运行出结果。但是,对于源码,我们并没有解析,因为有时候,我们确实需要修改代码,以适应我们自己的项目,比如,我想自动分割,怎么办呢?让我们一起来看看吧!

原始代码

原始代码主要有两部分test.cpp和sam.cpp。

test.cpp

我们先来看第一部分test.cpp:

//test.cpp代码如下:
#include <atomic>
#include <opencv2/opencv.hpp>
#include <thread>

#define STRIP_FLAG_HELP 1
#include <gflags/gflags.h>

#include "sam.h"

DEFINE_string(pre_model, "models/sam_preprocess.onnx", "Path to the preprocessing model");
DEFINE_string(sam_model, "models/sam_vit_h_4b8939.onnx", "Path to the sam model");
DEFINE_string(image, "images/input.jpg", "Path to the image to segment");
DEFINE_string(pre_device, "cpu", "cpu or cuda:0(1,2,3...)");
DEFINE_string(sam_device, "cpu", "cpu or cuda:0(1,2,3...)");
DEFINE_bool(h, false, "Show help");

bool parseDeviceName(const std::string& name, Sam::Parameter::Provider& provider) {
  if (name == "cpu") {
    provider.deviceType = 0;
    return true;
  }
  if (name.substr(0, 5) == "cuda:") {
    provider.deviceType = 1;
    provider.gpuDeviceId = std::stoi(name.substr(5));
    return true;
  }
  return false;
}

int main(int argc, char** argv) {
  gflags::ParseCommandLineNonHelpFlags(&argc, &argv, true);
  if (FLAGS_h) {
    std::cout << "Example: ./sam_cpp_test -pre_model=\"models/sam_preprocess.onnx\" "
                 "-sam_model=\"models/sam_vit_h_4b8939.onnx\" "
                 "-image=\"images/input.jpg\" -pre_device=\"cpu\" -sam_device=\"cpu\""
              << std::endl;
    return 0;
  }

  std::cout << "Preprocess device: " << FLAGS_pre_device << "; Sam device: " << FLAGS_sam_device
            << std::endl;

  Sam::Parameter param(FLAGS_pre_model, FLAGS_sam_model, std::thread::hardware_concurrency());
  if (!parseDeviceName(FLAGS_pre_device, param.providers[0]) ||
      !parseDeviceName(FLAGS_sam_device, param.providers[1])) {
    std::cerr << "Unable to parse device name" << std::endl;
  }

  std::cout << "Loading model..." << std::endl;
  Sam sam(param);  // FLAGS_pre_model, FLAGS_sam_model, std::thread::hardware_concurrency());

  auto inputSize = sam.getInputSize();
  if (inputSize.empty()) {
    std::cout << "Sam initialization failed" << std::endl;
    return -1;
  }

  cv::Mat image = cv::imread(FLAGS_image, -1);
  if (image.empty()) {
    std::cout << "Image loading failed" << std::endl;
    return -1;
  }
  std::cout << "Resize image to " << inputSize << std::endl;
  cv::resize(image, image, inputSize);
  std::cout << "Loading image..." << std::endl;
  if (!sam.loadImage(image)) {
    std::cout << "Image loading failed" << std::endl;
    return -1;
  }

  std::cout << "Now click on the image (press q/esc to quit; press c to clear selection; press a "
               "to run automatic segmentation)\n"
            << "Ctrl+Left click to select foreground, Ctrl+Right click to select background, "
            << "Middle click and drag to select a region\n";

  std::list<cv::Point3i> clickedPoints;
  cv::Point3i newClickedPoint(-1, 0, 0);
  cv::Rect roi;
  cv::Mat outImage = image.clone();

  auto g_windowName = "Segment Anything CPP Demo";
  cv::namedWindow(g_windowName, 0);
  cv::setMouseCallback(
      g_windowName,
      [](int event, int x, int y, int flags, void* userdata) {
        int code = -1;
        if (event == cv::EVENT_LBUTTONDOWN) {
          code = 2;
        } else if (event == cv::EVENT_RBUTTONDOWN) {
          code = 0;
        } else if (event == cv::EVENT_MBUTTONDOWN ||
                   ((flags & cv::EVENT_FLAG_MBUTTON) && event == cv::EVENT_MOUSEMOVE)) {
          code = 4;
        } else if (event == cv::EVENT_MBUTTONUP) {
          code = 5;
        }

        if (code >= 0) {
          if (code <= 2 && (flags & cv::EVENT_FLAG_CTRLKEY) == cv::EVENT_FLAG_CTRLKEY) {
            // If ctrl is pressed, then append it to the list later
            code += 1;
          }
          *(cv::Point3i*)userdata = {x, y, code};
        }
      },
      &newClickedPoint);

#define SHOW_TIME                                                     \
  std::cout << "Time elapsed: "                                       \
            << std::chrono::duration_cast<std::chrono::milliseconds>( \
                   std::chrono::system_clock::now() - timeNow)        \
                   .count()                                           \
            << " ms" << std::endl;

  bool bRunning = true;
  while (bRunning) {
    const auto timeNow = std::chrono::system_clock::now();

    if (newClickedPoint.x > 0) {
      std::list<cv::Point> points, nagativePoints;
      if (newClickedPoint.z == 5) {
        roi = {};
      } else if (newClickedPoint.z == 4) {
        if (roi.empty()) {
          roi = cv::Rect(newClickedPoint.x, newClickedPoint.y, 1, 1);
        } else {
          auto tl = roi.tl(), np = cv::Point(newClickedPoint.x, newClickedPoint.y);
          // construct a rectangle from two points
          roi = cv::Rect(cv::Point(std::min(tl.x, np.x), std::min(tl.y, np.y)),
                         cv::Point(std::max(tl.x, np.x), std::max(tl.y, np.y)));
          std::cout << "Box: " << roi << std::endl;
        }
      } else {
        if (newClickedPoint.z % 2 == 0) {
          clickedPoints = {newClickedPoint};
        } else {
          clickedPoints.push_back(newClickedPoint);
        }
      }

      for (auto& p : clickedPoints) {
        if (p.z >= 2) {
          points.push_back({p.x, p.y});
        } else {
          nagativePoints.push_back({p.x, p.y});
        }
      }

      newClickedPoint.x = -1;
      if (points.empty() && nagativePoints.empty() && roi.empty()) {
        continue;
      }

      cv::Mat mask = sam.getMask(points, nagativePoints, roi);
      SHOW_TIME

      // apply mask to image
      outImage = cv::Mat::zeros(image.size(), CV_8UC3);
      for (int i = 0; i < image.rows; i++) {
        for (int j = 0; j < image.cols; j++) {
          auto bFront = mask.at<uchar>(i, j) > 0;
          float factor = bFront ? 1.0 : 0.2;
          outImage.at<cv::Vec3b>(i, j) = image.at<cv::Vec3b>(i, j) * factor;
        }
      }

      for (auto& p : points) {
        cv::circle(outImage, p, 2, {0, 255, 255}, -1);
      }
      for (auto& p : nagativePoints) {
        cv::circle(outImage, p, 2, {255, 0, 0}, -1);
      }
    } else if (newClickedPoint.x == -2) {
      newClickedPoint.x = -1;
      int step = 40;
      cv::Size sampleSize = {image.cols / step, image.rows / step};

      std::cout << "Automatically generating masks with " << sampleSize.area()
                << " input points ..." << std::endl;

      auto mask = sam.autoSegment(
          sampleSize, [](double v) { std::cout << "\rProgress: " << int(v * 100) << "%\t"; });
      SHOW_TIME

      const double overlayFactor = 0.5;
      const int maxMaskValue = 255 * (1 - overlayFactor);
      outImage = cv::Mat::zeros(image.size(), CV_8UC3);

      static std::map<int, cv::Vec3b> colors;

      for (int i = 0; i < image.rows; i++) {
        for (int j = 0; j < image.cols; j++) {
          auto value = (int)mask.at<double>(i, j);
          if (value <= 0) {
            continue;
          }

          auto it = colors.find(value);
          if (it == colors.end()) {
            colors.insert(it, {value, cv::Vec3b(rand() % maxMaskValue, rand() % maxMaskValue,
                                                rand() % maxMaskValue)});
          }

          outImage.at<cv::Vec3b>(i, j) = it->second + image.at<cv::Vec3b>(i, j) * overlayFactor;
        }
      }

      // draw circles on the image to indicate the sample points
      for (int i = 0; i < sampleSize.height; i++) {
        for (int j = 0; j < sampleSize.width; j++) {
          cv::circle(outImage, {j * step, i * step}, 2, {0, 0, 255}, -1);
        }
      }
    }

    if (!roi.empty()) {
      cv::rectangle(outImage, roi, {255, 255, 255}, 2);
    }

    cv::imshow(g_windowName, outImage);
    int key = cv::waitKeyEx(100);
    switch (key) {
      case 27:
      case 'Q':
      case 'q': {
        bRunning = false;
      } break;
      case 'C':
      case 'c': {
        clickedPoints.clear();
        newClickedPoint.x = -1;
        roi = {};
        outImage = image.clone();
      } break;
      case 'A':
      case 'a': {
        clickedPoints.clear();
        newClickedPoint.x = -2;
        outImage = image.clone();
      }
    }
  }

  cv::destroyWindow(g_windowName);

  return 0;
}

代码理解

整理理解

这段代码是一个使用C++编写的命令行应用程序,它结合了OpenCV库、gflags库和一个名为SAM(Segment Anything Model)的模型,来实现图像分割功能。这个程序,允许用户:

通过鼠标与图像交互,指定前景点、背景点或选择区域,并根据这些输入,生成分割掩码。

我们逐步解析一下,代码的主要部分:

  • 包含头文件和定义宏
    • 包含了必要的头文件如<atomic><opencv2/opencv.hpp><thread>等。
    • 使用预处理器指令定义了一个宏STRIP_FLAG_HELP,这可能用于控制帮助信息的输出。
    • 包含了gflags库以处理命令行参数。
    • 包含了一个自定义的sam.h头文件,其中包含了与SAM模型相关的类和函数声明。
  • 定义命令行参数
    • 使用DEFINE_stringDEFINE_bool宏定义了一系列命令行参数,例如预处理模型路径、SAM模型路径、输入图像路径以及设备名称(CPU或CUDA设备)。
  • 解析设备名称
    • parseDeviceName函数负责将字符串形式的设备名称转换为可以被SAM参数对象理解的形式。
  • 主函数main
    • 解析命令行参数,并检查是否需要显示帮助信息。
    • 打印出所选的预处理,和SAM模型使用的设备。
    • 初始化SAM参数,并尝试加载模型。
    • 加载并调整输入图像大小,以适应模型的输入尺寸要求。
    • 设置鼠标回调函数,以捕获用户的点击事件,并相应地更新图像上的标记点或者ROI(感兴趣区域)。
    • 进入一个循环,在该循环中程序会不断监听用户的键盘输入和鼠标操作,根据用户的选择执行不同的操作,比如清除所有点、运行自动分割等。
    • 在每次更新后,都会重新计算掩码并将其应用到原始图像上,然后在窗口中显示结果图像。
    • 根据用户的按键输入,来决定是退出程序还是重置当前状态。
  • 性能监控
    • 代码中定义了SHOW_TIME宏,用于测量并打印某些操作花费的时间。
  • 交互逻辑
    • 用户可以通过:
      • ①“Ctrl键+左击”来标记前景点,
      • ②“Ctrl键+右击”来标记背景点,
      • ③“中间按钮点击并拖动”来选择一个矩形区域。
    • 用户也可以按'A'键触发自动分割,按'C'键清除所有选择,按'Q'或ESC键退出程序。

参数理解

pre_model和sam_model,分别用于指定预处理模型。和分割模型的路径。

  • pre_model:这个参数指向预处理模型(preprocessing model),其路径被设置为 "models/sam_preprocess.onnx"。预处理模型通常用于在将图像输入到主要的分割模型之前对图像进行处理。这可能包括归一化、调整大小或其他必要的步骤,以确保输入数据符合模型的期望格式和范围。预处理模型通常是优化过的,以快速处理输入数据,并且可能在不同的设备上运行(如CPU或GPU)。
  • sam_model:这个参数指向分割模型(segmentation model),其路径被设置为 "models/sam_vit_h_4b8939.onnx"。这个模型是执行实际分割任务的核心模型,它接收预处理后的图像作为输入,并输出每个像素属于特定对象的概率图(mask)。这个模型可能使用复杂的深度学习架构,如Vision Transformer(ViT),来识别图像中的对象并生成精确的分割掩码。

简而言之,pre_model 用于图像的预处理,而 sam_model 用于实际的图像分割任务。两者在图像分割流程中扮演不同的角色,但都对获得高质量的分割结果至关重要。

修改代码

由于默认是需要鼠标互动点击才给结果,但是,我想直接给结果,怎么办呢?我们可以修改代码如下:

#include <opencv2/opencv.hpp>
#include <gflags/gflags.h>
#include "sam.h"

DEFINE_string(pre_model, "models/sam_preprocess.onnx", "Path to the preprocessing model");
DEFINE_string(sam_model, "models/mobile_sam.onnx", "Path to the sam model");
DEFINE_string(image, "images/input.jpg", "Path to the image to segment");
DEFINE_string(pre_device, "cpu", "cpu or cuda:0(1,2,3...)");
DEFINE_string(sam_device, "cpu", "cpu or cuda:0(1,2,3...)");
DEFINE_bool(h, false, "Show help");

bool parseDeviceName(const std::string& name, Sam::Parameter::Provider& provider) {
    if (name == "cpu") {
        provider.deviceType = 0;
        return true;
    }
    if (name.substr(0, 5) == "cuda:") {
        provider.deviceType = 1;
        provider.gpuDeviceId = std::stoi(name.substr(5));
        return true;
    }
    return false;
}

int main(int argc, char** argv) {
    gflags::ParseCommandLineNonHelpFlags(&argc, &argv, true);
    if (FLAGS_h) {
        std::cout << "Example: ./sam_cpp_test -pre_model=\"models/sam_preprocess.onnx\" "
            "-sam_model=\"models/sam_vit_h_4b8939.onnx\" "
            "-image=\"images/input.jpg\" -pre_device=\"cpu\" -sam_device=\"cpu\""
            << std::endl;
        return 0;
    }

    Sam::Parameter param(FLAGS_pre_model, FLAGS_sam_model, std::thread::hardware_concurrency());
    if (!parseDeviceName(FLAGS_pre_device, param.providers[0]) ||
        !parseDeviceName(FLAGS_sam_device, param.providers[1])) {
        std::cerr << "Unable to parse device name" << std::endl;
        return -1;
    }

    Sam sam(param);

    auto inputSize = sam.getInputSize();
    if (inputSize.empty()) {
        std::cout << "Sam initialization failed" << std::endl;
        return -1;
    }

    cv::Mat image = cv::imread(FLAGS_image, -1);
    if (image.empty()) {
        std::cout << "Image loading failed" << std::endl;
        return -1;
    }

    cv::resize(image, image, inputSize);
    if (!sam.loadImage(image)) {
        std::cout << "Image loading failed" << std::endl;
        return -1;
    }

    // 自动分割
    cv::Size sampleSize = { 10, 10 }; // 设置采样点网格大小
    cv::Mat maskAuto = sam.autoSegment(sampleSize);

    // 将掩码应用到原始图像上,创建一个带颜色的输出图像
    cv::Mat outImage = cv::Mat::zeros(image.size(), CV_8UC3);
    static std::map<int, cv::Vec3b> colors;

    for (int i = 0; i < image.rows; i++) {
        for (int j = 0; j < image.cols; j++) {
            auto value = (int)maskAuto.at<double>(i, j);
            if (value <= 0) {
                continue;
            }

            auto it = colors.find(value);
            if (it == colors.end()) {
                colors[value] = cv::Vec3b(rand() % 256, rand() % 256, rand() % 256); // 为每个分割区域分配随机颜色
            }

            outImage.at<cv::Vec3b>(i, j) = it->second * 0.5 + image.at<cv::Vec3b>(i, j) * 0.5; // 混合原始图像和颜色掩码
        }
    }

    // 保存输出图像
    if (!cv::imwrite("output-auto.png", outImage)) {
        std::cout << "Failed to save the output image." << std::endl;
        return -1;
    }

    std::cout << "Automatic segmentation completed and saved as 'output-auto.png'." << std::endl;

    return 0;
}

再次编译

并运行exe,就会看到结果:

发布于 2025-01-02 10:45・IP 属地北京


http://www.kler.cn/a/466676.html

相关文章:

  • html中下拉选框的基本实现方式及JavaScript动态修改选项内容情况总结
  • 初学vue3心得
  • IoC设计模式详解:控制反转的核心思想
  • 【数据结构05】排序
  • 类的定义和使用(python)
  • python多张图片生成/合成gif
  • 我的博客年度之旅:感恩、成长与展望
  • MySQL 表结构在线变更:优雅地解决停机问题
  • 【Rust自学】10.2. 泛型
  • 医学AI公开课第二期|写给癌症研究者的人工智能指南|公开课·25-01-03
  • 论述数据、数据库、数据库管理系统、数据库系统的概念。
  • 利用矢量数据库增强大型语言模型应用
  • Leffa 虚拟试衣论文笔记
  • Unity 3D柱状图效果
  • 【Python】基于blind-watermark库添加图片盲水印
  • 【漏洞复现】用友U8 CRM downloadfile 任意文件读取漏洞复现
  • Dubbo扩展点加载机制
  • 庐山派K230学习日记1 从点灯到吃灰
  • mysql error:1071 -Specified key was too long; max key length is 767 bytes
  • 【深度学习】RNN循环神经网络的原理
  • Golang的代码质量分析工具
  • C# 设计模式(结构型模式):组合模式
  • 基于jQuery的图片浏览插件(1)
  • 探索新一代框架:基于ECS架构的轻量化Web开发
  • C# 设计模式(结构型模式):桥接模式
  • 2024年大型语言模型(LLMs)的发展回顾