基于transformer的目标检测:DETR
目录
一、背景介绍
二、DETR的工作流程
三、DETR的架构
1. 损失函数
2. 网络框架讲解及举例
一、背景介绍
在深度学习和计算机视觉领域,目标检测一直是一个核心问题。传统方法依赖于复杂的流程和手工设计的组件,如非极大值抑制(nms)和锚点(anchor)生成,这些都需要对任务有深入的先验知识。然而,DETR的出现,为我们提供了一种全新的视角,将目标检测视为一个直接的集合预测问题,实现了端到端的检测。
DETR,即DEtection TRansformer,是一种创新的框架,它通过简化检测流程,消除了对许多手工设计组件的需求。这一方法的核心在于其集合基础的全局损失函数,它通过二分图匹配(bi-partite matching)强制实现唯一的预测结果,以及其Transformer编码器-解码器架构。DETR利用一组固定且学习得到的对象查询(object queries),推理对象之间的关系和全局图像上下文,直接并行输出最终的预测集合。
额外提一下,最初的DETR训练很慢,原作者训练了500个epoch。不过这些问题在后续都被慢慢解决了。
二、DETR的工作流程
根据原文的流程图,训练部分可以分成4部分:
-
特征提取(Feature Extraction):
在这个阶段,输入的图像通过一个卷积神经网络(CNN)来提取图像特征。这些特征是图像的高级表示,捕捉了图像中的重要信息,为后续的编码阶段提供基础。 -
编码(Encoding):
编码阶段使用Transformer的编码器部分。编码器接收CNN提取的图像特征,并将其转换为一系列特征表示。这些特征表示能够捕捉图像的全局信息,为解码阶段提供丰富的上下文。 -
解码(Decoding):
解码阶段由Transformer的解码器部分负责。解码器利用编码器提供的特征表示,并通过自注意力机制来预测目标的类别和位置。解码器的输出是一组边界框预测,每个预测对应一个可能的目标。 -
损失计算(Loss Calculation):
在损失计算阶段,DETR使用二分图匹配损失来优化模型。这种损失函数确保预测的边界框与真实边界框之间建立一对一的匹配关系。如果一个预测没有与任何真实边界框匹配,那么它应该被分类为“无对象”类别。
在推理过程中,解码器同样生成一组边界框预测。但是,与训练不同,推理过程中不会使用二分图匹配损失。相反,会应用一个置信度阈值(例如0.7),只保留那些置信度高于这个阈值的预测。这样可以过滤掉那些不太确定的预测,只保留模型认为更有可能正确的预测。
三、DETR的架构
1. 损失函数
在DETR中,不管输入什么图片,都会在decoder输出100个框,但是一张图片中正常只会有几个target。一个预测框对应一个真实目标,模型怎么知道哪个预测框对应真实的目标框呢?于是该问题转换为了一个二分图匹配问题。
举个例子,假设我们有一张图片,其中包含3个真实目标,而DETR的解码器输出了100个预测框。二分图匹配问题就是要在这100个预测框和3个真实目标之间找到最佳的匹配关系,使得匹配的总成本(预测框与真实目标之间的损失函数)最小化。
如图所示,预测框和真值的对应问题可以列成一个矩阵,矩阵中放的是损失值,然后利用现成的匈牙利算法解决这个二分图匹配问题。
这里的损失函数分为类别损失和定位损失。
每个预测框都会有一个类别损失,但只有那些通过二分图匹配与真实目标框匹配的预测框才会有位置损失。
2. 网络框架讲解及举例
根据DETR流程图,举个例子描述输入图像经过DETR模型处理的精炼详细步骤:
-
输入图像: 输入一张尺寸为 3×800×1066 的图像。
-
CNN骨干网络: 使用传统的CNN骨干网络提取图像特征,得到一个 2048×25×34 的特征图,25*34是输入图像大小的1/32。
-
特征降维: 通过1*1卷积将特征图降维到 256×25×34。
-
位置编码: 为降维后的特征图添加位置编码,位置编码的尺寸同样为 256×25×34。
-
特征图拉平: 将添加了位置编码的特征图拉平成一维向量,尺寸为 850×256,800是像素点个数。
-
Transformer编码器: 将拉平后的特征向量输入到Transformer编码器中进行编码。
-
Transformer解码器输入: 解码器的输入是一个固定大小的向量,尺寸为 100×256,这代表了模型将输出100个预测框。通过QKV的计算公式可知,decoder的输出与输入相同。
-
交叉注意力: 解码器中的每个输出嵌入与编码器的输出进行交叉注意力操作,以结合编码器的全局信息。
-
解码器输出: 交叉注意力操作后,解码器输出尺寸为 100×256 的向量。
-
预测头: 将解码器的输出通过一个共享的前馈网络(FFN),每个FFN预测一个类别(包括“无对象”类别)和边界框。
-
最终预测: FFN输出最终的100个预测框,每个预测包括类别和边界框信息。
这个流程展示了DETR如何通过Transformer架构实现端到端的目标检测,从输入图像到最终的预测框,整个过程不需要传统的目标检测组件,如锚点生成或非极大值抑制。