当前位置: 首页 > article >正文

【顶刊TPAMI 2025】多头编码(MHE)之极限分类 Part 3:算法实现

目录

  • 1 三种多头编码(MHE)实现
    • 1.1 多头乘积(MHP)
    • 1.2 多头级联(MHC)
    • 1.3 多头采样(MHS)
    • 1.4 标签分解策略

论文:Multi-Head Encoding for Extreme Label Classification
作者:Daojun Liang, Haixia Zhang, Dongfeng Yuan and Minggao Zhang
单位:山东大学
代码:https://github.com/Anoise/MHE

论文地址:Online,ArXiv,GItHub

背景动机参见 【顶刊TPAMI 2025】多头编码(MHE)之极限分类 Part 1
基础知识参见 【顶刊TPAMI 2025】多头编码(MHE)之极限分类 Part 2
算法实现参见 【顶刊TPAMI 2025】多头编码(MHE)之极限分类 Part 3
表示能力参见 【顶刊TPAMI 2025】多头编码(MHE)之极限分类 Part 4
实验结果参见 【顶刊TPAMI 2025】多头编码(MHE)之极限分类 Part 5
无需预处理见 【顶刊TPAMI 2025】多头编码(MHE)之极限分类 Part 6

请各位同学给我点赞,激励我创作更好、更多、更优质的内容!^_^

关注微信公众号,获取更多资讯
在这里插入图片描述

1 三种多头编码(MHE)实现

现在,我们考虑具体化的算法实现,以使MHE适用于各种XLC任务。具体地说,将MHP应用于xsl以实现多头并行加速。在XMLC中使用MHC来防止多个类别之间的混淆,在模型预训练中使用MHS来有效地提取特征,因为该任务不需要分类器。然后,我们提供了一种策略来确定头像的数量和长度。
在这里插入图片描述

图 3 : XLC 任务的三个基于 MHE 的训练和测试流程。红色虚线框表示的部分是为了便于理解,实际中并不需要。

1.1 多头乘积(MHP)

根据推论1,输出可以分解为头部的乘积,这为使用MHP代替普通分类器来训练模型铺平了道路。
如3-a所示,在训练过程中,需要将全局标签 Y i Y_i Yi分配给每个头部,进行局部损失计算。因此,我们首先对 Y i Y_i Yi执行OHE,然后根据头部的长度将其重塑为 H H H阶张量 Y i 1 , . . . , H \mathcal{Y}_i^{1,...,H} Yi1,...,H。最后,将 Y i 1 , . . . , H \mathcal{Y}_i^{1,...,H} Yi1,...,H分解为每个头部上的本地标签 { Y i h } h = 1 H \{Y_i^h\}_{h=1}^H {Yih}h=1H。由于one-hot编码 Y i Y_i Yi的分解仅取决于头部的数量和顺序,因此可以递归地计算为
在这里插入图片描述
其中 j j j k k k为分类头的索引。
在测试期间,必须从局部预测中恢复全局预测。如图3-a所示,我们首先对每个头部执行 I Λ \mathbb{I}_{\varLambda} IΛ,以获得局部预测的标签。然后,通过对每个头部执行乘积并对最终输出应用Argmax来获得全局预测 Y ~ i \tilde{Y}_i Y~i。为了加快这一过程,根据定理1,我们从局部预测和后续正面的长度计算 Y ~ i \tilde{Y}_i Y~i,为
Y ~ i = ∑ k = 1 H − 1 Λ ( O k ) ∏ j = k + 1 H ∣ O j ∣ + Λ ( O H ) . ( 11 ) \tilde{Y}_i = \sum_{k=1}^{H-1} \varLambda(\bm{O}^k) \prod_{j=k+1}^H |\bm{O}^j| + \varLambda(\bm{O}^H). \qquad (11) Y~i=k=1H1Λ(Ok)j=k+1HOj+Λ(OH).(11)
MHP的算法伪代码见附录E-1。它可以用于许多xsl任务,如图像分类、人脸识别等。

1.2 多头级联(MHC)

对于XMLC,每个示例 X i \bm{X}_i Xi对应于多个标签 Y ˉ i ∈ { 0 , 1 } C \bar{\bm{Y}}_i \in \{0,1\}^{C} Yˉi{0,1}C,因此分类器的输出需要执行多热编码和Top- K K K选择,如 Y ~ i = Top- K ( O ˉ ) \tilde{Y}_i = \text{Top-}K(\bar{\bm{O}}) Y~i=Top-K(Oˉ)。在XMLC中不能直接采用MHP。这是因为MHP中的每个头只预测一个标签。如果用于多标签预测,则在计算局部预测的乘积时将导致不匹配。为了解决多标签场景下MHP的不匹配问题,提出了MHC,它将多个头部级联用于模型训练和测试。

如图3-b所示,在训练过程中,MHC的标签分解过程与MHP相同。在测试期间,选择输出的顶级 K K K激活。然后,通过预定义的候选集 C 1 \mathbb{C}^1 C1获得该头部的局部预测,并采用该候选集表示后续头部的标签集,方便检索,减少计算量。 h h h\text{-}头的最终输出 O ~ h \tilde{\bm {O}}^h O~h由嵌入的 Y ~ h − 1 \tilde{\bm{Y}}^{h-1} Y~h1和当前输出 O h \bm{O}^h Oh的乘积得到。然后,根据 O ~ h \tilde{\bm {O}}^h O~h的前 K K K激活项,从 C h \mathbb{C}^h Ch中选择 Y ~ h \tilde{\bm {Y}}^h Y~h。重复此过程,直到获得 Y ~ H \tilde{\bm {Y}}^H Y~H的标签为
在这里插入图片描述
其中   i h = ∏ j = 1 h ∣ O j ∣ \ i_h = \prod_{j=1}^h |\bm{O}^j|  ih=j=1hOj, E h \mathbb{E}^h Eh h h h\text{-}头的嵌入层, C [ 1 , . . . , i h + 1 ] ( i h , ∣ O h + 1 ∣ ) \mathbb{C}_{[1,...,i_{h+1}]}^{(i_h,|\bm{O}^{h+1}|)} C[1,...,ih+1](ih,Oh+1)为元素为 1 1 1 i h + 1 i_{h+1} ih+1,形状为 ( i h , ∣ O h + 1 ∣ ) (i_h, |\bm{O}^{h+1}|) (ih,Oh+1)的索引矩阵。由公式12可知,MHC是一种由粗到精的分层预测方法,它依次从前一个头部中选择Top- K K K候选标签。请注意,MHC仅依赖于Eq. 10进行标签分解,不需要HLT或标签聚类等预处理技术。MHC的算法伪代码见附录E-2。

1.3 多头采样(MHS)

对于模型预训练任务,训练完成后丢弃香草分类器,只采用模型提取的特征 F \bm F F对下游任务进行微调。因此,需要训练分类器中权值的所有参数来提取更多的判别特征,但是训练权值的所有参数计算开销很大。因此,提出MHS通过选择地面真值标签所在的头部来更新模型参数。

如图3-c所示,MHS将原始分类器平均分为 H H H组,使 O = ∑ h H ∣ O h ∣ \bm{O} = \sum_h^H |\bm{O}^h| O=hHOh。训练时,选择标签 Y i Y_i Yi所在的头部进行模型训练,称为正头部。当然,我们也可以随机选择几个负头像一起训练模型,从而使模型具有更多的负样本信息。 O h {\bm O}^{h} Oh的MHS正演过程可表示为
O h = O h ∪ { O j } = W h F ∪ { W j } F ,      ( 13 a ) Y h = Y h ∪ { 0 } = Y [ ∣ O h − 1 ∣ : ∣ O h ∣ ] ∪ { 0 } , ( 13 b ) \bm{O}^h = \bm{O}^h \cup \{\bm{O}^j\} = \mathcal{W}^h\bm{F} \cup \{\mathcal{W}^j\}\bm{F}, \qquad \quad \ \ \ \ (13a) \\ \bm{Y}^h = \bm{Y}^h \cup \{0\} = \bm{Y}[|\bm{O}^{h-1}|:|\bm{O}^{h}|] \cup \{\bm{0}\}, \qquad (13b) Oh=Oh{Oj}=WhF{Wj}F,    (13a)Yh=Yh{0}=Y[Oh1:Oh]{0},(13b)
其中 { O j } \{\bm{O}^j\} {Oj} { W j } \{\mathcal{W}^j\} {Wj}分别表示负头的输出和权重集, ∪ \cup 表示串联操作。等式13-b表示用 0 \bm 0 0 s填充 Y h \bm{Y}^h Yh以对齐 O h \bm{O}^h Oh的长度,其中 ∣ O h ∣ = 0 |\bm{O}^h|=0 Oh=0表示 h = 0 h=0 h=0

式13中的方法可以表示为MHS- S S S,其中 S S S为所选头像的个数。我们的实验表明MHS- 1 1 1(仅正样本)在模型预训练上取得了很好的效果。对于 S = 2 S=2 S=2, MHS近似或优于香草分类器。为了加快MHS的速度,在同一批次中选择含有其他样品标签的头作为阴性头。MHS的算法伪代码见附录E-3。

1.4 标签分解策略

到目前为止,我们已经介绍了三种MHE算法,其实现取决于头的数量和长度。因此,在本小节中,我们引入误差积累和混淆度的概念来衡量头部数量和长度对基于mhe的算法性能的影响。

头的数量: 带 H H H头的MHE的近似过程可表示为
O ≈ O 1 ⊗ O ~ 2 ≈ O 1 ⊗ O 2 ⊗ O 3 ~ ⏟ ≈ O ~ 2 ≈ O 1 ⊗ O 2 ⊗ ⋯ ⊗ O H ⏟ ≈ O ~ H − 1 . ( 14 ) {\bm O} \approx {\bm O}^1 \otimes \tilde{\bm O}^2 \approx {\bm O}^1 \otimes \underbrace{{\bm O}^2 \otimes \tilde{{\bm O}^3}}_{\approx \tilde{\bm O}^2} \approx {\bm O}^1 \otimes {\bm O}^2 \otimes \underbrace{\cdots \otimes {\bm O}^H}_{\approx \tilde{\bm O}^{H-1}}. \qquad (14) OO1O~2O1O~2 O2O3~O1O2O~H1 OH.(14)
如等式14所示,增加一个头部相当于又累积了一个时间误差。虽然增加头的数量会显著减少分类器的参数和计算量,但也会导致更大的累积误差。因此,在计算资源和运行速度允许的情况下,应尽量减少分类头的数量。

头的长度:混淆度是当采用MHE来近似原始标签空间时由共享组件引起的不匹配的度量。它与近似误差成正比,如下所示
D = m a x π ( O 1 , ⋯   , O H ) ( ∏ h = 2 H ∏ k = h H ∣ O k ∣ ∣ O k − 1 ∣ ) ,    ( H ≥ 2 ) , ( 15 ) D = \mathop{max}\limits_{\pi({\bm O}^1, \cdots, {\bm O}^H)} \left ( \prod_{h=2}^H \frac{\prod_{k=h}^H |{\bm O}^k|}{|{\bm O}^{k-1}|}\right ), \ \ (H \ge 2), \qquad (15) D=π(O1,,OH)max(h=2HOk1k=hHOk),  (H2),(15)
其中 π \pi π是正面的排列策略。希望 D D D的值尽可能小。由于 π \pi π依赖于具体的分解过程,我们详细分析了MHE不同算法的混淆程度。因此,

  • 对于MHP,由于磁头是平行的,需要组合,因此混淆程度与磁头的排列无关。也就是说,等式15中的 m a x max max可以在头的长度按升序排列时删除。因此,我们得出结论,MHP中每个头部的长度应尽可能一致,以最小化 D D D,即 ∣ O h ∣ ≈ C H |{\bm O}^h| \approx \sqrt[H]{C} OhHC
  • 对于MHC,由于头部是顺序级联的,我们可以选择一个更好的策略 π \pi π最小化 D D D。显然,当 π \pi π按降序排列( ∣ O 1 ∣ ≥ ⋯ ≥ O H |{\bm O}^1| \ge \cdots \ge {\bm O}^H O1OH)时, D D D是最小的。
  • 对于MHS,这些多个磁头是相互关联的,需要组合(与 m a x max max操作无关)。也就是说,我们可以选择与MHC相同的策略来最小化D。

背景动机参见 【顶刊TPAMI 2025】多头编码(MHE)之极限分类 Part 1
基础知识参见 【顶刊TPAMI 2025】多头编码(MHE)之极限分类 Part 2
算法实现参见 【顶刊TPAMI 2025】多头编码(MHE)之极限分类 Part 3
表示能力参见 【顶刊TPAMI 2025】多头编码(MHE)之极限分类 Part 4
实验结果参见 【顶刊TPAMI 2025】多头编码(MHE)之极限分类 Part 5
无需预处理见 【顶刊TPAMI 2025】多头编码(MHE)之极限分类 Part 6


http://www.kler.cn/a/469158.html

相关文章:

  • tcpdump-命令详解
  • Excel | 空格分隔的行怎么导入excel?
  • Allure 集成 pytest
  • MOE怎样划分不同专家:K-Means聚类算法来实现将神经元特征聚类划分
  • 【vue】晋升路线图、蛇形进度条
  • Linux 基础 4.文件IO 通用的IO模型
  • MFC读写文件实例
  • asp.net core中的 Cookie 和 Session
  • CSS——7.CSS注释
  • 信号的产生、处理
  • S32K144 UDSdoCAN 升级刷写实现笔记
  • 【动手学电机驱动】STM32-MBD(3)Simulink 状态机模型的部署
  • Qt之屏幕录制设计(十六)
  • 系统架构师考试-ABSD基于架构的设计方法
  • python 实现贪心算法(Greedy Algorithm)
  • 2025 年前端新技术如何塑造未来开发生态?
  • 解决CentOS 8 YUM源更新后报错问题:无法下载AppStream仓库元数据
  • SMMU软件指南之使用案例(Stage 2使用场景)
  • MySQL第四弹----数据库约束和数据库设计
  • 【连续学习之LwM算法】2019年CVPR顶会论文:Learning without memorizing
  • STM32拓展 低功耗案例1:睡眠模式 (register)
  • JavaScript系列(8)-- Array高级操作
  • javaEE-网络编程-3 UDP
  • LabVIEW 实现自动对焦的开发
  • 编译与汇编
  • kubelet状态错误报错