当前位置: 首页 > article >正文

计算3D目标框的NMS

3D障碍物目标框(中心点坐标XYZ、长宽高lwh、朝向角theta)的非极大值抑制

#include <iostream>
#include <vector>
#include <algorithm>
#include <opencv2/opencv.hpp>

// 定义3D目标框的结构体
struct BoundingBox3D
{
    double centerX, centerY, centerZ; // 中心点坐标
    double length, width, height;     // 长宽高
    double theta;                     // 朝向角
    double score;                     // 目标框得分

    BoundingBox3D(double x, double y, double z, double l, double w, double h, double t, double s)
        : centerX(x), centerY(y), centerZ(z), length(l), width(w), height(h), theta(t), score(s) {}
};

class NMS3D
{
public:
    // 构造函数,传入IoU阈值
    NMS3D(double iouThreshold) : iouThreshold_(iouThreshold) {}

    // 执行NMS
    std::vector<BoundingBox3D> executeNMS(const std::vector<BoundingBox3D> &boxes)
    {
        std::vector<BoundingBox3D> resultBoxes;

        // 按得分降序排序
        std::vector<BoundingBox3D> sortedBoxes = sortBoxesByScore(boxes);

        // 遍历排序后的框
        while (!sortedBoxes.empty())
        {
            // 保留得分最高的框
            BoundingBox3D topBox = sortedBoxes[0];
            resultBoxes.push_back(topBox);

            // 移除与当前框IoU大于阈值的框
            sortedBoxes.erase(sortedBoxes.begin());
            sortedBoxes = removeOverlappingBoxes(topBox, sortedBoxes);
        }

        return resultBoxes;
    }

private:
    // 按得分降序排序
    std::vector<BoundingBox3D> sortBoxesByScore(const std::vector<BoundingBox3D> &boxes)
    {
        std::vector<BoundingBox3D> sortedBoxes = boxes;
        std::sort(sortedBoxes.begin(), sortedBoxes.end(),
                  [](const BoundingBox3D &a, const BoundingBox3D &b)
                  {
                      return a.score > b.score;
                  });
        return sortedBoxes;
    }

    // 移除与指定框IoU大于阈值的框
    std::vector<BoundingBox3D> removeOverlappingBoxes(const BoundingBox3D &box,
                                                      const std::vector<BoundingBox3D> &boxes)
    {
        std::vector<BoundingBox3D> filteredBoxes;
        for (const auto &b : boxes)
        {
            if (calculateIoU(box, b) < iouThreshold_)
            {
                filteredBoxes.push_back(b);
            }
        }
        return filteredBoxes;
    }

    // 计算两个框的IoU(Intersection over Union)
    double calculateIoU(const BoundingBox3D &box1, const BoundingBox3D &box2)
    {
        // 计算两个框的相交部分的体积
        double intersectionVolume = calculateIntersectionVolume(box1, box2);

        // 计算两个框的并集部分的体积
        double unionVolume = box1.length * box1.width * box1.height +
                             box2.length * box2.width * box2.height -
                             intersectionVolume;

        // 计算IoU
        return intersectionVolume / unionVolume;
    }

    // 计算两个框的相交部分的体积
    double calculateIntersectionVolume(const BoundingBox3D &box1, const BoundingBox3D &box2)
    {
        // 计算平面重叠面积
        double intersectArea = calIntersectionArea(box1, box2);
        
        double intersectHeight = calculateOverlap(box1.centerZ, box1.height, box2.centerZ, box2.height);

        // 计算相交部分的体积
        return intersectArea * intersectHeight;
    }

    cv::Point rotatePoint(const cv::Point &point, double angle)
    {
        double rotatedX = point.x * cos(angle) - point.y * sin(angle);
        double rotatedY = point.x * sin(angle) + point.y * cos(angle);

        return cv::Point(rotatedX, rotatedY);
    }

    double calIntersectionArea(const BoundingBox3D &box1, const BoundingBox3D &box2)
    {
        cv::RotatedRect rect1(cv::Point2f(box1.centerX,box1.centerY),cv::Size2f(box1.width,box1.height),box1.theta);
        cv::RotatedRect rect2(cv::Point2f(box2.centerX,box2.centerY),cv::Size2f(box2.width,box2.height),box2.theta);

        std::vector<cv::Point2f> intersection;
        cv::rotatedRectangleIntersection(rect1,rect2, intersection);
        // std::cout <<rect1.center<< " "<<rect2.center<<std::endl;
        // std::cout <<rect1.size<< " "<<rect2.size<<std::endl;
        // std::cout << "intersection area:"<<intersection.size()<<std::endl;

        double union_area = cv::contourArea(intersection);
        // std::cout << "intersection area:"<<union_area<<std::endl;
        return union_area;
    }

    // 计算两个轴上的重叠部分长度
    double calculateOverlap(double center1, double size1, double center2, double size2)
    {
        double halfSize1 = size1 / 2;
        double halfSize2 = size2 / 2;
        double min1 = center1 - halfSize1;
        double max1 = center1 + halfSize1;
        double min2 = center2 - halfSize2;
        double max2 = center2 + halfSize2;

        // 计算重叠部分长度
        return std::max(0.0, std::min(max1, max2) - std::max(min1, min2));
    }

    double iouThreshold_; // IoU阈值
};

int main()
{
    std::vector<BoundingBox3D> inputBoxes;
    inputBoxes.push_back(BoundingBox3D(0.0, 0.0, 0.0, 200.0,200.0, 200.0, 45, 0.9));
    inputBoxes.push_back(BoundingBox3D(100,100, 10, 200.0, 200.0, 200.0, -45, 0.8));
    //inputBoxes.push_back(BoundingBox3D(2.0, 2.0, 2.0, 2.0, 1.0, 1.0, 0, 0.7));

    double iouThreshold = 0.5; // 可根据实际情况调整IoU阈值
    NMS3D nms(iouThreshold);
    std::vector<BoundingBox3D> resultBoxes = nms.executeNMS(inputBoxes);

    // 输出结果框
    for (const auto &box : resultBoxes)
    {
        std::cout << "Center: (" << box.centerX << ", " << box.centerY << ", " << box.centerZ << "), "
                  << "Dimensions: (" << box.length << ", " << box.width << ", " << box.height << "), "
                  << "Theta: " << box.theta << ", "
                  << "Score: " << box.score << std::endl;
    }
    return 0;
}

关于cv::contourArea可能计算不准的问题,是由于传入的点没有按照一定的顺序排列(顺时针或逆时针)。参考解决博客


http://www.kler.cn/a/134998.html

相关文章:

  • 青少年编程与数学 02-006 前端开发框架VUE 18课题、逻辑复用
  • 在php中,Fiber、Swoole、Swow这3个协程都是如何并行运行的?
  • Mysql常见知识点
  • 【Word_笔记】Word的修订模式内容改为颜色标记
  • Nginx 413 Request Entity Too Large
  • 使用centos搭建内网的yum源
  • 【最新Tomcat】IntelliJ IDEA通用配置Tomcat教程(超详细)
  • JDY蓝牙注意事项
  • 高精度算法【Java】(待更新中~)
  • 【以图会意】操作系统的加载流程
  • 【Linux】C文件系统详解(四)——磁盘的物理和抽象结构
  • Bert浅谈
  • vue使用本地图片设置为默认图
  • 【Django使用】django经验md文档10大模块。第4期:Django数据库增删改查
  • 整形数据和浮点型数据在内存中的存储差别
  • 微服务调用链路追踪
  • 使用Python实现3D曲线拟合
  • Java爬虫的使用案例及简单总结
  • 场景交互与场景漫游-路径漫游(7)
  • 腾讯云轻量数据库性能如何?轻量数据库租用配置价格表
  • JavaScript实现飞机发射子弹详解(内含源码)
  • 超聚变服务器关闭超线程CPU的步骤(完整版)
  • 【开源】基于Vue.js的在线课程教学系统的设计和实现
  • 右键菜单和弹出菜单的区别