当前位置: 首页 > article >正文

论文解读:Prompt-aligned Gradient for Prompt Tuning

摘要

得益于CLIP等大型预训练的视觉语言模型VLM,我们可以通过离散的提示设计构建Zero-shot分类器,例如,利用图像与提示语句" a photo of a [ CLASS ] "之间的相似度,可以获得图像属于某个类别的置信度分数。此外,如果我们使用少量样本对软提示进行微调,提示调优表现出VLMs快速适应下游任务的巨大潜力。

然而,我们发现一个常见的错误,不适当的微调或极少样本的学习甚至会导致zero-shot预测性能低下。现有方法仍然通过使用提前终止和数据增强等传统的抗过拟合技术来解决这一问题,缺乏针对提示的原则性解决方案。

在本文中,我们提出了ProGrad (即提示对齐梯度),以防止提示调优遗忘从VLMs中学习到的一般知识。特别是,ProGrad只更新那些梯度与一般知识一致的提示,这种一般知识是通过预先定义的提示预测所提供的优化方向来表示的。简单来说,ProGrad是一种机制,它只更新那些其改进方向与现有知识库一致的提示,以此来保持知识更新的一致性和避免冲突。

在少样本学习、领域泛化、基类-新类泛化和跨数据集迁移设置下的大量实验表明,ProGrad的少样本泛化能力强于当前最先进的即时调优方法。

Introduction

在学习无数的图像-文本对之后,大规模视觉-语言模型( VLM ) 可以学习相匹配的图像文本对。得益于VLMs强大的语言建模能力,我们可以用自然语言(即提示)建立一个用来查询通用知识的交流通道。提示桥接了预训练过程和下游任务之间的连接差距,而不需要额外的微调。例如,我们可以为zero-shot图像分类制作一个提示语句' a photo of a [ CLASS ] ':通过使用视觉语言模型CLIP ,我们将图像输入到视觉分支,提示语句输入到语言分支,然后获得一个视觉语言相似度作为将图像分类为' [ CLASS ] '的置信度分数。

在实际中,基于提示的zero-shot图像分类是不准确的,因为手工设计的提示可能不是机器最喜欢的(例如, ' this is a Picture of '可能在VLM训练中语法上更占优势),或者不特定于下游域(例如, "某人正在做某事的照片"在动作识别上更好) 。最近,提示调整或前缀调整被提出,用一组可学习的单词嵌入向量来代替手工提示,这些单词嵌入向量不必翻译回人类可读的单词。然而,提示调优仍然像传统的微调一样棘手:随着训练的继续,泛化能力可能会下降,甚至低于zero-shot基线。

解释图1 :

图1 ( a & b )所示,提示调优方法CoOp通过提前停止达到了最好的效果,当训练继续进行时,其准确率最多下降了4 %。此外,图1 ( c & d )显示,CoOp在没有增加或没有足够的下游任务样本的情况下,几乎无法改善zero-shot的CLIP。据我们所知,现有的方法仍然依赖于传统的抗过拟合技术,如提前停止和数据增强,这缺乏对提示调优本质的原则性解决方案。

解释图2 :

此外,Grad - CAM可视化结果表明,微调后的提示会误导VLM忘记分类至少应该关注前景对象而不是背景的通用知识。对比CoOp (图2 b)和零样本学习CLIP (图2c),我们发现CoOp模型分散了对前景的注意力,而CLIP主要关注前景物体。这些结果说明了现有的提示调优策略存在过拟合风险,尤其是当训练样本数量极其有限时。

为此,本文提出了一种名为Prompt - alignment Gradient ( ProGrad )的提示调优方法,来克服CLIP中不恰当的调优。ProGrad的原则是对每个调优步骤进行规则化处理,使其不与原始提示提供的通用知识发生冲突,例如零样本CLIP预测。具体来说,我们使用zero-shot的CLIP和少样本微调预测之间KL散度来衡量通用知识的方向Gg,称为通用方向。类似地,我们利用真实数据和少样本微调模型之间的交叉熵梯度计算特定领域的知识方向Gd,称为特定领域方向。

我们将特定领域的方向Gd分解为:1 )一个与一般方向正交的向量G⊥,它表示不冲突的特定领域知识;2 )另一个与一般方向平行的向量G∥,表示通用知识。由于任何两个正交的向量都可以转化为两个互不冲突的基向量,所以第一个梯度分量G⊥不会覆盖总的方向。对于第二个成分,它必须是以下两个方向之一:1 )与一般方向相同,这表明更新与通用知识一致;2 )与一般方向相反,这表明冲突的更新应该被丢弃,以避免遗忘。

继续解释图2:

总体而言,在每次迭代中,ProGrad只更新与大方向呈锐角的提示对齐方向上的参数。与CoOp和CLIP相比,Gg和G⊥(图2 ( d & e ) )都有助于正则化模型聚焦于前景,ProGrad (图2 ( f ) )进一步提高了视觉响应。

学习CLIP、CoOp和CoCoOp,我们在15个图像分类基准下,在少样本学习、领域泛化、基类到新类泛化和跨数据集迁移的设置下对ProGrad进行评估,包括通用对象分类、细粒度图像识别、动作分类。综上所述,本文的ProGrad实现了:1 )与CoOp相比,在所有11个数据集上都有明显的提升;2 )与CoOp和CoCoOp相比,在所有11个数据集上的基类和新类精度的调和平均值都有明显的提高;3 )在领域泛化的源数据集和目标数据集上都有明显的提高。

Method

3.1. premiliminaries

Contrastive language-image pre-training (CLIP)

对抗性语言-图像预训练模型( CLIP ) 采用对比语言-图像预训练范式,对大量图像文本对进行训练。对于对比学习,关联的图像和句子被视为正样本,而非关联的图像文本对则被视为负样本。对比目标最大化正对的相似度,最小化负对的相似度。

Zero-shot transfer inference

zero-shot迁移推理将预训练的CLIP模型适应于下游任务,而无需对模型进行微调。以图像分类为例,通过将分类任务表述为图像-文本匹配问题来实现zero-shot迁移,其中文本是使用类似于"A photo of [ CLASS ]"的模板来扩展" [ CLASS ] "获得。基于图像特征f和类别扩展的文本特征wi之间的余弦相似度来衡量图像-类别匹配得分。图像编码器提取图像x的图像特征f,而第i类的文本特征wi则通过将提示描述输入到t中得到:

Prompt-based learning

基于提示(Prompt-based)的学习进一步增强了CLIP模型的迁移能力,并通过自动从下游任务中学习少量样本来避免提示工程(prompt engineering)。与使用固定手工制作的提示的零样本迁移不同,CoOp [54] 构建并微调一组M个连续的上下文向量v = {v1, v2, ..., vM}作为可转动的提示。具体来说,提示ti = {v1, v2, ..., vM, ci}结合了可学习的上下文向量v和类别标记嵌入ci,并输入到文本编码器g(·)。CoOp通过最小化真实标记的负对数似然来优化静态上下文向量v。

3.2. Prompt-aligned Gradient

CoOp面临着一个挑战,当注释数目有限时(例如每个类别一个注释),其迁移性能就会显著下降,甚至可能会低于zero-shot迁移效果。此外,CoOp也过度地依赖于早期停止和数据增强等抗过拟合的技术。为了克服过拟合的挑战,本文提出一种高效的微调范式ProGrad,将下游任务中的少样本知识与大规模的通用知识进行对齐。

受到知识蒸馏在知识迁移中的成功启发,本文利用CLIP的zero-shot预测作为通用知识,然后将微调后的预测与通用知识进行比较,从而调节梯度方向。

具体来说,我们通过Eq2计算模型预测p(ti|x)与真实值y之间的交叉熵损失Lce(v)来得到领域特定方向;根据调优模型预测p(ti|x)与CLIP的zero-shot预测pzs(wi|x)之间的KL散度得到通用知识的方向。

本文使用Gg = ∇vLkl(v) 和 Gd = ∇vLce(v)分别代表Lkl(v)和Lce(v)的梯度。

解释图3:Gd和Gg之间给关系具有两种可能性:(1)如图3(a):两者之间的夹角小于90度,这表明下游任务中少样本中知识的优化方法与通用知识并不冲突。在这种情况下,我们可以将更新后的梯度方法Gprograd设置为Gd。(2)如图3(b):两者之间的夹角大于90度,这表明下游任务少样本知识与通用知识相冲突。换句话说,优化Gd之后的上下文向量会导致遗忘预训练通用知识。在这种情况下,我们将Gd投影到Gg的正交方向来优化模型进行分类,避免了增加KL损失。

ProGrad的策略为:

其中λ表示通用知识的指导强度。当λ=1表示Gd投影到Gg的正交方向上,而λ=0表示退化为CoOp。

图3(c)表示了本文提出的ProGrad的流水线。与CoOp中使用Gd(特定领域方向)更新上下文向量不同,本文使用Gprograd来优化上下文向量,防止梯度的方向对下游任务中少样本的过拟合。

Experiments

Conclusion

本文指出了现有的针对少样本泛化的提示调优方法的过拟合问题,这些方法严重依赖于早期停止和数据增强。本文提出了一种提示调优方法ProGrad,对每个调优步骤进行规则化处理,不与手工提示的一般知识相冲突。在11个数据集上的小样本分类、基-新泛化、领域泛化和跨数据集迁移实验证明了ProGrad的有效性和高效性。在未来的工作中,我们将探索如何将ProGrad应用于目标检测和分割等其他任务。


http://www.kler.cn/a/289328.html

相关文章:

  • CSRF攻击XSS攻击
  • 西门子【Library of Basic Controls (LBC)基本控制库”(LBC) 提供基本控制功能】
  • Linux提权-02 sudo提权
  • Web自动化:Cypress 测试框架概述
  • 将Docker运行中的容器保存为镜像并导出导入
  • 【服务治理中间件】consul介绍和基本原理
  • 论文《Improving your graph neural networks:A High-Frequency Booster》笔记
  • 构造+模拟,CF 873D - Merge Sort
  • 水平垂直居中的方式
  • Nginx - Rewirte
  • 【GPT】Coze使用开放平台接口-【5】API 调用
  • 15、Django Admin添加自定义字段功能
  • 宠物勺子秤芯片解决方案CSU8RP1186
  • 机器学习(五) -- 监督学习(8) --神经网络2
  • 苹果系统中如何安装Python和PyCharm
  • 低代码用户中心的构建与应用
  • 计算机毕业设计PySpark深度学习动漫推荐系统 动漫视频推荐系统 机器学习 协同过滤推荐算法 bilibili动漫爬虫 数据可视化 数据分析 大数据毕业设计
  • Vue3 数据通信
  • 计算机网络 第1章 概述
  • AI预测体彩排3采取888=3策略+和值012路或胆码测试9月3日升级新模型预测第71弹
  • 大数据-114 Flink DataStreamAPI 程序输入源 自定义输入源 Rich并行源 RichParallelSourceFunction
  • Meshy-4:AI驱动3D建模的革命性工具,解锁虚拟创作新高度
  • AIGC与数据分析融合,引领商业智能新变革(TOP企业实践)
  • 摄像头进行视频捕获并定时截取屏幕图像
  • 【前端面试】设计循环双端队列javascript
  • C#通过ACE OLEDB驱动程序访问 Access和 Excel