当前位置: 首页 > article >正文

DETR论文阅读

1. 动机

传统的目标检测任务需要大量的人工先验知识,例如预定义的先验anchor,NMS后处理策略等。这些人工先验知识引入了很多人为因素,且较难处理。如果能够端到端到直接生成目标检测结果,将会使问题变得很优雅。

2. 主要贡献

提出了一个新的目标函数,用二分图匹配的方式强制模型输出一个独一无二的目标框,避免了传统方法中的非极大值抑制。

首次把transformer引入到目标检测领域。

简化了检测流程,有效地消除了对许多人工设计组件的需求,如NMS或anchor生成。实现了端到端的目标检测。

3. 模型结构

DETR将目标检测看作一种set prediction问题,并提出了一个十分简洁的目标检测pipeline,即CNN提取基础特征,送入Transformer做关系建模,得到的输出通过二分图匹配算法与图片上的ground truth做匹配。

先使用CNN对图像进行特征提取,把得到的二维特征转换到一维,然后送入transformer的encoder-decoder结构之中。然后利用decoder的结果预测检测框的输出。

将目标检测看作set prediction问题
DETR算法流程

3.1. backbone

DETR基础版本的backbone使用torchvision上预训练过的ResNet-50,训练时冻结BN层参数。设输入img维度为(3, H, W),经过backbone后变为(2048, \frac{H}{32}, \frac{W}{32})。此外在后续实验阶段论文还使用了ResNet-101以及改进过的DC5版本。

3.2. transfomer

CNN提取的特征拉直(flatten)后加入位置编码(positional encoding)得到序列特征,作为Transformer encoder的输入。Transformer中的attention机制具有全局感受野,能够实现全局上下文的关系建模,其中encoder和decoder均由多个encoder、decoder层堆叠而成。每个encoder层中包含self-attention机制,每个decoder中包含self-attention和cross-attention。

3.3. object queries

输出100个目标框和它的分类。设计了一套全新的损失函数,能够在训练的时候把与ground truth匹配的框算作为正样本,不匹配的框作为负样本。在推理的时候使用一个阈值来选择结果,预测得分高于阈值的作为输出,低于阈值的忽略。

transformer解码器中的序列是object queries。每个query对应图像中的一个物体实例(包含背景实例 ϕ),它通过cross-attention从编码器输出的序列中对特定物体实例的特征做聚合,又通过self-attention建模该物体实例域其他物体实例之间的关系。最终,FFN基于特征聚合后的object queries做分类的检测框的回归。

值得一提的是,object queries是可学习的embedding,与当前输入图像的内容无关(不由当前图像内容计算得到)。论文中对不同object query在COCO数据集上输出检测框的位置做了统计(如上图所示),可以看不同object query是具有一定位置倾向性的。对object queries的理解可以有多个角度。首先,它随机初始化,并随着网络的训练而更新,因此隐式建模了整个训练集上的统计信息。其次,在目标检测中每个object query可以看作是一种可学习的动态anchor,可以发现,不同于Faster RCNN, RetinaNet等方法在特征的每个像素上构建稠密的anchor不同,detr只用少量稀疏的anchor(object queries)做预测,这也启发了后续的一系列工作。

3.4. 损失函数

DETR有两种损失:(1)二分图匹配阶段的损失,用于确定最优匹配。(2)在最优匹配下的模型损失。

3.4.1. 二分图匹配

我们知道DETR每次输出包含N=100个预测目标的集合,由于GT集合元素个数小于N,我们用\phi将GT集合元素个数填充至N个。那么预测集合与GT集合总的二分图匹配个数就有A_N^N个,我们所有匹配的集合设为\Sigma_N。我们要做的就是找到这个最优的匹配,公式如下图所示。

\hat{\sigma}=argmin\sum_i^NL_{match}\left(y_i,\hat{y}_{\sigma(i)}\right)

\hat{\sigma}即为最优匹配,y_i\hat{y}_{\sigma(i)}分别代表GT值和预测值。

以往的一些研究包括本论文都是使用的匈牙利算法 Hungarian algorithm 来计算最优匹配的。

3.4.2. 匹配后损失计算

有了最优的匹配\hat{\sigma}后,便要计算模型的损失,公式如下。

L_{Hungarian}\left(y,\hat{y}\right)=\sum_{i=1}^N\left[-log\:\hat{p}_{\hat{\sigma}(i)}\left(c_i\right)+1_{\left\{c_i\neq\phi\right\}}L_{box}\left(b_i,\hat{b}_{\hat{\sigma}(i)}\right)\right]$$ $$L_{box}\left(b_i,\hat{b}_{\hat{\sigma}(i)}\right)=\lambda_{iou}L_{iou}\left(b_i,\hat{b}_{\hat{\sigma}(i)}\right)+\lambda_{L1}\left|\left|b_i-\hat{b}_{\hat{\sigma}(i)}\right|\right|_1

其中y_i=(c_i,b_i),分别代表GT类别和bbox参数{x,y,w,h};在最优匹配\hat{\sigma}下,预测的类别分数和bbox参数分别为\hat{p}_{\hat{\sigma}(i)}\left(c_i\right)\hat{b}_{\hat{\sigma}}\left(i\right)

\lambda_{iou}\lambda_{L1}为超参数用于调节权重。

参考文献

End-to-end object detection with transformers

DETR目标检测新范式带来的思考 - 知乎

DETR(DEtection TRansformer)要点总结-CSDN博客

DETR 论文精读【论文精读】_哔哩哔哩_bilibili


http://www.kler.cn/a/507215.html

相关文章:

  • LabVIEW 程序中的 R6025 错误
  • 「刘一哥GIS」系列专栏《GRASS GIS零基础入门实验教程(配套案例数据)》专栏上线了
  • 软件设计大致步骤
  • HarmonyOS Next 实现登录注册页面(ARKTS) 并使用Springboot作为后端提供接口
  • Windows 正确配置android adb调试的方法
  • 【学习笔记】理解深度学习的基础:机器学习
  • openCV项目实战——信用卡数字识别
  • Vue 开发者的 React 实战指南:测试篇
  • CMake构建C#工程(protobuf)
  • Web 实时消息推送的七种实现方案
  • SpringBoot链接Kafka
  • 在 .NET 9 中使用 Scalar 替代 Swagger
  • 基于 Python 的财经数据接口库:AKShare
  • NFTScan | 01.06~01.12 NFT 市场热点汇总
  • 图论基础,如何快速上手图论?
  • Redis哨兵模式搭建示例(配置开机自启)
  • 代码随想录25 回溯算法
  • 78_Redis网络模型
  • K8S--边车容器
  • 如何Python机器学习、深度学习技术提升气象、海洋、水文?
  • 2025第3周 | json-server的基本使用
  • Linux下使用MySql数据库
  • 采用海豚调度器+Doris开发数仓保姆级教程(满满是踩坑干货细节,持续更新)
  • 浏览器中的Markdown编辑器
  • 【2024年华为OD机试】(B卷,100分)- 相对开音节 (Java JS PythonC/C++)
  • java常用开发工具类