当前位置: 首页 > article >正文

目标检测中的非极大值抑制(NMS)原理与实现解析

一、技术背景

在目标检测任务中,模型通常会对同一目标生成多个重叠的候选框(如锚框或预测框)。非极大值抑制(Non-Maximum Suppression, NMS) 是一种关键的后处理技术,用于去除冗余的检测结果,保留置信度最高且位置最优的边界框。本文将通过一段Python代码解析NMS的核心实现逻辑,并演示其在OpenCV环境中的实际效果。


二、算法核心思想

NMS的核心是通过以下步骤筛选边界框:

  1. 按置信度排序:优先处理置信度最高的预测框。
  2. 计算交并比(IoU):与当前框重叠度高的候选框将被抑制。
  3. 迭代筛选:重复上述过程直至处理完所有候选框。
    在这里插入图片描述

三、代码实现解析

1. 输入数据结构

输入为字典类型 predicts_dict,键为类别名称,值为该类别对应的边界框列表。每个边界框格式为 [x1, y1, x2, y2, score],表示左上角和右下角坐标及置信度。

predicts_dict = {'black1': [[83,54,165,163,0.8], [67,48,118,132,0.5], ...]}

2. 核心函数 non_max_suppress

def non_max_suppress(predicts_dict, threshold):
    for object_name, bbox in predicts_dict.items():
        bbox_array = np.array(bbox, dtype=float)
        # 提取坐标和置信度
        x1, y1, x2, y2, score = bbox_array[:,0], bbox_array[:,1], bbox_array[:,2], bbox_array[:,3], bbox_array[:,4]
        # 按置信度降序排序
        order = score.argsort()[::-1]
        area = (x2 - x1 + 1) * (y2 - y1 + 1)
        keep = []  # 保留的索引列表
        while order.size > 0:
            i = order[0]  # 当前最高分框
            keep.append(i)
            # 计算IoU
            xx1 = np.maximum(x1[i], x1[order[1:]])
            yy1 = np.maximum(y1[i], y1[order[1:]])
            xx2 = np.minimum(x2[i], x2[order[1:]])
            yy2 = np.minimum(y2[i], y2[order[1:]])
            inter = np.maximum(0.0, xx2 - xx1 + 1) * np.maximum(0.0, yy2 - yy1 + 1)
            iou = inter / (area[i] + area[order[1:]] - inter)
            # 保留IoU低于阈值的框
            inds = np.where(iou <= threshold)[0]
            order = order[inds + 1]
        # 更新筛选后的结果
        predicts_dict[object_name] = bbox_array[keep].tolist()
    return predicts_dict

关键步骤说明:

  • 坐标提取与排序:将边界框转换为NumPy数组后,按置信度降序排列。
  • IoU计算:通过最大-最小值法计算交集区域,公式为:
    IoU = Intersection Union − Intersection \text{IoU} = \frac{\text{Intersection}}{\text{Union} - \text{Intersection}} IoU=UnionIntersectionIntersection
  • 动态索引更新:通过 order = order[inds + 1] 跳过被抑制的框,逐步缩小处理范围。
3. 可视化测试代码
  • 绘制原始预测框:在全黑图像上绘制未经过NMS处理的边界框及置信度。
  • NMS处理与对比:调用 non_max_suppress 后,在另一窗口展示抑制后的结果。
# 绘制原始框
for box in bbox:
    cv2.rectangle(img, (x1, y1), (x2, y2), (255,255,255), 2)
# 处理并绘制NMS后的框
predicts_dict_nms = non_max_suppress(predicts_dict, 0.1)
for box in bbox_nms:
    cv2.rectangle(img_cp, (x1, y1), (x2, y2), (255,255,255), 2)

四、优化与注意事项

  1. 阈值选择:阈值过小可能导致漏检,过大则冗余框增多(通常目标检测任务中阈值设为0.5)。
  2. 多类别处理:代码支持同时对多个类别独立进行NMS,如输入 black1black2 两个类别的预测结果。
  3. 坐标修正:代码中 +1 的操作是为了避免零宽度/高度,确保面积计算正确。
import cv2
import random
import numpy as np

def non_max_suppress(predicts_dict, threshold):
	for object_name, bbox in predicts_dict.items():  # 对每一个类别分别进行NMS;一次读取一对键值(即某个类别的所有框)
		bbox_array = np.array(bbox, dtype=np.float)
		print(bbox_array)
		# 下面分别获取框的左上角坐标(x1,y1),右下角坐标(x2,y2)及此框的置信度;这里需要注意的是图像左上角可以看做坐标点(0,0),右下角可以看做坐标点(1,1),也就是说从左往右x值增大,从上往下y值增大
		x1 = bbox_array[:, 0]
		y1 = bbox_array[:, 1]
		x2 = bbox_array[:, 2]
		y2 = bbox_array[:, 3]
		scores = bbox_array[:, 4]  # class confidence, ndarray
		print(scores, type(scores))        
		order = scores.argsort()[::-1]  # argsort函数返回的是数组值从小到大的索引值,[::-1]表示取反。即这里返回的是数组值从大到小的索引值
		areas = (x2 - x1 + 1) * (y2 - y1 + 1)  # 当前类所有框的面积(python会自动使用广播机制,相当于MATLAB中的.*即两矩阵对应元素相乘);x1=3,x2=5,习惯上计算x方向长度就是x=3、4、5这三个像素,即5-3+1=3,而不是5-3=2,所以需要加1
		print(areas, type(areas))    
		keep = []
		
		# 按confidence从高到低遍历bbx,移除所有与该矩形框的IoU值大于threshold的矩形框
		while order.size > 0:
			i = order[0]
			keep.append(i)  # 保留当前最大confidence对应的bbx索引
			# 获取所有与当前bbx的交集对应的左上角和右下角坐标,并计算IoU(注意这里是同时计算一个bbx与其他所有bbx的IoU)
			xx1 = np.maximum(x1[i], x1[order[1:]])  # 最大置信度的左上角坐标分别与剩余所有的框的左上角坐标进行比较,分别保存较大值;因此这里的xx1的维数应该是当前类的框的个数减1
			print("xx1:", xx1)
			yy1 = np.maximum(y1[i], y1[order[1:]])
			xx2 = np.minimum(x2[i], x2[order[1:]])
			yy2 = np.minimum(y2[i], y2[order[1:]])
			inter = np.maximum(0.0, xx2-xx1+1) * np.maximum(0.0, yy2-yy1+1)
			iou = inter / (areas[i] + areas[order[1:]] - inter)  # 注意这里都是采用广播机制,同时计算了置信度最高的框与其余框的IoU
			print(iou, type(iou))
			print(np.where(iou <= threshold))
			inds = np.where(iou <= threshold)[0]  # 保留iou小于等于阙值的框的索引值
			print('inds:', inds)
			order = order[inds + 1]  # 将order中的第inds+1处的值重新赋值给order;即更新保留下来的索引,加1是因为因为没有计算与自身的IOU,所以索引相差1,需要加上
		bbox = bbox_array[keep]
		predicts_dict[object_name] = bbox.tolist()
	return predicts_dict

# 下面在一张全黑图片上测试非极大值抑制的效果
img = np.zeros((600,600), np.uint8)
predicts_dict = {'black1': [[83, 54, 165, 163, 0.8], [67, 48, 118, 132, 0.5], [91, 38, 192, 171, 0.6]]}
# predicts_dict = {'black1': [[83, 54, 165, 163, 0.8], [67, 48, 118, 132, 0.5], [91, 38, 192, 171, 0.6]], 'black2': [[59, 120, 137, 368, 0.12], [54, 154, 148, 382, 0.13]] }
"""
# 在全黑的图像上画出设定的几个框
for object_name, bbox in predicts_dict.items():
	for box in bbox:
		x1, y1, x2, y2, score = box[0], box[1], box[2], box[3], box[-1]
		y_text = int(random.uniform(y1, y2))  # uniform()是不能直接访问的,需要导入 random 模块,然后通过 random 静态对象调用该方法。uniform() 方法将随机生成下一个实数,它在 [x, y) 范围内
		cv2.rectangle(img, (x1, y1), (x2, y2), (255, 255, 255), 2)
		cv2.putText(img, str(score), (x2 - 30, y_text), 2, 1, (255, 255, 0))
	cv2.namedWindow("black1_roi")  # 创建一个显示图像的窗口
	cv2.imshow("black1_roi", img)  # 在窗口中显示图像;注意这里的窗口名字如果不是刚刚创建的窗口的名字则会自动创建一个新的窗口并将图像显示在这个窗口
	cv2.waitKey(0)  # 如果不添这一句,在IDLE中执行窗口直接无响应。在命令行中执行的话,则是一闪而过。
cv2.destroyAllWindows()  # 最后释放窗口是个好习惯!
"""
# 在全黑图片上画出经过非极大值抑制后的框
img_cp = np.zeros((600,600), np.uint8)
predicts_dict_nms = non_max_suppress(predicts_dict, 0.1)
for object_name, bbox in predicts_dict_nms.items():
	for box in bbox:
		x1, y1, x2, y2, score = int(box[0]), int(box[1]), int(box[2]), int(box[3]), box[-1]
		y_text = int(random.uniform(y1, y2))  # uniform()是不能直接访问的,需要导入 random 模块,然后通过 random 静态对象调用该方法。uniform() 方法将随机生成下一个实数,它在 [x, y) 范围内
		cv2.rectangle(img_cp, (x1, y1), (x2, y2), (255, 255, 255), 2)
		cv2.putText(img_cp, str(score), (x2 - 30, y_text), 2, 1, (255, 255, 0))
	cv2.namedWindow("black1_nms")  # 创建一个显示图像的窗口
	cv2.imshow("black1_nms", img_cp)  # 在窗口中显示图像;注意这里的窗口名字如果不是刚刚创建的窗口的名字则会自动创建一个新的窗口并将图像显示在这个窗口
	cv2.waitKey(0)  # 如果不添这一句,在IDLE中执行窗口直接无响应。在命令行中执行的话,则是一闪而过。
cv2.destroyAllWindows()  # 最后释放窗口是个好习惯!

http://www.kler.cn/a/594212.html

相关文章:

  • Java EE 进阶:MyBatis
  • Ubuntu-server-16.04 设置多个ip和多个ipv6
  • 【动态规划】矩阵连乘问题 C++(附代码实例和复杂度分析)
  • Java集合基础知识
  • C++模版(进阶)
  • 0321美团实习面试——技能大致内容
  • Java使用FFmpegFrameGrabber进行视频拆帧,结合Thumbnails压缩图片保存到文件夹
  • C# ManualResetEvent‌的高级用法
  • python字符级差异分析并生成 Word 报告 自然语言处理断句
  • Qt6+QML实现Windows屏幕录制
  • 【软考-架构】8.4、信息化战略规划-CRO-SCM-应用集成-电子商务
  • 【STM32】I²CC通信外设硬件I²CC读写MPU6050(学习笔记)
  • 【go】Go语言设计模式:函数与方法的权衡
  • Oracle 19c更换临时表空间操作步骤
  • STM32学习-Day4-寄存器开发流程
  • 电源电路篇
  • 深入理解 Vue 3 项目结构与运行机制
  • MySQL数据库入门到大蛇尚硅谷宋红康老师笔记 高级篇 part10
  • redis的key是如何找到对应存储的数据原理
  • 微软产品的专有名词和官方视频教程