当前位置: 首页 > article >正文

Soft TeacherEnd-to-End Semi-Supervised Object Detection with Soft Teacher

Soft Teacher:End-to-End Semi-Supervised Object Detection with Soft Teacher

论文:End-to-End Semi-Supervised Object Detection with Soft Teacher

Abstract

​ 相较于之前更复杂的多阶段方法,本论文提出了一个端到端的半监督目标检测方法。这个端到端的训练在学习过程中逐渐提高伪标签的质量,而越来越准确的伪标签反过来又有利于目标检测的训练。在这个框架中我们还提出了两个简单且有效的技巧:一种是软教师机制,其中每个未标注边界框的分类损失由教师网络产生的分类得分加权;另一种是框抖动方法,用于选择可靠的伪框,以便进行框回归学习。在COCO基准测试中,所提出的方法在不同标注比例(即1%,5%,10%)下显著优于先前的方法。此外,当标注数据相对较多时,我们的方法也表现良好。例如:利用COCO的123k未标注图像,它可以将使用完整COCO训练集训练的40.9mAP基线检测器提升3.6mAP,达到44.5mAP。在基于最先进的Swin Transformer的目标检测器上(在test-dev上达到58.9mAP),它仍然可以显著提高检测准确性,增加1.5mAP,达到60.4mAP,并将实例分割准确性提高1.2mAP,达到52.4mAP。进一步与Object365预训练模型结合后,检测准确性达到了61.3mAP,实例分割准确性达到53.0mAP,创造了新的最先进水平。代码和模型在https://github.com/microsoft/SoftTeacher上公开发布。

在这里插入图片描述

1.Introduction

​ 有一些方法(如STAC)采用了多阶段训练框架,第一阶段使用标注数据训练初始检测器,随后进行未标注数据的伪标签生成过程,并基于伪标签的未标注数据进行再训练步骤。这些多阶段方法在准确性上取得了相对不错的效果,然而,最终性能受到由初始检测器生成的伪标签质量的限制,而该检测器是使用少量标注数据进行训练的,可能不够准确。

​ 为了解决这个问题,我们提出了一种端到端的伪标签基础半监督目标检测框架,该框架在每次迭代中同时对未标注图像进行伪标签生成,并利用这些伪标签和少量标注数据训练检测器。具体而言,带标注和未标注的图像以预设比例随机抽样,形成一个数据批次。这些图像上应用了两个模型,一个进行检测训练,另一个负责为未标注图像标注伪标签。前者被称为学生模型,后者是学生模型的指数移动平均(EMA),即教师模型。这种端到端的方法避免了复杂的多阶段训练方案。此外,它还实现了一种“飞轮效应”,使得伪标签生成和检测训练过程能够相互加强,从而随着训练的进行,两者的性能不断提升。

​ 这种端到端框架的另一个重要优势是,它允许更充分地利用教师模型来指导学生模型的训练,而不仅仅是像以前的方法 [27(STAC), 36] 那样提供“带有硬类别标签的一些生成伪框”。为实现这一观点,提出了一种软教师方法。在这种方法中,教师模型直接评估学生模型生成的所有框候选,而不是提供“伪框”来为这些学生生成的框候选分配类别标签和回归向量。对这些框候选的直接评估使得在学生模型训练中可以使用更多的监督信息。具体而言,我们首先根据检测分数将框候选分为前景/背景,使用高前景阈值以确保正伪标签的高精度,类似于[27(STAC)]中的做法。然而,这种高前景阈值导致许多正框候选者被错误地指定为背景。为了解决这个问题,我们提出使用可靠性度量来加权每个“背景”框候选的损失。我们通过实验证明,教师模型生成的简单检测分数可以很好地作为可靠性度量,并在我们的方法中使用。我们发现这种方法的性能显著优于以前的硬前景/背景分配方法(见表3和表4),因此我们称之为“软教师”。

​ 另一种实现这一思路的方法是通过框抖动技术选择可靠的边界框,用于学生模型的定位分支训练。该方法首先对伪前景框候选进行多次抖动。然后,这些抖动后的框根据教师模型的位置分支进行回归,回归框的方差被用作可靠性度量。具有足够高可靠性的框候选将被用于学生模型的定位分支训练。

2.Related works

图像分类中的半监督学习

​ 图像分类中的半监督学习可以大致分为两组:基于一致性和基于伪标签。基于一致性的方法利用未标记的图像来构造正则化损失,该正则化损失鼓励对同一图像的不同扰动来产生类似的预测。实现扰动的方法有几种,包括扰动模型 [1]、增强图像 [23] 或对抗训练 [19]。在 [11] 中,训练目标通过预测不同的训练步骤进行组合。在 [29] 中,他们通过对模型自身进行集成,而不是模型预测,来发展 [11],即学生模型的指数移动平均(EMA)。伪标签方法 [33, 7, 12](也称为自我训练)通过最初训练的分类模型为未标注图像标注伪标签,然后通过这些伪标注图像来细化检测器。与我们专注于目标检测的方法不同,伪标签在分类图像时不必解决前景/背景标签和框回归的问题。最近,一些研究 [32, 3, 2, 26] 探讨了数据增强在半监督学习中的重要性,这启发我们使用弱增强来生成伪标签,而使用强增强来学习检测模型。

目标检测中的半监督学习

​ 类似于图像分类中的半监督学习,半监督目标检测方法也分为两类:一致性方法 [10, 28] 和伪标签方法 [20, 36, 13, 27, 31]。我们的方法属于伪标签类别。在 [20, 36] 中,不同数据增强的预测结果被集成,以形成未标注图像的伪标签。在 [13] 中,训练了一个选择网络(SelectiveNet)来选择伪标签。在 [31] 中,在未标注图像上检测到的框被粘贴到标注图像上,并对粘贴后的标签图像进行定位一致性估计。由于图像本身被修改,因此在 [31] 中需要非常彻底的检测过程。而在我们的方法中,只处理轻量级的检测头。STAC [27] 提出了使用弱数据增强进行模型训练,同时对伪标签执行强数据增强。然而,与其他伪标签方法 [20, 36, 13, 27, 31] 一样,它也遵循多阶段训练方案。与之相比,我们的方法是一个端到端的伪标签生成框架,避免了复杂的训练过程,同时也实现了更好的性能

目标检测

​ 目标检测专注于设计高效且准确的检测框架。主要有两种主流方法:单阶段目标检测器 [17, 21, 30] 和双阶段目标检测器 [6, 22, 14, 34, 35]。这两类方法之间的主要区别在于是否使用级联来过滤大量的目标候选(提议)。理论上,我们的方法与这两种类型的方法都是兼容的。然而,为了与先前关于半监督目标检测的工作 [28, 27] 进行公平的比较,我们使用 Faster R-CNN [22] 作为我们的默认检测框架来说明我们的方法。

3.Methodology

在这里插入图片描述

​ 上图展示了我们端到端训练框架的概述。框架中有两个模型:学生模型和教师模型。学生模型通过标注图像和使用伪框的未标注图像上的检测损失进行学习。未标注图像有两组伪框,分别用于驱动分类分支和回归分支的训练。教师模型是学生模型的指数移动平均(EMA)。在这个端到端框架中,有两个关键设计:软教师和框抖动。

这些多阶段方法在准确性上取得了相当不错的效果,但最终性能受到初始且可能不准确的检测器生成的伪标签质量的限制,而该检测器是使用少量标注数据训练的。

End-to-End Pseudo-Labeling Framework

​ 我们首先介绍基于伪标签的半监督目标检测的端到端框架。我们的方法遵循教师-学生训练方案。在每次训练迭代中,按照数据采样比例 s r s_r sr 随机采样标注图像和未标注图像,以形成训练数据批次。教师模型用于在未标注图像上生成伪框,学生模型则在标注图像(带有真实标签)和未标注图像(将伪框视为真实标签)上进行训练。因此,总体损失定义为监督损失和无监督损失的加权和:

在这里插入图片描述

其中 L s L_s Ls L u L_u Lu 分别表示有标签图像的监督损失和无标签图像的无监督损失, α \alpha α 控制无监督损失的贡献。两者都通过训练数据批次中相应图像的数量进行了归一化:

在这里插入图片描述

​ 其中 I i l I_i^l Iil 表示第 i i i 个有标签图像, I i u I_i^u Iiu 表示第 i i i 个无标签图像, L cls L_{\text{cls}} Lcls 是分类损失, L reg L_{\text{reg}} Lreg 是边界框回归损失, N l N_l Nl N u N_u Nu 分别表示有标签图像和无标签图像的数量。

​ 在训练开始时,教师模型和学生模型均为随机初始化。随着训练的进行,教师模型会不断通过学生模型进行更新,我们遵循常见的做法 [29, 26],即采用指数移动平均(EMA)策略更新教师模型。与在图像分类中将简单概率分布作为伪标签不同,为目标检测创建伪标签更为复杂,因为一幅图像通常包含多个目标,其注释不仅包括位置还有类别。给定一幅未标注图像,教师模型用于检测对象并预测出数千个框候选。随后,进行非最大抑制(NMS)以消除冗余。虽然大多数冗余框被去除,但仍然会剩下一些非前景候选。因此,只有前景得分高于阈值的候选框会被保留作为伪框。

​ 为了生成高质量的伪框并促进学生模型的训练,我们借鉴了 FixMatch [26],这是半监督图像分类任务中的最新进展。对学生模型的检测训练使用强增强,而对教师模型的伪标签生成则使用弱增强。

​ 理论上,我们的框架适用于主流目标检测器,包括单阶段目标检测器 [15, 17, 21, 30] 和双阶段目标检测器 [22, 9, 5, 35, 34]。为了与先前的方法进行公平的比较,我们使用 Faster R-CNN [22] 作为我们的默认检测框架来说明我们的方法。

Soft Teacher

​ 检测器的性能依赖于伪标签的质量。在实际中,我们发现使用较高的前景得分阈值来过滤掉大多数低置信度的学生生成框候选可以取得比使用较低阈值更好的结果。如表9所示,当阈值设置为0.9时,性能最佳。然而,尽管严格的标准(较高的阈值)可以提高前景精度,但保留的框候选的召回率(正确识别出的所有正例占所有实际正例的比例)也会迅速下降。如图3(a)所示,当前景阈值设置为0.9时,召回率较低,仅为33%,而精度则达到了89%。在这种情况下,如果我们使用学生生成框候选与教师生成伪框之间的IoU来分配前景和背景标签(就像在提供真实框注释的情况下,通用目标检测框架所做的那样),一些前景框候选可能会被错误地分配为负样本,这可能会妨碍训练并影响性能(将真正的前景对象标记为背景会向模型提供错误的信息,这会导致模型学习到不正确的模式。模型可能会逐渐忽略这些实际是前景的对象,从而降低其检测性能)。

​ 为了解决这个问题,我们提出了一种软教师方法,它利用来自教师模型的更丰富的信息,这得益于端到端框架的灵活性。具体而言,我们评估每个学生生成框候选被认为是实际背景的可靠性,然后用其来加权背景分类损失。给定两个框集合 { b i f g } \{b_i^{fg}\} {bifg} { b i b g } \{b_i^{bg}\} {bibg},其中 { b i f g } \{b_i^{fg}\} {bifg}表示被分配为前景的框,而 { b i f g } \{b_i^{fg}\} {bifg}表示被分配为背景的框,带有可靠加权的未标注图像的分类损失定义为:

在这里插入图片描述
在这里插入图片描述

​ 其中 G cls G_{\text{cls}} Gcls 表示用于分类的(由教师模型生成的)伪框集合, l cls l_{\text{cls}} lcls 是框分类损失, r j r_j rj 是第 j j j 个背景框候选的可靠性评分, N b fg N_b^{\text{fg}} Nbfg N b bg N_b^{\text{bg}} Nbbg 分别是框集合 { b i fg } \{ b_i^{\text{fg}} \} {bifg} { b i bg } \{ b_i^{\text{bg}} \} {bibg} 中的候选框数量。

​ 估计可靠性评分 ( r ) 是一个挑战。我们通过实验证明,教师模型在弱增强图像上产生的背景分数可以很好地作为 ( r ) 的代理指标,并且在我们的端到端训练框架中很容易获得。具体来说,给定一个由学生模型生成的候选框,其背景分数可以通过使用教师模型(BG-T)通过其检测头处理该框来简单地获得。值得注意的是,这种方法不像广泛使用的硬负样本挖掘方法(例如OHEM [25] 或 Focal Loss [15]),更像是“简单的”负样本挖掘。为了比较,论文还检查了其他几种指标。这里不展开了。

Box Jittering

在这里插入图片描述

如图3(b)所示,候选框的位置精度和前景分数之间并没有显示出强烈的正相关性,这意味着具有高前景分数的框可能不会提供准确的位置信息。这表明根据前景分数选择教师生成的伪框并不适合用于边界框回归,需要一个更好的标准我们引入了一种直观的方法来通过测量回归预测的一致性来估计候选伪框的位置可靠性。具体来说,给定一个由教师模型生成的伪框候选 b i b_i bi,我们在 b i b_i bi 周围采样一个抖动框,并将这个抖动框输入教师模型以获得精炼后的框 b ^ i \hat{b}_i b^i,其公式如下:

在这里插入图片描述

上述过程重复多次以收集一组 N jitter N_{\text{jitter}} Njitter 个精炼后的抖动框 { b ^ i , j } \{ \hat{b}_{i,j} \} {b^i,j},我们定义位置可靠性为边界框回归的方差:

在这里插入图片描述

在这里插入图片描述

其中 σ k \sigma_k σk 是精炼抖动框集合 { b ^ i , j } \{\hat{b}_{i,j}\} {b^i,j} 中第 k k k 个坐标的标准差, σ ^ k \hat{\sigma}_k σ^k 是归一化的 σ k \sigma_k σk h ( b i ) h(b_i) h(bi) w ( b i ) w(b_i) w(bi) 分别表示候选框 b i b_i bi 的高度和宽度。较小的框回归方差表示更高的定位可靠性。然而,在训练过程中,计算所有伪框候选的框回归方差是不可承受的。因此,在实际操作中,我们仅计算前景分数大于 0.5 的框的可靠性。通过这种方式,需要估计的框的数量从平均数百个减少到每张图像约 17 个,从而计算成本几乎可以忽略不计。

在图3©中,我们展示了定位精度与我们的框回归方差之间的相关性。与前景分数相比,框回归方差能够更好地衡量定位精度。这促使我们选择那些框回归方差小于某个阈值的框候选作为伪标签,以训练无标签图像上的框回归分支。给定用于训练无标签数据上框回归的伪框 G r e g G_{reg} Greg,回归损失定义为:
L r e g u = 1 N f g b ∑ i = 1 N f g b l r e g ( b f g i , G r e g ) , (10) L_{reg}^u = \frac{1}{N_{fg}^b} \sum_{i=1}^{N_{fg}^b} l_{reg}(b_{fg}^i, G_{reg}), \tag{10} Lregu=Nfgb1i=1Nfgblreg(bfgi,Greg),(10)

其中 b f g i b_{fg}^i bfgi 是被指定为前景的第 i i i 个框, N f g b N_{fg}^b Nfgb 是前景框的总数, l r e g l_{reg} lreg 是框回归损失。 因此,将公式(4)和公式(10)代入公式(3),无标签图像的损失为:
L u = 1 N u ∑ i = 1 N u ( L c l s u ( I i u , G i c l s ) + L r e g u ( I i u , G i r e g ) ) . (11) L^u = \frac{1}{N^u} \sum_{i=1}^{N^u} (L_{cls}^u(I_i^u, G_i^{cls}) + L_{reg}^u(I_i^u, G_i^{reg})). \tag{11} Lu=Nu1i=1Nu(Lclsu(Iiu,Gicls)+Lregu(Iiu,Gireg)).(11)

这里我们使用伪框 G c l s G^{cls} Gcls G r e g G^{reg} Greg 作为损失函数的输入,强调了在我们的方法中用于分类和框回归的伪框是不同的。

4.Experiments

这里简单过一下,详细数据看原论文

4.1 数据集和评估协议

本研究在MS-COCO基准上验证了所提出的方法。使用了两个训练数据集,train2017包含118,000张标记图像,unlabeled2017包含123,000张未标记图像。此外,还提供了5,000张图像的val2017集用于验证。为了验证性能,研究采用了两种设置:

  • 部分标记数据:首先引入了1%、5%和10%的train2017图像作为标记训练数据,剩余未采样图像作为未标记数据。
  • 完全标记数据:使用全量的标记数据进行训练。
4.2 实验设置

研究采用了基于教师-学生模型的伪标签训练框架。教师模型通过指数移动平均(EMA)策略更新,学生模型则进行标记和未标记图像的训练。每次训练迭代随机采样标记和未标记图像,生成训练数据批次。

4.3 系统比较

本部分与以往的多阶段框架进行了比较。结果表明,转变为端到端框架后,性能提高了1.3点。采用EMA更新策略后,方法的性能进一步提升,达到31.2 mAP。

4.4 消融研究

验证了软教师和框选抖动的效果。实验结果显示,整合软教师后,性能提高了2.4点,应用框选抖动后性能达到34.2 mAP,比E2E+EMA高出3点。不同指标在软教师中的效果也进行了研究,结果表明背景分数的预测效果最佳。

这些实验表明,所提出的方法在多种标记比率下均显著超越了现有的最先进技术。

5.Conclusion

​ 本文提出了一种用于半监督对象检测的端到端训练框架,该框架摒弃了以前方法采用的复杂的多阶段模式。我们的方法通过利用学生模型进行检测训练,同时提高了检测器和伪标签,并利用学生模型通过指数移动平均策略不断更新的教师模型进行在线伪标签生成。在端到端训练中,我们提出了两种简单的技术,分别称为软教师和框抖动,以促进有效利用教师模型。所提出的框架在MS-COCO基准测试种,在部分标注数据和完全标注数据设置下,都远远优于最先进的方法。


http://www.kler.cn/a/376603.html

相关文章:

  • wireshark抓包查看langchain的ChatOpenAI接口发送和接收的数据
  • P3-2.【结构化程序设计】第二节——知识要点:多分支选择语句
  • 详解ARM64可执行程序的生成过程
  • 猫头虎分享Python 编码转换库:处理 JSONL 编码格式转换的最佳实践
  • 如何压缩pdf文件的大小?5分钟压缩pdf的方法推荐
  • kubeadm安装k8s
  • 计算机网络-总线型以太网(ethernet)-知识点小结
  • 基于STM32的智能宠物喂食系统设计
  • Discuz中的关键全局变量`$_G`
  • 快速上手 Windows 命令:简化你的工作流程
  • xlrd.biffh.XLRDError: Excel xlsx file; not supported
  • 你真的了解Canvas吗--解密十三【ZRender篇】
  • 简单了解前缀树/字典树(Trie树)C++代码
  • 三维重建:AI 根据图像信息还原物体三维形状的技术
  • postgresql14源码编译安装
  • 使用AMD GPU和ONNX Runtime高效生成图像与Stable Diffusion模型
  • 【前端】在 Next.js 开发服务器中应该如何配置 HTTPS?
  • 【前端】项目中遇到的问题汇总(长期更新)
  • 【Java】方法的使用 —— 语法要求、方法的重载和签名、方法递归
  • 无源元器件-磁珠选型参数总结