Segment Anything C++ 项目【Part2:修改源码+自动分割】
目录
写在前面
原始代码
test.cpp
代码理解
整理理解
参数理解
修改代码
再次编译
写在前面
为啥要写这篇文章?在上一篇博客中Segment Anything C++ 项目【Part1:本地跑通+全网最详细】-CSDN博客,我们已经实现了:跑通原始代码,生成exe,手动点击,才会给出运行出结果。但是,对于源码,我们并没有解析,因为有时候,我们确实需要修改代码,以适应我们自己的项目,比如,我想自动分割,怎么办呢?让我们一起来看看吧!
原始代码
原始代码主要有两部分test.cpp和sam.cpp。
test.cpp
我们先来看第一部分test.cpp:
//test.cpp代码如下:
#include <atomic>
#include <opencv2/opencv.hpp>
#include <thread>
#define STRIP_FLAG_HELP 1
#include <gflags/gflags.h>
#include "sam.h"
DEFINE_string(pre_model, "models/sam_preprocess.onnx", "Path to the preprocessing model");
DEFINE_string(sam_model, "models/sam_vit_h_4b8939.onnx", "Path to the sam model");
DEFINE_string(image, "images/input.jpg", "Path to the image to segment");
DEFINE_string(pre_device, "cpu", "cpu or cuda:0(1,2,3...)");
DEFINE_string(sam_device, "cpu", "cpu or cuda:0(1,2,3...)");
DEFINE_bool(h, false, "Show help");
bool parseDeviceName(const std::string& name, Sam::Parameter::Provider& provider) {
if (name == "cpu") {
provider.deviceType = 0;
return true;
}
if (name.substr(0, 5) == "cuda:") {
provider.deviceType = 1;
provider.gpuDeviceId = std::stoi(name.substr(5));
return true;
}
return false;
}
int main(int argc, char** argv) {
gflags::ParseCommandLineNonHelpFlags(&argc, &argv, true);
if (FLAGS_h) {
std::cout << "Example: ./sam_cpp_test -pre_model=\"models/sam_preprocess.onnx\" "
"-sam_model=\"models/sam_vit_h_4b8939.onnx\" "
"-image=\"images/input.jpg\" -pre_device=\"cpu\" -sam_device=\"cpu\""
<< std::endl;
return 0;
}
std::cout << "Preprocess device: " << FLAGS_pre_device << "; Sam device: " << FLAGS_sam_device
<< std::endl;
Sam::Parameter param(FLAGS_pre_model, FLAGS_sam_model, std::thread::hardware_concurrency());
if (!parseDeviceName(FLAGS_pre_device, param.providers[0]) ||
!parseDeviceName(FLAGS_sam_device, param.providers[1])) {
std::cerr << "Unable to parse device name" << std::endl;
}
std::cout << "Loading model..." << std::endl;
Sam sam(param); // FLAGS_pre_model, FLAGS_sam_model, std::thread::hardware_concurrency());
auto inputSize = sam.getInputSize();
if (inputSize.empty()) {
std::cout << "Sam initialization failed" << std::endl;
return -1;
}
cv::Mat image = cv::imread(FLAGS_image, -1);
if (image.empty()) {
std::cout << "Image loading failed" << std::endl;
return -1;
}
std::cout << "Resize image to " << inputSize << std::endl;
cv::resize(image, image, inputSize);
std::cout << "Loading image..." << std::endl;
if (!sam.loadImage(image)) {
std::cout << "Image loading failed" << std::endl;
return -1;
}
std::cout << "Now click on the image (press q/esc to quit; press c to clear selection; press a "
"to run automatic segmentation)\n"
<< "Ctrl+Left click to select foreground, Ctrl+Right click to select background, "
<< "Middle click and drag to select a region\n";
std::list<cv::Point3i> clickedPoints;
cv::Point3i newClickedPoint(-1, 0, 0);
cv::Rect roi;
cv::Mat outImage = image.clone();
auto g_windowName = "Segment Anything CPP Demo";
cv::namedWindow(g_windowName, 0);
cv::setMouseCallback(
g_windowName,
[](int event, int x, int y, int flags, void* userdata) {
int code = -1;
if (event == cv::EVENT_LBUTTONDOWN) {
code = 2;
} else if (event == cv::EVENT_RBUTTONDOWN) {
code = 0;
} else if (event == cv::EVENT_MBUTTONDOWN ||
((flags & cv::EVENT_FLAG_MBUTTON) && event == cv::EVENT_MOUSEMOVE)) {
code = 4;
} else if (event == cv::EVENT_MBUTTONUP) {
code = 5;
}
if (code >= 0) {
if (code <= 2 && (flags & cv::EVENT_FLAG_CTRLKEY) == cv::EVENT_FLAG_CTRLKEY) {
// If ctrl is pressed, then append it to the list later
code += 1;
}
*(cv::Point3i*)userdata = {x, y, code};
}
},
&newClickedPoint);
#define SHOW_TIME \
std::cout << "Time elapsed: " \
<< std::chrono::duration_cast<std::chrono::milliseconds>( \
std::chrono::system_clock::now() - timeNow) \
.count() \
<< " ms" << std::endl;
bool bRunning = true;
while (bRunning) {
const auto timeNow = std::chrono::system_clock::now();
if (newClickedPoint.x > 0) {
std::list<cv::Point> points, nagativePoints;
if (newClickedPoint.z == 5) {
roi = {};
} else if (newClickedPoint.z == 4) {
if (roi.empty()) {
roi = cv::Rect(newClickedPoint.x, newClickedPoint.y, 1, 1);
} else {
auto tl = roi.tl(), np = cv::Point(newClickedPoint.x, newClickedPoint.y);
// construct a rectangle from two points
roi = cv::Rect(cv::Point(std::min(tl.x, np.x), std::min(tl.y, np.y)),
cv::Point(std::max(tl.x, np.x), std::max(tl.y, np.y)));
std::cout << "Box: " << roi << std::endl;
}
} else {
if (newClickedPoint.z % 2 == 0) {
clickedPoints = {newClickedPoint};
} else {
clickedPoints.push_back(newClickedPoint);
}
}
for (auto& p : clickedPoints) {
if (p.z >= 2) {
points.push_back({p.x, p.y});
} else {
nagativePoints.push_back({p.x, p.y});
}
}
newClickedPoint.x = -1;
if (points.empty() && nagativePoints.empty() && roi.empty()) {
continue;
}
cv::Mat mask = sam.getMask(points, nagativePoints, roi);
SHOW_TIME
// apply mask to image
outImage = cv::Mat::zeros(image.size(), CV_8UC3);
for (int i = 0; i < image.rows; i++) {
for (int j = 0; j < image.cols; j++) {
auto bFront = mask.at<uchar>(i, j) > 0;
float factor = bFront ? 1.0 : 0.2;
outImage.at<cv::Vec3b>(i, j) = image.at<cv::Vec3b>(i, j) * factor;
}
}
for (auto& p : points) {
cv::circle(outImage, p, 2, {0, 255, 255}, -1);
}
for (auto& p : nagativePoints) {
cv::circle(outImage, p, 2, {255, 0, 0}, -1);
}
} else if (newClickedPoint.x == -2) {
newClickedPoint.x = -1;
int step = 40;
cv::Size sampleSize = {image.cols / step, image.rows / step};
std::cout << "Automatically generating masks with " << sampleSize.area()
<< " input points ..." << std::endl;
auto mask = sam.autoSegment(
sampleSize, [](double v) { std::cout << "\rProgress: " << int(v * 100) << "%\t"; });
SHOW_TIME
const double overlayFactor = 0.5;
const int maxMaskValue = 255 * (1 - overlayFactor);
outImage = cv::Mat::zeros(image.size(), CV_8UC3);
static std::map<int, cv::Vec3b> colors;
for (int i = 0; i < image.rows; i++) {
for (int j = 0; j < image.cols; j++) {
auto value = (int)mask.at<double>(i, j);
if (value <= 0) {
continue;
}
auto it = colors.find(value);
if (it == colors.end()) {
colors.insert(it, {value, cv::Vec3b(rand() % maxMaskValue, rand() % maxMaskValue,
rand() % maxMaskValue)});
}
outImage.at<cv::Vec3b>(i, j) = it->second + image.at<cv::Vec3b>(i, j) * overlayFactor;
}
}
// draw circles on the image to indicate the sample points
for (int i = 0; i < sampleSize.height; i++) {
for (int j = 0; j < sampleSize.width; j++) {
cv::circle(outImage, {j * step, i * step}, 2, {0, 0, 255}, -1);
}
}
}
if (!roi.empty()) {
cv::rectangle(outImage, roi, {255, 255, 255}, 2);
}
cv::imshow(g_windowName, outImage);
int key = cv::waitKeyEx(100);
switch (key) {
case 27:
case 'Q':
case 'q': {
bRunning = false;
} break;
case 'C':
case 'c': {
clickedPoints.clear();
newClickedPoint.x = -1;
roi = {};
outImage = image.clone();
} break;
case 'A':
case 'a': {
clickedPoints.clear();
newClickedPoint.x = -2;
outImage = image.clone();
}
}
}
cv::destroyWindow(g_windowName);
return 0;
}
代码理解
整理理解
这段代码是一个使用C++编写的命令行应用程序,它结合了OpenCV库、gflags库和一个名为SAM(Segment Anything Model)的模型,来实现图像分割功能。这个程序,允许用户:
通过鼠标与图像交互,指定前景点、背景点或选择区域,并根据这些输入,生成分割掩码。
我们逐步解析一下,代码的主要部分:
- 包含头文件和定义宏:
- 包含了必要的头文件如
<atomic>
、<opencv2/opencv.hpp>
、<thread>
等。 - 使用预处理器指令定义了一个宏
STRIP_FLAG_HELP
,这可能用于控制帮助信息的输出。 - 包含了
gflags
库以处理命令行参数。 - 包含了一个自定义的
sam.h
头文件,其中包含了与SAM模型相关的类和函数声明。
- 包含了必要的头文件如
- 定义命令行参数:
- 使用
DEFINE_string
和DEFINE_bool
宏定义了一系列命令行参数,例如预处理模型路径、SAM模型路径、输入图像路径以及设备名称(CPU或CUDA设备)。
- 使用
- 解析设备名称:
parseDeviceName
函数负责将字符串形式的设备名称转换为可以被SAM参数对象理解的形式。
- 主函数
main
:- 解析命令行参数,并检查是否需要显示帮助信息。
- 打印出所选的预处理,和SAM模型使用的设备。
- 初始化SAM参数,并尝试加载模型。
- 加载并调整输入图像大小,以适应模型的输入尺寸要求。
- 设置鼠标回调函数,以捕获用户的点击事件,并相应地更新图像上的标记点或者ROI(感兴趣区域)。
- 进入一个循环,在该循环中程序会不断监听用户的键盘输入和鼠标操作,根据用户的选择执行不同的操作,比如清除所有点、运行自动分割等。
- 在每次更新后,都会重新计算掩码并将其应用到原始图像上,然后在窗口中显示结果图像。
- 根据用户的按键输入,来决定是退出程序还是重置当前状态。
- 性能监控:
- 代码中定义了
SHOW_TIME
宏,用于测量并打印某些操作花费的时间。
- 代码中定义了
- 交互逻辑:
- 用户可以通过:
- ①“Ctrl键+左击”来标记前景点,
- ②“Ctrl键+右击”来标记背景点,
- ③“中间按钮点击并拖动”来选择一个矩形区域。
- 用户也可以按'A'键触发自动分割,按'C'键清除所有选择,按'Q'或ESC键退出程序。
- 用户可以通过:
参数理解
pre_model和sam_model,分别用于指定预处理模型。和分割模型的路径。
pre_model
:这个参数指向预处理模型(preprocessing model),其路径被设置为"models/sam_preprocess.onnx"
。预处理模型通常用于在将图像输入到主要的分割模型之前对图像进行处理。这可能包括归一化、调整大小或其他必要的步骤,以确保输入数据符合模型的期望格式和范围。预处理模型通常是优化过的,以快速处理输入数据,并且可能在不同的设备上运行(如CPU或GPU)。sam_model
:这个参数指向分割模型(segmentation model),其路径被设置为"models/sam_vit_h_4b8939.onnx"
。这个模型是执行实际分割任务的核心模型,它接收预处理后的图像作为输入,并输出每个像素属于特定对象的概率图(mask)。这个模型可能使用复杂的深度学习架构,如Vision Transformer(ViT),来识别图像中的对象并生成精确的分割掩码。
简而言之,pre_model
用于图像的预处理,而 sam_model
用于实际的图像分割任务。两者在图像分割流程中扮演不同的角色,但都对获得高质量的分割结果至关重要。
修改代码
由于默认是需要鼠标互动点击才给结果,但是,我想直接给结果,怎么办呢?我们可以修改代码如下:
#include <opencv2/opencv.hpp>
#include <gflags/gflags.h>
#include "sam.h"
DEFINE_string(pre_model, "models/sam_preprocess.onnx", "Path to the preprocessing model");
DEFINE_string(sam_model, "models/mobile_sam.onnx", "Path to the sam model");
DEFINE_string(image, "images/input.jpg", "Path to the image to segment");
DEFINE_string(pre_device, "cpu", "cpu or cuda:0(1,2,3...)");
DEFINE_string(sam_device, "cpu", "cpu or cuda:0(1,2,3...)");
DEFINE_bool(h, false, "Show help");
bool parseDeviceName(const std::string& name, Sam::Parameter::Provider& provider) {
if (name == "cpu") {
provider.deviceType = 0;
return true;
}
if (name.substr(0, 5) == "cuda:") {
provider.deviceType = 1;
provider.gpuDeviceId = std::stoi(name.substr(5));
return true;
}
return false;
}
int main(int argc, char** argv) {
gflags::ParseCommandLineNonHelpFlags(&argc, &argv, true);
if (FLAGS_h) {
std::cout << "Example: ./sam_cpp_test -pre_model=\"models/sam_preprocess.onnx\" "
"-sam_model=\"models/sam_vit_h_4b8939.onnx\" "
"-image=\"images/input.jpg\" -pre_device=\"cpu\" -sam_device=\"cpu\""
<< std::endl;
return 0;
}
Sam::Parameter param(FLAGS_pre_model, FLAGS_sam_model, std::thread::hardware_concurrency());
if (!parseDeviceName(FLAGS_pre_device, param.providers[0]) ||
!parseDeviceName(FLAGS_sam_device, param.providers[1])) {
std::cerr << "Unable to parse device name" << std::endl;
return -1;
}
Sam sam(param);
auto inputSize = sam.getInputSize();
if (inputSize.empty()) {
std::cout << "Sam initialization failed" << std::endl;
return -1;
}
cv::Mat image = cv::imread(FLAGS_image, -1);
if (image.empty()) {
std::cout << "Image loading failed" << std::endl;
return -1;
}
cv::resize(image, image, inputSize);
if (!sam.loadImage(image)) {
std::cout << "Image loading failed" << std::endl;
return -1;
}
// 自动分割
cv::Size sampleSize = { 10, 10 }; // 设置采样点网格大小
cv::Mat maskAuto = sam.autoSegment(sampleSize);
// 将掩码应用到原始图像上,创建一个带颜色的输出图像
cv::Mat outImage = cv::Mat::zeros(image.size(), CV_8UC3);
static std::map<int, cv::Vec3b> colors;
for (int i = 0; i < image.rows; i++) {
for (int j = 0; j < image.cols; j++) {
auto value = (int)maskAuto.at<double>(i, j);
if (value <= 0) {
continue;
}
auto it = colors.find(value);
if (it == colors.end()) {
colors[value] = cv::Vec3b(rand() % 256, rand() % 256, rand() % 256); // 为每个分割区域分配随机颜色
}
outImage.at<cv::Vec3b>(i, j) = it->second * 0.5 + image.at<cv::Vec3b>(i, j) * 0.5; // 混合原始图像和颜色掩码
}
}
// 保存输出图像
if (!cv::imwrite("output-auto.png", outImage)) {
std::cout << "Failed to save the output image." << std::endl;
return -1;
}
std::cout << "Automatic segmentation completed and saved as 'output-auto.png'." << std::endl;
return 0;
}
再次编译
并运行exe,就会看到结果:
发布于 2025-01-02 10:45・IP 属地北京