当前位置: 首页 > article >正文

【论文精读】LPT: Long-tailed prompt tuning for image classification

🌈 个人主页:十二月的猫-CSDN博客
🔥 系列专栏: 🏀论文精读_十二月的猫的博客-CSDN博客

💪🏻 十二月的寒冬阻挡不了春天的脚步,十二点的黑夜遮蔽不住黎明的曙光 

目录

1. 摘要

2. 介绍

3. 相关工作

3. 初步研究

3.1 VPT 性能调查

3.2 及时调整分析

4. 长尾提示调谐

4.1 阶段1:共享提示调优

4.2 阶段2:组提示调优

4.3 损失函数

调整作用


1. 摘要

       对于长尾分类任务,大多数研究通常在大规模(无标注)数据集上预训练一个大型模型,然后对整个预训练模型进行微调,以适应长尾数据。对整个预训练模型进行微调虽然很有前途,但往往存在计算成本高、为不同任务部署不同模型的问题,以及因过度拟合长尾数据的某些特征而削弱泛化能力的问题。为了缓解这些问题,我们针对长尾分类任务提出了一种有效的长尾提示调整(LPT)方法。LPT 将几个可训练的提示引入到冻结的预训练模型中,使其适应长尾数据。为了达到更好的效果,我们将提示分为两组:1)针对整个长尾数据集的共享提示,用于学习一般特征,并使预训练模型适应目标长尾领域;2)针对特定组的提示,用于收集具有相似特征的样本的特定组特征,同时赋予预训练模型细粒度的判别能力。然后,我们设计了一种两阶段训练范式来学习这些提示。在第一阶段,我们通过传统的监督提示调整来训练共享提示,使预先训练好的模型适应所需的长尾领域。在第二阶段,我们使用学习到的共享提示作为查询,从特定组提示集中为一组相似样本选择一个小的最佳匹配集,以挖掘这些相似样本的共同特征,然后使用双重采样策略和非对称高斯云化 Logit 损失优化这些提示。LPT 在固定预训练模型的基础上只对少数提示进行微调,通过存储少数提示来降低训练成本和部署成本,并享有预训练模型的强大泛化能力。实验表明,在各种长尾基准上,LPT 只需额外增加 ∼1.1% 的可训练参数,就能获得与之前的整体模型微调方法相当甚至更高的性能,而且对领域偏移具有更强的鲁棒性。

论文链接:

        1、Rethinking the Value of Labels for Improving Class-Imbalanced Learning(2020 

NeurIPS)

背景知识:   

        1、从背景知识的论文中已知无标注数据集上的预训练同样对模型处理长尾问题有很多好处。

        2、目前长尾问题都是基于预训练的大模型,再用具体的长尾数据在下游任务中去微调。但是微调的代价往往比较大。

核心思想:

        1、利用Prompt代替微调。将Prompt分为两组,一组是所有长尾数据集共享的提示,另一组是具有相似特征的样本的特定提示。

        2、两阶段训练来分别学习两组提示。第一阶段是学习通用的prompt,第二阶段是学习组内的prompt。

2. 介绍

       在深度学习时代,从长尾数据中学习(Cui et al, 2019; Kang et al, 2020; Zhang et al, 2021b)是一项非常具有挑战性的工作,因为由于多数类的训练样本数量过多,网络往往会过度过度拟合多数类,而忽略少数类。为了消除这种负面影响,以往的方法主要集中在三个方面: 1) 对长尾数据分布进行重采样(Kang 等人,2020;Li 等人,2022;2021a;Ren 等人,2020),以实现每个小批数据中所有类别之间的平衡;2) 对训练损失重新加权(Cui 等人,2019;Li 等人,2022; 3) 专门设计的解耦训练(Kang 等人,2020 年)、知识提炼(Li 等人,2021b)或集合学习(Zhou 等人,2020 年;Wang 等人,2020 年)。

       这些方法虽然在一定意义上缓解了长尾学习的负面影响,取得了更好的整体性能,但通常需要从头开始训练特征提取器和线性分类器,或者在大规模数据集(如 ImageNet,Deng 等人,2009 年)上预先训练模型,因此存在三个问题。首先,为了适应长尾数据,整个模型的微调需要更高的额外训练成本。其次,对整个模型进行微调也会损害预训练模型的泛化能力,因为在大规模数据集上训练的预训练模型往往数据丰富,对各种特征具有很强的判别能力。而微调往往会削弱这种泛化能力,因为微调会过度拟合长尾数据的某些特征,而且很难处理长尾学习中经常出现的域偏移或失分布数据。最后,微调也会导致不同的学习任务产生截然不同的模型,这就破坏了模型的兼容性,增加了实际部署成本。

问题:

        1、所有的解决长尾问题的方法都需要a、b中的一个:a. 头从训练特征提取器和分类器;b. 选用预训练好的特征提取器和分类器进行微调

        2、考虑到预训练的大规模无标注样本对长尾学习有很多好处,目前大多采用b方案。同时结合类重平衡、损失函数设计、解耦训练等方法去微调。

        3、但是微调存在三方面的问题:a. 大模型微调成本很高;b. 微调会削弱原模型的泛化能力,因为尾部类过少容易过拟合;c. 微调将使得不同下游任务出现不同模型,彼此不兼容,增加了实际部署的成本

       贡献:为了缓解上述问题,我们提出了一种新颖有效的长尾提示调整(LPT)方法。具体来说,LPT 建立在预训练模型的基础上,例如视觉转换器(ViT)(Dosovitskiy 等人,2021 年),并在该预训练模型中引入额外的可训练提示,最后只对这些提示进行微调,以使预训练模型适应手头的长尾数据。提示有两种:1)针对所有类别的共享提示,用于学习一般特征(知识),并使预训练模型适应目标领域;2)针对特定群体的提示,用于收集具有相似特征的样本的特定群体特征,并使预训练模型具有细粒度的区分能力。为实现有效训练,我们设计了一个两阶段训练框架来学习这两种提示。在第一阶段,LPT 在感兴趣的长尾训练数据集上优化共享提示和分类器。在这一阶段,其目标是:1)通过提示调整使预训练模型适应目标兴趣领域;2)增强预训练模型与训练分类器对训练数据的判别能力,这是学习特定群体提示的基础。在第二阶段,我们学习新添加的特定群体提示集,并进一步微调第一阶段使用的分类器。具体来说,在给定输入的情况下,LPT 会将其与学习到的共享提示一起输入到预训练模型中,并查看输入结果。

       这种 LPT 可以很好地缓解现有方法中存在的上述三个问题。在训练成本方面,LPT 只需对几个提示语进行微调,而这些提示语的大小远远小于预训练模型的大小,因此比起对整个预训练模型进行微调来适应,LPT 的训练成本要低得多。在泛化能力方面,LPT 只需在固定预训练模型的基础上对提示进行微调,因此可以享受预训练模型强大的泛化能力。在兼容性方面,LPT 针对不同的学习任务共享一个预训练模型,只需存储小尺寸的提示信息,大大提高了模型的兼容性,降低了实际部署成本。

       如图 1 所示,在各种长尾分类基准上,LPT 只需增加 ∼1.1% 的提示参数,就能获得与之前对整个预训练模型进行微调的方法相当甚至更高的性能。特别是,在仅使用基于视觉的数据进行训练和测试的情况下,LPT 在 PlacesLT 数据集上实现了 50.1% 的总体分类准确率和 46.9% 的少拍准确率(Zhou et al, 2017a),与之前仅使用视觉数据训练的方法相比,分别提高了 8.9% 和 11.6%。此外,更多的实验结果表明了 LPT 的优越性,以及它在长尾数据和域偏移数据上的泛化和鲁棒性。

3. 相关工作

       长尾图像分类:为解决高度不平衡数据分布带来的负面影响,以往的研究主要集中在三个不同的方面:数据重采样,利用手工采样器(Kang 等人,2020 年)、数据增强(Li 等人,2021a)或基于元学习的采样器(Ren 等人,2020 年)来平衡头部和尾部类别的训练数据;损失再加权(Cui 等人,2019 年;Menon 等人,2021 年;Li 等人,2022 年;Jamal 等人,2020 年;Tan 等人,2020 年),主要是在置信度分数中加入手工偏差(Menon 等人,2021 年;Li 等人,2022 年;Tan 等人,2020 年),通过手工加权重新缩放对数(Cui 等人,2019 年;Tan 等人,2020 年),或基于元学习的方法(Jamal 等人,2020 年); 李等人,2022 年)、通过手工创建的权重重新缩放对数(崔等人,2019 年;谭等人,2020 年)或基于元学习的方法(贾马尔等人,2020 年);以及解耦训练策略(康等人,2020 年;李等人,2021 年b)和集合学习方法(周等人,2020 年;王等人,2020 年)。最近,一些基于视觉语言的方法(Ma 等人,2021;Tian 等人,2022;Long 等人,2022)被提出,这些方法在训练和测试过程中引入额外的语言数据(Ma 等人,2021;Tian 等人,2022)或外部数据库(Long 等人,2022)来生成辅助置信度分数,最终在长尾数据上微调整个基于 CLIP 的模型。与上述对所有参数进行全面微调的方法不同,我们旨在利用预训练模型强大的无偏特征表示能力,构建一种及时调整的方法,从长尾数据中获得灵活而准确的分类器。

传统的长尾问题解决方案:

        针对于模型对长尾问题的处理和适应(提升处理长尾问题的能力)

       高效调整:高效调整方法(包括 prompt(Lester 等人,2021 年;Jia 等人,2022 年)、adapter(Houlsby 等人,2019 年;He 等人,2022 年;Nie 等人,2022 年;Chen 等人,2022 年)、LoRA(Hu 等人,2022 年)和其他方法(Frankle 等人,2021 年; Touvron等人,2022))的设计目的是利用预训练模型的表征能力,只对少数可训练参数进行微调,从而在下游任务中取得更好的性能(Zhai等人,2019;Lin等人,2014;Zhou等人,2017b)。在本文中,我们重点关注及时调整(Zhou 等人,2022a;Jia 等人,2022;Bahng 等人,2022)。具体来说,Jia 等人(2022 年)在 ImageNet(Deng 等人,2009 年)预训练的 ViT(Dosovitskiy 等人,2021 年)中引入了提示符;而 Bahng 等人(2022 年)则在图像边缘插入提示符并优化提示符。Wang 等人(2022 年)还在继续学习框架中引入了提示调整方法,使用多个可学习的提示来处理相应的任务。与上述研究不同的是,LPT 重点探索了在大规模且高度不平衡的训练数据中提示调整的迁移能力,从而实现了可比性和准确性。

基于预训练模型解决方案:

        针对于下游任务如何更好、更高效的利用预训练模型处理存在长尾现象的问题

目前长尾问题处理方式:

                                        预训练模型————下游根据问题调整模型

上面的方案一针对:下游用什么方法可以提高模型调整的效果

上面的方案二针对:下游使用预训练模型来调整要采取什么调整策略

3. 初步研究

3.1 VPT 性能调查

       以往研究(Zhou et al, 2022a;Jia et al, 2022)中的及时调优主要集中在对均衡分布的有限数据进行微调(Zhai et al, 2019),而其对大规模长尾数据的迁移学习能力(Zhou et al, 2017a;Van Horn et al, 2018)则未被探讨。为了启动我们的方法,我们首先定量评估了及时调整是否有利于长尾学习。为此,我们在大规模 Places-LT 数据集(Zhou 等人,2017a)上比较了线性探测和一种提示调谐方法(即 VPT,因其有效性而得名)的性能,从而研究了在 ImageNet-21k (Deng 等人,2009 年)上预训练的 ViT-B(Dosovitskiy 等人,2021 年)。具体来说,线性探测的目的是在预训练和固定特征提取器(如 ViT(Dosovitskiy 等人,2021 年))的基础上对线性分类器进行微调;而 VPT 通常是在预训练模型的基础上将输入标记与可学习的提示(标记)和线性分类器串联起来。在训练过程中,我们使用这两种方法在 20 个历时内独立优化其可学习参数,并使用经过良好调整的超参数,例如,SGD 的学习率为 0.02,权重衰减为 1e-4。

问题:

        1、及时调整包括线性探测、提示调谐两个方法。(意思为:快速微调)

        2、之前文章对及时调整的研究只在均衡数据中研究过,没有用于长尾数据集。

       表 1 总结了线性探测和 VPT 的量化结果。在不采用类平衡采样的情况下,VPT 的总体准确率为 37.52%,在多发/中发/少发准确率方面分别比线性探测高出 3.94%、3.33%、4.52%。特别是在引入类平衡采样(Kang 等人,2020 年),即先从训练集中随机采样类,然后在每次迭代中随机采样等量的输入后,VPT 的总体准确率达到了 44.17%,在少点准确率方面甚至超过了同行 8.67%。根据观察结果,我们得出以下结论:a) 及时调整能持续提高长尾分类的整体性能;b) 及时调整对长尾分布更稳健,对尾部类别的好处更大。不过,从表 1 中也可以看出,及时调整在长尾问题上的表现还不够充分,还远远落后于前沿技术。

3.2 及时调整分析

       然而,及时调谐能提高长尾学习任务性能的原因仍不清楚。为了定量和定性分析及时调谐,我们在 Places-LT 上进行了一系列实验(Zhou et al, 2017a)。我们首先采用线性判别分析(LDA)从领域适应的角度研究学习到的提示。具体来说,我们使用预训练的 ViT-B 和经过 3.1 节中 VPT 在 Places-LT 上微调的 ViT-B,分别提取 ImageNet val 集和 Places-LT val 集的特征,然后利用上述特征得到相应的 LDA 向量,用于可视化。

1、Vit预训练模型采用的是ImageNet数据集

2、实验微调采用Places-LT数据集

整体流程思考:

        1、利用预训练模型Vit-B(可以让它走一遍无监督模型先),经过VPT去微调(微调和训练的区别主要在于学习率的不同。微调学习率低,训练学习率高)。

        2、VPT微调同时更新图片特征提取器以及prompt(prompt针对的是类别)。

        3、再利用LDA去完成两个任务:降维、分类(主要还是方便于可视化

       从图 2 的定性结果中,我们不难发现:a) 对于预训练的 ViT-B,其从 ImageNet 提取的特征(红色聚类)与从 Places-LT 提取的特征(绿色聚类)相距甚远;b) 对于 VPT 微调后的 ViT-B,其从 ImageNet 提取的特征(黄色聚类)与从 Places-LT 提取的特征(蓝色聚类)对齐,且彼此接近。因此,这些观察结果表明:1)VPT 中的学习提示可以帮助微调数据分布(Places-LT)与预训练数据分布(ImageNet)保持一致,从而使预训练模型适应长尾学习任务的目标领域。

核心要点:

        1、VPT比VIT更容易对齐两个数据集

        2、说明VPT比Vit有更强的迁移学习能力

        3、Vit在预训练学习的基础知识和长尾微调学习的知识相关性低;VPT在预训练学习的基础知识和微调学习的知识能够在同一个特征空间,相关性高。

       接下来,我们从特定群组的角度来研究学习到的提示。具体来说,对于 Places-LT 中的每个类,我们将该类中的样本视为一个组(簇);然后,对于每个组 i(1 ≤ i ≤ C,数据集中共有 C 个类),我们计算每个样本与其对应的组中心之间的平均距离,并将此平均距离视为每个组的类内距离 Ri。此外,我们还将组间距离 D 定义为任意两个组中心之间的平均距离,然后计算组内距离 Ri 的平均值与组间距离 D 的比值 γ,即\gamma=\frac{1}{CD}\sum_{\mathrm{i}}R_{\mathrm{i}}.。直观地说,对于一个群体来说,内类距离 Ri 越小,群体就越紧凑。因此,如果 γ 越小,则组的可区分度越高。因此,我们用 γ 作为衡量所学特征是否具有可区分性的指标,并在表 2 中报告了统计结果。可以看出,与 vanilla 预训练模型(指的是任何模型中没有经过任何特定修改或增强的模型版本)相比,VPT 微调预训练模型中的特征获得的平均类内距离更小,比率 γ 也更小,这表明 VPT 中不同类的特征更容易区分。此外,我们还对预训练的 ViT-B 和 VPT 微调预训练的 ViT-B 进行了 K-NN 评估。表 2 显示,在 K-NN 准确率方面,VPT 比 vanilla 预训练的 ViT-B 高出 1.1%,这表明 VPT 微调模型具有更高的区分能力。因此,可以得出以下结论:2)学习到的提示可以进一步提高预训练模型的判别能力,从而有利于长尾分类问题。

核心思想:

        1、建立一个新的衡量标准γ,用来衡量模型的优劣(从特征学习角度)

        2、用这个数据证明VPT模型判别能力更好

4. 长尾提示调谐

       第3节的观察结果启发我们设计一种基于即时调整的高效长尾学习方法(Jia et al ., 2022)。然而,长尾学习中的 vanilla VPT仍然落后于最先进的方法(Tian et al ., 2022;Long et al ., 2022)。为了进一步提高长尾学习中提示调谐的整体性能,我们提出了一种有效的长尾提示调谐(LPT)方法,其框架和训练过程如图3所示。一般来说,LPT包括一个共享的提示,让所有类学习一般特征或知识,并将预训练的模型适应目标领域,同时赋予训练数据的判别能力;以及特定于组的提示,以收集特定于组的功能并进一步细化调优第一阶段使用的分类器以获得更高的性能。两组提示分别通过共享提示调优和组提示调优进行优化。我们的LPT介绍如下。

4.1 阶段1:共享提示调优

       对于图3中的共享提示调谐阶段,给定L层的预训练ViT (Dosovitskiy et al ., 2021),我们的目标是优化共享提示\mathbf{u}=[\mathbf{u}_{1},\ldots,\mathbf{u}_{\mathrm{L}}]和余弦分类器f(\cdot;\theta_{f}),其中u跟随VPT-Deep (Jia et al, 2022),由L个独立的可学习token序列组成。具体来说,给定输入图像I, LPT通过预训练的patch嵌入层获得初始patch令牌z0。然后,给定类令牌([CLS]) c0和预训练的变压器编码器,对于ViT中的第i层,其中1≤i≤L,则定义第i块中使用的查询为\mathbf{q_{i}^{attn}}=[\mathbf{c_{i-1}},\mathbf{z_{i-1}}],对应的键值\mathbf{k_{i}^{attn}}=\mathbf{v_{i}^{attn}}=[\mathbf{c_{i-1}},\mathbf{z_{i-1}},\mathbf{u_{i}}],然后更新(ci,zi)利用u 通过:

(\mathbf{c_i},\mathbf{z_i})=\mathrm{FFN_i(Attn_i(q_i^{attn},k_i^{attn},v_i^{attn})),}\quad(1)

       在[\cdot,\ldots,\cdot]表示沿令牌数方向的令牌拼接操作,Attni和FFNi是第i个预训练ViT块中的自注意层和前馈网络(Vaswani et al ., 2017)。然后,将最终的类令牌cL输入余弦分类器f以计算每类置信度分数\mathbf{s}=f(\mathbf{c}_{\mathrm{L}};\theta_{f})。最后,给定相应输入I的真值y,我们最小化L\mathcal{L}_{\mathrm{P}_{1}}=\mathcal{L}_{\mathrm{cls}}(\mathbf{s},\mathbf{y})在阶段1的训练期间优化u和θf,其中Lcls是两个阶段使用的分类损失,将在第4.3节中讨论。

关键点:

        1、Vit模型(用transformer架构全面代替CNN)

        2、这里的类型真值y,个人认为是热卡标签或软标签(软标签效果肯定更好)

4.2 阶段2:组提示调优

       降低长尾学习难度的一个直接解决方案是通过特征的相似性将训练数据分成多个组,从而在每组中共享组特有的知识,降低识别难度。基于这一动机,为了收集具有相似特征的样本的组特有特征,并赋予预训练模型细粒度的判别能力,我们的目标是使用不同的组提示来处理来自不同类别的样本,从而通过每个组提示收集组特有特征,有利于长尾分类。因此,我们引入了针对群体的提示和m个个人可学习的提示\mathcal{R} = \{(\mathbf{k}_{1},\mathbf{r}^{1}),\ldots,(\mathbf{k}_{\mathrm{m}},\mathbf{r}^{\mathrm{m}})\}其中ki为对应的第i组提示符r^i的键,每个r^i有L−K个可训练的令牌序列。为了减少计算成本和额外参数的数量,我们在前K个块中只使用共享提示符,并在最后L - K个块中引入组特定提示符集R。在本小节中,我们主要讨论分组提示调整的训练过程。具体来说,根据第 3.2 节中的观察结果(2),我们从第 1 阶段选择查询 q = cL,而不是像 Wang 等人(2022 年)那样使用预训练 ViT 的输出类标记,因为类标记 cL 通常具有更强的判别能力。给定查询 q 后,我们通过以下方式从 R 中自适应地选择最佳匹配提示:

[\mathrm{w}_1,\ldots,\mathrm{w}_k]=\mathrm{top-k}(\langle\mathbf{q},[\mathbf{k}_1,\ldots,\mathbf{k}_\mathrm{m}]\rangle,k)\quad(2)

问题:为什么需要为每一个组Prompt提示r增加一个索引标签k?

解答:1、梯度更新不是更新所有的r;2、同一个组可能同时和两个不同的Prompt产生联系(例如:一棵树r1、一只狗r2两个prompt完全没有联系,但是一张图片可能和这两个prompt都有关系。因此不能用q和r1、r2直接建立余弦相似度,因为r1、r2就不可能相似,因此一定不可能q匹配上多个r。因此为r另外建立一个k,这个k可以利用loss去学习,又没有r这种必然存在不相似的限制,因此完成Prompt本身性质和多方面匹配的一个解耦操作

       下面我们将讨论键的优化问题。直观地说,优化键的一种直接方法是强制同一类别的查询匹配某些键。然而,这种方法并不可行,因为很难准确解释哪些类可以匹配到某些提示。相反,我们更倾向于简单地最小化匹配查询和密钥之间的距离,从而自适应地优化这些密钥。我们就是从这个角度来设计这种查询函数的。正如第 3.2 节所述,微调阶段 1 生成的每个类别的特征集群都很紧凑。因此,对于来自同一类别的查询,如果我们随机选择一个查询 qi 和一个密钥 k',然后最小化 1-\langle\mathbf{q_{i}},\mathbf{k^{\prime}}\rangle,,k'和其他查询之间的距离自然最小,因为这些查询是固定的和足够紧凑的。因此,在训练过程中,每个键被学习到靠近附近的一个或多个聚类,最终引导相应的组提示符收集组特有的特征。此外,由于1)VPT (Jia et al, 2022)受益于即时集成,以及2)引入更多特定于群体的知识可能有利于识别尾部类别的样本。LPT不是只使用R中一个匹配的组提示符,而是将多个选择的提示符进行提示符集成,如下所示:

\mathbf{r}=\mathrm{sum}([\mathbf{r}^{\mathbf{w}_{1}},\ldots,\mathbf{r}^{\mathbf{w}_{k}}])/k,\quad(3)

       在给定 r 的情况下,LPT 会重新使用特征点(\mathbf{c_{K}},\mathbf{z_{K}})从阶段1作为(\mathbf{\hat{c}}_{\mathrm{K}},\mathbf{\hat{z}}_{\mathrm{K}})为了节省计算成本,将第 i 个区块中使用的查询定义为\mathbf{\hat{q}}_{\mathrm{i}}^{\mathrm{attn}}=(\mathbf{\hat{c}}_{\mathrm{K}},\mathbf{\hat{z}}_{\mathrm{K}})和键值为\mathbf{\hat{k}_{i}^{attn}}=\mathbf{\hat{v}_{i}^{attn}}=[\mathbf{\hat{c}_{i-1}},\mathbf{\hat{z}_{i-1}},\mathbf{u_{i}},\mathbf{r_{i-K}}],最终更新为(\mathbf{\hat{c}}_{\mathrm{K}},\mathbf{\hat{z}}_{\mathrm{K}})作为:

(\mathbf{\hat{c}_{i}},\mathbf{\hat{z}_{i}})=\mathrm{FFN_{i}(Attn_{i}(\hat{q}_{i}^{attn},\hat{k}_{i}^{attn},\hat{v}_{i}^{attn}))}

关键点:

        1、用前面第k个块的结果作为前置输入

        2、r^m_i:m表示第几组prompt,i表示第几个block的最终prompt

       其中,K+1 ≤ i ≤ L 表示 ViT 中最后 L - K 个预训练块的索引。接下来,将输出类标记 \mathbf{\hat{c}}_{\mathrm{L}}输入余弦分类器 f,并通过f(\mathbf{\hat{c}}_{\mathrm{L}};\theta_{f}).计算每类置信度得分。最后,在给定相应输入 I 的地面实况 y 的情况下,我们最小化 LP2,包括分类损失 Lcls 以及查询 q 与相应匹配密钥[\mathrm{k}_{\mathrm{w}_{1}},\ldots,\mathrm{k}_{\mathrm{w}_{k}}]之间的余弦相似度,如公式 5 所示:

\mathcal{L}_{\mathrm{P_{2}}}=\beta\mathcal{L}_{\mathrm{cls}}(\mathbf{\hat{s}},\mathbf{y})+(1-\frac{1}{k}\sum_{\mathrm{i\in w}}\langle\mathbf{q},\mathbf{k_{i}}\rangle),\quad(5)

 其中,β 是 Lcls 的比例因子,将在下文中讨论。

核心思想:

        1、这里本质上有两个优化目标:a)增加模型输出的标签和真实标签的相似度;b)增加查询q和k的匹配程度

       需要注意的是,天真地使用类平衡采样(Kang 等,2020 年)或实例平衡采样(Kang 等,2020 年)可能会分别导致尾类或头类的严重过拟合(Zhang 等,2021b)。为了平衡头部类和尾部类的性能,避免过度拟合,我们引入了双重采样策略。具体来说,在第二阶段的每次训练迭代中,LPT 会从实例平衡采样器中随机采样一个小批量 fIgins,以及从类平衡采样器中随机采样另一个小批量 fIgbal。对于 fIgbal 中的样本,我们只需设置 β = 1 来计算 LP2;而对于 fIgins 中的样本,我们设置 β = η(E - e)=E 其中,η 是 fIgins 的初始化权重,E 表示最大历时数,e 是当前历时数。

4.3 损失函数

       最后,我们介绍了两阶段训练中使用的分类损失 Lcls。虽然 LPT 可以使用多种分类损失来进一步提高 LPT 的性能,但我们采用了非对称 GCL 损失L_{A-GCL},用于根据训练数据的统计标签频率调整对数,以及在正负类之间重新加权梯度。在不失一般性的前提下,我们以 LPT 第二阶段计算的\mathbf{\hat{s}}=f(\mathbf{\hat{c}}_{\mathrm{L}};\theta_{f})为例来演示L_{A-GCL}。按照(Li et al, 2022)的方法,我们对第 i 个类别的置信度得分进行重新缩放:

\mathbf{v_{i}}=\alpha(\mathbf{\hat{s}_{i}}-(\log n_{\mathrm{max}}-\log n_{i})\|\epsilon\|)\quad(6)

       其中,α 为比例因子,为高斯分布随机变量,ni 和 nmax 分别指训练集中第 i 个类别的标签频率和最大标签频率。然后,我们通过以下方法计算每类概率\mathbf{p}=[\mathbf{p}_{1},\ldots,\mathbf{p}_{\mathbf{C}}]

[\mathbf{p}_{1},\ldots,\mathbf{p}_{\mathbf{C}}]=\mathrm{softmax}([\mathbf{v}_{1},\ldots,\mathbf{v}_{\mathbf{C}}]).\quad(7)

调整作用

  • 减少偏差: 通过从 s 中减去这一差值,可以调整 s 的值,使得对频率较少的类别进行一定的“惩罚”,从而使模型在训练过程中能够更加关注样本数较少的类别。
  • 增强平衡: 这种调整有助于在类别不平衡的情况下,使得模型的学习过程更加均衡,避免在频率较高的类别上过拟合。

       接下来,我们使用非对称再加权(Ridnik 等人,2021 年)来消除长尾学习中负梯度的影响。假设 j 是 I 的实际类,我们计算L_{A-GCL}为:

\mathcal{L}_{\mathrm{A-GCL}}=(1-\mathbf{p_{j}})^{\lambda_{+}}\log(\mathbf{p_{j}})+\sum_{1\leq i\leq C,i\neq j}(\mathbf{p_{i}})^{\lambda_{-}}\log(\mathbf{p_{i}}),\quad(8)

其中,λ+ 和 λ- 分别是地面实况类和负面类的聚焦参数(Lin 等人,2017 年)。最后,我们在 LPT 的两阶段训练中选择 L_{cls} = L_{A-GCL}

核心思想:
        1、CL:对比损失函数(将相似拉近,将不相识推远)

        2、GCL:在CL基础上增加对梯度信息的考虑(例如梯度大的减小,梯度小的扩大)

        3、A-GCL:A的意思就是自适应,说明这个对梯度信息的调整是一个超参数,动态的

对比学习的两种形式:

        1、除法形式:\mathcal{L}_{\mathrm{GCL}}=-\frac1N\sum_{i=1}^N\log\frac{\exp(\sin(x_i,x_j)/\tau)}{\sum_{k=1}^N\exp(\sin(x_i,x_k)/\tau)} 

        2、加法形式:\mathcal{L}_{\mathrm{A-GCL}}=(1-\mathbf{p_{j}})^{\lambda_{+}}\log(\mathbf{p_{j}})+\sum_{1\leq i\leq C,i\neq j}(\mathbf{p_{i}})^{\lambda_{-}}\log(\mathbf{p_{i}})

如果想要学习更多深度学习论文,大家可以点个关注并订阅,持续学习、天天进步

你的点赞就是我更新的动力,如果觉得对你有帮助,辛苦友友点个赞,收个藏呀~~~


http://www.kler.cn/a/379165.html

相关文章:

  • 数字小偷:2025年全面防护指南
  • WebSocket实现分布式的不同方案对比
  • 【2024年华为OD机试】 (C卷,100分)- 小明找位置(Java JS PythonC/C++)
  • 无公网IP 实现外网访问本地 Docker 部署 Navidrome
  • 鸿蒙-点击Notification通知并打开App的具体页面
  • 浅谈云计算20 | OpenStack管理模块(下)
  • 读书笔记-《Spring技术内幕》(四)事务
  • 【亚马逊云】基于 AWS 使用CloudFormation快速部署 VMClarity 环境
  • celery在django项目中实现并发任务和定时任务
  • SOLIDWORKS 2025用户体验新功能
  • NineData云原生智能数据管理平台新功能发布|2024年10月版
  • distrobox install in ubuntu 22.04 / 在 ubuntu 22.04 上安装 distrobox (***) OK
  • qt的c++环境配置和c++基础【正点原子】嵌入式Qt5 C++开发视频
  • Stable Diffusion Web UI 1.9.4常用插件扩展-WD14-tagger
  • Spring Boot技术:校园社团信息管理的创新解决方案
  • 123.WEB渗透测试-信息收集-ARL(14)
  • 初始计算机网络
  • sqlserver、达梦、mysql的差异
  • React 组件生命周期与 Hooks 简明指南
  • HTTP代理是什么?有什么用?
  • git pull遇到一个问题
  • 揭秘Scam-as-a-Service:警惕钓鱼攻击的产业化
  • centos7之LVS-DR模式传统部署
  • 21 Docker容器集群网络架构:四、Docker集群网络验证
  • 在k8s环境中如何在本地和pod之间同步文件?
  • 基于微信小程序的生签到系统设计与实现(lw+演示+源码+运行)