【论文阅读】Adversarial gradient-based meta learning with metric-based test
基于对抗梯度的元学习与基于度量的测试
引用:Zhang Y, Wang C, Shi Q, et al. Adversarial gradient-based meta learning with metric-based test[J]. Knowledge-Based Systems, 2023, 263: 110312.
论文地址:[下载地址]
Abstract
基于梯度的元学习及其近似算法已广泛应用于小样本学习场景中。在实际应用中,训练好的元模型通常在不同任务中采用统一的梯度下降步骤设置。然而,元模型可能会对某些任务产生偏差,导致一些任务在几步内就能收敛,而其他任务则在整个内循环过程中无法接近最优解。这个偏差问题可能导致训练好的元模型在某些任务上表现良好,但在其他任务上却有意外的差强人意的表现,从而影响元模型的泛化能力。为了解决这个问题,本文正式建立了度量基策略与梯度下降在元测试中的近似关系。通过直接计算相似性进行数据分类,训练好的元模型避免了收敛问题。我们指出,度量基方法在元测试中能够紧密地近似梯度下降,前提是保证在元训练过程中衍生特征的表示能力和内循环的收敛性。在此基础上,我们提出了一种新的元学习模型GMT2(基于梯度的元训练与基于度量的元测试),通过结合元训练中的梯度下降与元测试中的度量基方法来解决这个问题。GMT2采用了一种新的一阶近似方案,使用对抗更新策略,这不仅增强了内层特征表示,还能够在不计算二阶导数的情况下,允许足够的内循环梯度步骤。实验表明,GMT2在效率和准确性上均优于流行的元学习模型。
1 Introduction
经典的机器学习方法通常在大量的监督样本上训练模型。通过深度神经网络1 2 3和大数据集4 5,机器学习方法能够取得优秀的结果。然而,在现实生活中,收集大量的监督数据是一个沉重的负担。此外,由于隐私、安全和伦理问题,现有的监督数据有限。例如,在药物发现领域6,临床候选者的生物记录不足以支持传统深度神经网络在这一任务中的应用。另一方面,人类可以通过少量的示例和先前的经验迅速学习,我们希望将这种能力赋予我们的机器学习模型。这些问题促使了机器学习的一个特殊分支——小样本学习7。
最近,许多元学习方法在小样本设置中被提出8 9 10 11 12 13 14。其中,MAML9 是迄今为止最受欢迎的一种。与传统机器学习不同,MAML通过包含支持数据集(训练数据集)和查询数据集(测试数据集)的多种任务来训练元模型。其算法由内外优化循环组成。内循环使用基于梯度的方法来使元模型适应特定任务的支持数据集;外循环则通过反向传播来更新元参数,以适应在查询数据集上训练的模型损失。整个过程被称为元训练。同样,在元模型完全训练之后,元测试过程再次使用梯度下降方法使其适应新任务并评估其性能。
然而,在元测试过程中,MAML及其近似方法9 10 11 12 13 15 16 面临一个共同的收敛问题,即一些任务可能在几个步骤内收敛,而其他任务则无法在整个过程中接近最优解。这是因为它们对所有新任务设置相同的梯度步数,而训练好的元模型通常会对某些任务产生偏差14。我们在图1中展示了这一问题的示意图。由于不同任务需要不同的梯度步数才能收敛,图1(a)中的元模型就面临了收敛问题。由于偏差问题,尽管训练良好的元模型在某些任务上表现良好,但它可能在其他任务上表现不佳。这损害了元模型的泛化能力。我们可以增加梯度步数以强制收敛,如图1(b)所示,但这会带来额外的计算成本。
图 1. 在元测试过程中为新任务设置适当的梯度步骤数的挑战的简单示意图。我们使用一个经过良好训练的元模型,展示了6个新任务的适应循环,这些任务分别分配了5个梯度步骤(图 1(a))和30个梯度步骤(图 1(b))。这里我们通过“总梯度步骤”轴连接它们的结果。图 1(a) 中,第4和第5个任务在5个步骤下的损失表明它们未能收敛。尽管所有任务在30个梯度步骤下均已收敛,如图 1(b) 所示,但这会带来较高的计算成本,因为大多数梯度步骤是多余的。显然,收敛性和计算成本之间的平衡是任务特定的。
据我们所知,目前尚无正式的研究通过改变元测试方法来解决偏差问题。TAML14 在元训练过程中加入了正则化项,以获得无偏的元模型。ANIL17 提出了使用冻结的元初始化表示的余弦相似度来执行元测试中的未见任务。尽管其主要目的是通过去除元测试中的头部来节省计算成本,但它恰好避免了收敛问题,且性能下降较小。然而,ANIL并未对元测试中基于度量的方法的有效性进行理论分析,且该新方法的进一步优化仍需探索。
在本文中,我们正式建立了梯度下降与基于度量的策略在元测试中的近似关系。具体来说,如果在元训练过程中能够保证内循环的收敛性和衍生特征的表示能力,那么基于度量的方法可以紧密地近似于元测试中的梯度下降。基于这种有条件的近似关系,我们提出了一种新的元学习算法GMT2,结合了基于梯度的元训练与基于度量的元测试。它采用对抗更新来增强特征表示,并在元训练阶段使用一阶更新来减少计算成本。因此,GMT2在大幅降低计算成本的同时,超越了现有的基于梯度的元学习方法。对多个数据集的实验验证了我们提出方法的有效性和效率。我们总结了本工作的贡献如下:
- 我们提出了一种新颖的一阶对抗元学习算法GMT2,能够以更低的计算成本高效地增强元模型的特征表示。
- 我们正式建立了梯度下降与元测试中基于度量策略的近似关系。
- 多个数据集上的实验表明,GMT2不仅比原始梯度下降更快,而且在与流行的小样本方法比较时,能够实现具有竞争力的准确性。
2. Related work
有很多场景面临小样本问题,如药物发现6 和无人机18。为了解决这些问题,提出了许多流行的小样本方法,包括基于度量的方法和基于梯度的方法。基于度量的方法旨在学习任务之间的共同特征,从而在特定的特征空间中公平地比较数据相似性。它们有一个共同的目标,即将同一类别的数据点保持在一起,而将不同类别的数据点分开19。距离或相似性规则被用来分类查询样本20。例如,原型网络21计算每个类别的支持样本的均值嵌入作为类别中心,并通过最近的类别中心来分类查询样本。与现有的基于度量的方法不同,我们的方法是基于度量和基于梯度的结合体,其中在GMT2中使用梯度下降来训练元模型,而基于度量的方法则用于元测试中的样本分类。
基于梯度的方法通常维护一个元模型,并通过梯度下降将其适应于每个特定任务。梯度下降还被用于外循环中来更新元参数。外循环中的梯度更新依赖于内循环中的梯度下降结果,因此引入了二阶导数。如果使用大量的内循环梯度步骤,那么二阶导数的计算成本将变得不可承受。为避免在MAML算法中计算二阶导数,已经提出了一些近似方法9 10 11 15 16。例如,FOMAML9是MAML的一级近似方法,其中省略了二阶导数。据报道,FOMAML大约能加速33%的网络计算。ANIL14只更新外层的二阶梯度,而忽略内层的二阶梯度。这些近似方法的目标是减少计算成本,但它们也继承了MAML中的偏差问题。
由于二阶更新和短视偏差22,通常会限制内步数的数量,以减少计算成本。与此同时,一些研究者23 24认为,更多的内步数有利于模型性能,并确保每个任务的模型收敛。更重要的是,训练的元模型可能会对某些任务产生偏差,这在分类任务14和强化学习任务25中都可以看到,这表明不同任务中的梯度步数是不同的。如图1所示,确定实际任务中适当的梯度步数仍然是一个挑战,需要领域专业知识和足够的计算能力。这些方法通常调整元训练过程,如添加正则化项,以生成无偏的元模型。与这些方法不同,GMT2通过改变元测试方法来避免偏差问题,并减少计算资源的消耗。
一些元学习方法13 17 26在内循环梯度下降过程中固定内层,只训练基础神经网络的最后一层。它们将内层视为特征提取器,最后一层则作为分类器。R2D212将岭回归作为其分类器,而MetaOptNet13使用SVM作为分类器,以获得闭式解。它们都不需要计算二阶导数。尽管这些方法也避免了在内循环中进行梯度下降,但它们的闭式解仍然存在一些复杂的矩阵运算,如任务特定模型的矩阵逆运算(见图2)。
图 2. 我们方法的示意图。这里我们使用3-way 1-shot任务,包含3个类别,每个类别有1个支持样本。在元训练期间,我们通过基于梯度的方法对特征提取器 f φ f_\varphi fφ进行元训练。具体而言,我们使用多个内部梯度步骤在内部循环中学习最优分类器 W ′ W' W′,同时固定内层 f φ f_\varphi fφ。在外部循环中,我们根据元损失使用 W W W的对抗更新来更新 f φ f_\varphi fφ。在元测试阶段,我们使用度量方法对每个查询样本进行分类。具体地,我们使用训练好的特征提取器 f φ f_\varphi fφ来获取查询数据的嵌入表示。基于低维特征,我们计算支持数据集中的所有类别中心,并将查询样本分配给最近中心的对应类别。
3. 基于对抗梯度的元训练与基于度量的元测试
3.1. Preliminary
与传统的机器学习不同,小样本学习通过N-way K-shot任务来训练模型。N-way表示任务中有N个类别,K-shot表示每个类别有K个支持(训练)样本。除了支持样本外,任务中还有查询(测试)样本,用来衡量训练模型的性能。给定支持数据集 D i s u p p o r t D_i^{support} Disupport和查询数据集 D i q u e r y D_i^{query} Diquery,我们将N-way K-shot任务表示为 T i = ( D i s u p p o r t , D i q u e r y ) T_i = (D_i^{support}, D_i^{query}) Ti=(Disupport,Diquery)。每个数据集由数据点 x x x和对应的标签 y y y组成。
这里以MAML为例介绍基于梯度的方法。MAML9维护一个基础模型(元模型) f θ f_{\theta} fθ,其元参数为 θ \theta θ。对于特定任务 T i T_i Ti,基础模型 f θ f_{\theta} fθ使用其支持数据集 D i s u p p o r t D_i^{support} Disupport通过梯度下降快速适应成任务特定的模型 f θ i ′ f_{\theta_i'} fθi′,这称为内梯度下降或内循环,如公式(1)所示:
θ i ′ = θ − λ ∇ θ L T i ( D i s u p p o r t ; f θ ) (1) \theta_i' = \theta - \lambda \nabla_{\theta} L_{T_i} (D_i^{support}; f_{\theta}) \tag{1} θi′=θ−λ∇θLTi(Disupport;fθ)(1)
其中, L ( D ; f θ ) = ∑ ( x , y ) ∈ D l ( f θ ( x ) , y ) L(D; f_{\theta}) = \sum_{(x,y) \in D} l(f_{\theta}(x), y) L(D;fθ)=∑(x,y)∈Dl(fθ(x),y)计算 f θ f_{\theta} fθ在数据集 D D D上的损失。接下来,MAML使用训练好的模型 f θ i ′ f_{\theta_i'} fθi′计算任务 T i T_i Ti的查询数据集上的损失,并收集这些任务的一个批次来形成最终的元损失,如公式(2)所示:
min θ ∑ T i ∼ p ( T ) L T i ( D i q u e r y ; f θ i ′ ) (2) \min_{\theta} \sum_{T_i \sim p(T)} L_{T_i}(D_i^{query}; f_{\theta_i'}) \tag{2} θminTi∼p(T)∑LTi(Diquery;fθi′)(2)
通过元损失,MAML再次使用梯度下降来更新元参数 θ \theta θ,这称为外梯度下降或外循环。它与其他任务批次重复上述过程,直到收敛。
原型网络(ProtoNet)21是小样本学习中的经典基于度量的方法。它首先为任务的支持数据集 D s u p p o r t D^{support} Dsupport中的每个类别 C n ⊆ D s u p p o r t C_n \subseteq D^{support} Cn⊆Dsupport找到所有的中心点 c n = 1 ∣ C n ∣ ∑ x ∈ C n x c_n = \frac{1}{|C_n|} \sum_{x \in C_n} x cn=∣Cn∣1∑x∈Cnx。然后,如公式(3)所示,它计算查询数据与每个类别中心之间的距离,并将查询数据标记为与最近中心的类别相同。
min ∥ f θ ( c n ) − f θ ( x q u e r y ) ∥ 2 2 (3) \min \|f_{\theta}(c_n) - f_{\theta}(x_{query})\|_2^2 \tag{3} min∥fθ(cn)−fθ(xquery)∥22(3)
与MAML不同,ProtoNet在任务之间进行元训练,以使同一类别中的数据点更接近,而不同类别的点更远。具体来说,ProtoNet在元训练期间更新特征提取器 f θ f_{\theta} fθ以最小化公式(3),并在元测试中选择最小化公式(3)的类别 n n n。表1列出了本文中广泛使用的符号。
3.2. Approximate gradient-based meta-test with metric methods
卷积神经网络(CNN)1 是图像分类中最常用的神经网络架构。其内部层(即除最后一层之外的层)使用卷积操作来学习特征,最后一层则是一个全连接层,作为线性分类器。我们也可以将内部层的输出看作是由 MAML 学到的特征。接下来,我们将元参数 θ \theta θ 分为 φ \varphi φ 和 W W W。我们将 CNN 的所有内部层表示为 f φ f_\varphi fφ,其中 φ \varphi φ 是其参数。同时,我们使用 W ∈ R F × N W \in \mathbb{R}^{F \times N} W∈RF×N 来表示最后一层(全连接层)及其参数,其中 F F F 表示 f φ f_\varphi fφ 输出向量的长度, N N N 是任务中的类别数。因此,我们可以重新表述基模型 f θ f_\theta fθ 的表示,如公式(4)所示:
f θ ( x ) = W T f φ ( x ) (4) f_\theta(x) = W^T f_\varphi(x) \tag{4} fθ(x)=WTfφ(x)(4)
假设我们有一个训练良好的元模型,我们将其调整为特定任务的模型,用于元测试过程。为方便起见,我们省略了上标和下标 i i i。在分类场景下, softmax ( f θ ( x ) ) \text{softmax}(f_\theta(x)) softmax(fθ(x)) 输出的是 x x x 属于第 n n n 类的概率,其中 n n n 表示类别的位置。当 x x x 属于第 n n n 类时,标签 y y y 的值是一个单位向量 e n e_n en,其第 n n n 位为 1,其他位置为 0。我们也可以将任务的目标函数写成 L2 范数形式,如公式(5)所示:
min f θ ∥ y − f θ ( x ) ∥ 2 2 (5) \min_{f_\theta} \|y - f_\theta(x)\|_2^2 \tag{5} fθmin∥y−fθ(x)∥22(5)
如我们所见,适应目标是使得 f θ ( x ) f_\theta(x) fθ(x) 尽可能接近其标签 y y y。因此,如果我们有足够的梯度下降,可以得到如下近似式,如公式(6)所示:
y ≈ f θ ′ ( x ) (6) y \approx f_\theta'(x) \tag{6} y≈fθ′(x)(6)
对于测试,原始过程是选择 f θ ′ ( x test ) f_\theta'(x_{\text{test}}) fθ′(xtest) 的最大概率,并预测其类别标签。实际上,这相当于选择一个单位向量 e n e_n en 作为 y y y,使其最小化 ∥ y − f θ ′ ( x test ) ∥ 2 2 \|y - f_\theta'(x_{\text{test}})\|_2^2 ∥y−fθ′(xtest)∥22。该目标与公式(5)相似,尽管我们现在的目标是寻找一个最佳的类别标签 y y y,而不是最佳的模型 f θ ′ f_\theta' fθ′。我们将测试目标展开如下:
min y ∥ y − f θ ′ ( x test ) ∥ 2 2 = ( y − f θ ′ ( x test ) ) T ( y − f θ ′ ( x test ) ) = y T y − 2 y T f θ ′ ( x test ) + f θ ′ ( x test ) T f θ ′ ( x test ) (7) \min_y \|y - f_\theta'(x_{\text{test}})\|_2^2 = (y - f_\theta'(x_{\text{test}}))^T (y - f_\theta'(x_{\text{test}})) = y^T y - 2y^T f_\theta'(x_{\text{test}}) + f_\theta'(x_{\text{test}})^T f_\theta'(x_{\text{test}}) \tag{7} ymin∥y−fθ′(xtest)∥22=(y−fθ′(xtest))T(y−fθ′(xtest))=yTy−2yTfθ′(xtest)+fθ′(xtest)Tfθ′(xtest)(7)
我们去掉第一个项 y T y = 1 y^T y = 1 yTy=1 和第三个项,因为它们是常数。剩下的第二项,实际上是在计算 f θ ′ ( x test ) f_\theta'(x_{\text{test}}) fθ′(xtest) 和 y y y 之间的余弦相似度。根据公式(6),我们知道如果模型 f θ ′ f_\theta' fθ′ 训练得很好, f θ ′ ( x ) f_\theta'(x) fθ′(x) 可以近似其标签 y y y。为了消除统计误差,我们使用 f θ ′ ( c n ) f_\theta'(c_n) fθ′(cn) 来表示类别标签 y y y,其中 c n c_n cn 是带有类别标签 y y y 的支持数据的类别中心。因此,我们可以将测试目标重新写为公式(8):
max n f θ ′ ( c n ) T f θ ′ ( x test ) (8) \max_n f_\theta'(c_n)^T f_\theta'(x_{\text{test}}) \tag{8} nmaxfθ′(cn)Tfθ′(xtest)(8)
我们可以进一步将 f θ ′ f_\theta' fθ′ 分解为 f φ ′ f_\varphi' fφ′ 和 W ′ W' W′,如公式(4)所示。给定训练好的元模型,先前的工作表明,在元测试过程中,内部层的参数几乎保持不变 17,这意味着 f φ ≈ f φ ′ f_\varphi \approx f_\varphi' fφ≈fφ′。然后我们得到公式(9)中的以下变换:
f θ ′ ( c n ) T f θ ′ ( x test ) = f φ ′ ( c n ) T W ′ W ′ T f φ ′ ( x test ) ≈ f φ ( c n ) T W ′ W ′ T f φ ( x test ) ≈ sim ( f φ ( c n ) , f φ ( x test ) ) (9) f_\theta'(c_n)^T f_\theta'(x_{\text{test}}) = f_\varphi'(c_n)^T W' W'^T f_\varphi'(x_{\text{test}}) \approx f_\varphi(c_n)^T W' W'^T f_\varphi(x_{\text{test}}) \approx \text{sim}(f_\varphi(c_n), f_\varphi(x_{\text{test}})) \tag{9} fθ′(cn)Tfθ′(xtest)=fφ′(cn)TW′W′Tfφ′(xtest)≈fφ(cn)TW′W′Tfφ(xtest)≈sim(fφ(cn),fφ(xtest))(9)
在这里,我们也可以使用其他度量方法来计算相似度 sim ( f φ ( c n ) , f φ ( x test ) ) \text{sim}(f_\varphi(c_n), f_\varphi(x_{\text{test}})) sim(fφ(cn),fφ(xtest)),而不是在 W ′ W ′ T W' W'^T W′W′T 度量空间中的余弦相似度 27。从公式(7)到(9),元测试的目标最终转化为查询数据 f φ ( x test ) f_\varphi(x_{\text{test}}) fφ(xtest) 与每个类别中心 f φ ( c n ) f_\varphi(c_n) fφ(cn) 在 W ′ W ′ T W' W'^T W′W′T 度量空间中的相似度。我们给出了 NIL 17 为什么通过加权其倒数第二层表示与支持数据集的相似度来对查询数据进行分类的分析原因。显然,使用 f φ ( c n ) T f φ ( x test ) f_\varphi(c_n)^T f_\varphi(x_{\text{test}}) fφ(cn)Tfφ(xtest) 可以跳过 W W W 的梯度下降过程,这是 NIL 的策略,但这可能会影响分类性能。尽管如此,我们注意到,如果内部层学习到更好的特征表示,则 f φ ( c n ) T f φ ( x test ) f_\varphi(c_n)^T f_\varphi(x_{\text{test}}) fφ(cn)Tfφ(xtest) 的结果会更接近于 f φ ( c n ) T W ′ W ′ T f φ ( x test ) f_\varphi(c_n)^T W' W'^T f_\varphi(x_{\text{test}}) fφ(cn)TW′W′Tfφ(xtest),且受 W ′ W ′ T W' W'^T W′W′T 的影响较小,因为同一类别的对象已经在 f φ f_\varphi fφ 的特征空间中紧密聚类。
因此,我们得出结论,基于度量的方法可以用来近似元测试过程中的梯度下降过程。现在我们不再需要在元测试过程中运行梯度下降,也不必担心梯度步骤的数量是否足够得到最佳的任务特定模型。显然,度量方法比梯度下降运行得更快。实验 4.3 中将展示准确率和运行时间的比较。此外,基于以上分析,我们在下一小节中提出了一种新的方法,用以增强内部层的特征表示。
3.3. Adversarial first-order training strategy
根据第 3.2 节的分析,我们通过将元训练中的梯度下降与元测试中的度量学习相结合,提出了一种新的元学习模型 GMT2。我们观察到,如果内部层的特征表示更强,则 Eq. (9) 中的近似更为紧密。为了更好地满足梯度下降的近似,GMT2 采用了一种对抗更新策略来增强内部层的特征表示,同时在最后一层进行更新。此外,GMT2 还采用了一种新的一阶训练策略,以允许足够的梯度步骤来确保内循环的收敛。这两者都使得如 Eq. (6) 所示的近似条件得以满足。
在元训练过程中,我们将内部层视为特征提取器,而将最后一层视为分类器。特征提取器在任务之间进行元训练,而分类器则适应每个任务。这一特性非常适合我们的分析,因为我们正是使用特征提取器在元测试过程中计算相似度,如 Eq. (9) 所示。在内循环中,我们固定内部层,仅更新最后一层,这遵循 ANIL 的训练策略。在此,我们设置更多的内循环步骤,以满足 Eq. (6) 的近似条件。在外循环中,我们更新内部层和最后一层的元初始化。具体来说,我们更新内部层 f ϕ f_{\phi} fϕ 以减少 Eq. (10) 中的损失,使用训练好的最后一层 W ′ W' W′。
min θ ∑ T i ∼ p ( T ) L T i ( D query , W ′ ; f ϕ ) (10) \min_{\theta} \sum_{T_i \sim p(T)} L_{T_i} (D_{\text{query}}, W'; f_{\phi}) \tag{10} θminTi∼p(T)∑LTi(Dquery,W′;fϕ)(10)
至于最后一层 W W W,虽然它在我们的元测试中没有起作用,但我们仍然可以利用其元初始化,并模拟 GAN 28 范式来帮助训练内部层的特征表示。具体来说,我们更新 W W W 以扰乱分类过程。如上所述,Eqs. (5) 和 (10) 的目标是使 f θ ′ f_{\theta'} fθ′ 的输出尽可能接近相应的 y y y,这也使得其熵非常小 14。相反,如果我们增加熵,就可以扰乱分类过程。在这里,我们采用最大化 f θ f_{\theta} fθ 在 D query D_{\text{query}} Dquery 上熵的形式,如 Eq. (11) 所示。
max W G ( D query , f ϕ ; W ) = − ∑ x ∈ D query W T f ϕ ( x ) log W T f ϕ ( x ) (11) \max_W G(D_{\text{query}}, f_{\phi}; W) = - \sum_{x \in D_{\text{query}}} W^T f_{\phi}(x) \log W^T f_{\phi}(x) \tag{11} WmaxG(Dquery,fϕ;W)=−x∈Dquery∑WTfϕ(x)logWTfϕ(x)(11)
通过对 W W W 的对抗更新,GMT2 可以学习到更加鲁棒和富有表现力的 f ϕ f_{\phi} fϕ。请注意, W ′ W' W′ 在我们的算法中并不参与 W W W 的更新,这与 ANIL 中的二阶更新不同。因此,GMT2 是一种一阶近似方法,它允许足够的内循环步骤来确保内循环的收敛。
根据度量学习 27,在 Eq. (9) 中度量矩阵 W ′ W ′ T W'W'^T W′W′T 的作用是使同类数据点更加接近,而不同类的数据点则更远。因此,我们提出了 Eq. (12) 作为我们的聚类正则化,增强每个类的聚类效果。
min R cluster = 1 N ∑ i sim ( f ϕ ( c i ) , f ϕ ( c ) ) − 1 N ∑ i ∣ C i ∣ ∑ j sim ( f ϕ ( c i ) , f ϕ ( x i , j ) ) (12) \min R_{\text{cluster}} = \frac{1}{N} \sum_i \text{sim}(f_{\phi}(c_i), f_{\phi}(c)) - \frac{1}{N} \sum_i |C_i| \sum_j \text{sim}(f_{\phi}(c_i), f_{\phi}(x_{i,j})) \tag{12} minRcluster=N1i∑sim(fϕ(ci),fϕ(c))−N1i∑∣Ci∣j∑sim(fϕ(ci),fϕ(xi,j))(12)
其中 sim ( ⋅ , ⋅ ) \text{sim}(\cdot, \cdot) sim(⋅,⋅) 是与 Eq. (9) 中使用的相同的相似度度量。
3.4. Algorithm
最后,我们将讨论汇总,得出最终的元损失函数 Eq. (13),
min ϕ , W J = ∑ T i ∼ p ( T ) L ( D query , W ′ ; f ϕ ) − α G ( D query , f ϕ ; W ) + β R cluster (13) \min_{\phi, W} J = \sum_{T_i \sim p(T)} L(D_{\text{query}}, W'; f_{\phi}) - \alpha G(D_{\text{query}}, f_{\phi}; W) + \beta R_{\text{cluster}} \tag{13} ϕ,WminJ=Ti∼p(T)∑L(Dquery,W′;fϕ)−αG(Dquery,fϕ;W)+βRcluster(13)
其中 α \alpha α 和 β \beta β 分别是对抗正则化和聚类正则化的系数。GMT2 的整个过程如图 2 所示。
在元训练期间,GMT2 首先在某个任务的支持数据集上学习一个最优的任务特定模型 f θ ′ f_{\theta'} fθ′,并通过 Eq. (13) 在查询数据集上计算元损失。我们在外循环中对一批任务进行梯度下降,并重复上述过程直到收敛,如算法 1 所示。对于 Eq. (13) 的第一项,我们将基础模型的内层与任务特定模型的最后一层结合,以元训练特征提取器 f ϕ f_{\phi} fϕ。在外循环中,我们对 f ϕ f_{\phi} fϕ 和 W W W 进行对抗性更新,以学习更好的特征提取器。至于元测试,GMT2 首先计算支持数据集中的每个类的类中心,然后通过 Eq. (9) 将查询数据分配到最近的中心对应的类。元测试过程如算法 2 所示。
4. Experiments
4.1. Experimental setup
4.1.1. Datasets
我们在实验中使用了 mini-Imagenet、tiered-Imagenet 和 CIFAR-FS 数据集。mini-Imagenet 数据集 29 是 ImageNet 数据集 4 的一个子集。总共有 100 个类,每个类包含 600 张 84 × 84 的彩色图像。这 100 个类被划分为 64、16 和 20 个类,分别用于元训练、元验证和元测试任务的采样。tiered-Imagenet 数据集 30 是 ImageNet 数据集的一个较大子集,包含 608 个类(779,165 张 84 × 84 的彩色图像),并在人工策划的层次结构中分为 34 个更高层次的类别。这些类别被分为 20 个训练类别、6 个验证类别和 8 个测试类别。CIFAR-FS(CIFAR100 few-shots)[12] 是从 CIFAR100 数据集 5 随机采样的。该数据集包含 100 个类,划分为 64 个训练类、16 个验证类和 20 个测试类。
4.1.2. Implementation details
我们的 GMT2 方法是 MAML 的一种近似方法,并且对 ANIL 进行了改进。为了方便比较,我们遵循它们的 conv4 架构。conv4 架构是一个浅层神经网络,包含四个卷积块和一个全连接层。每个卷积块按顺序包含 3 × 3 卷积(步幅 = 1,填充 = 1)、批量归一化、ReLU 非线性激活和 2 × 2 最大池化。在我们的方法中,所有四个卷积层都有 32 个过滤器,并且所有数据集使用相同的神经网络架构。
我们在实验中使用 5-way 1-shot 和 5-way 5-shot 任务设置,每个任务包含 15 个查询样本进行测试。如算法 1 所示,在元训练期间我们采样一批任务。我们为所有数据集设置任务的批量大小为 4,称之为一个 episode。元训练过程中总共有 60,000 个 episodes。在外循环中,我们使用 Adam 31 优化器进行元训练,学习率为 0.001。对于 mini-Imagenet 的 5-way 1-shot 任务,在完成 34,000 个 episodes 后,学习率衰减为原来的 0.1;对于 mini-Imagenet 的 5-way 5-shot 任务,在完成 28,000 个 episodes 后,学习率衰减为原来的 0.1。其余任务设置保持不变的学习率。我们每 200 个 episodes 进行一次元验证,并记录最佳验证模型用于元测试。
在内循环中,我们使用学习率为 0.01 的 SGD 来训练基础模型。内梯度步骤的数量为所有数据集均为 30。对于 mini-Imagenet 和 tiered-Imagenet 的 5-way 1-shot 任务,我们设置正则化系数 α = 1.0,β = 0.1;而对于其余任务设置,我们将 α 设置为 0.1,β 设置为 0.01。我们使用余弦相似度作为相似度度量 sim(·, ·)。实验在一台配备 Intel® Xeon® Gold 6254 CPU @ 3.10 GHz 和 GeForce RTX 2080 Ti 的服务器上进行。
4.2. Few-shot classification
在本节中,我们将我们的方法与其他基于梯度的方法(如 MAML 9)和基于度量的方法(如 Matching Nets 29)进行图像分类准确度的比较。我们还记录了近似方法(如 Reptile 16)的结果。
比较少样本学习方法的一个障碍是不同方法使用了不同的实验设置 [^32, ^33],例如不同的基础神经网络架构。为了公平比较,我们将所有比较方法设置为相同的 conv4 架构,只有滤波器的数量有所不同。我们在 mini-Imagenet、tiered-Imagenet 和 CIFAR-FS 数据集上的结果分别展示在 表 2、表 3 和 表 4 中。我们使用几个新任务来衡量模型的性能,并记录具有 95% 置信区间的平均分类准确度。
我们可以看到,GMT2 超越了大多数方法,并且与最好的方法相比具有竞争力。我们的方法的表现优于 NIL,并且与基于度量的方法具有竞争力,这证实了我们学习到的特征表示是有效的。另一方面,我们还记录了不使用 Rcluster 的结果。使用和不使用 Rcluster 的结果非常接近,但当样本数量较少时,Rcluster 确实能提供一些帮助。
最后,我们在 ResNet-12 网络上评估了 GMT2,并将其与当前流行的少样本方法进行了比较,结果见 表 5 和 表 6。正如我们所看到的,使用类似的骨干网络,GMT2 的表现与当前流行的少样本方法(如 MetaOptNet)具有竞争力。因此,扩大基础网络的容量适用于 GMT2,并且是提升其性能的有效方式。
4.3. Comparisons of metric methods and gradient descent
在本实验中,我们将评估在元测试过程中用基于度量的方法替代梯度下降时,准确性提升和时间减少的效果。我们使用30个梯度步骤来保证梯度下降在元测试任务中的收敛性,结果如表7所示。这里我们使用的是带有Rcluster的方法。
从表7中可以看出,基于度量的元测试不仅比基于梯度的元测试运行速度快,而且在两个数据集上都取得了更好的分类准确性。显然,计算梯度比直接计算相似度要复杂得多,尤其是当我们采用更多的梯度步骤时。此外,尽管我们采取了许多梯度步骤,但梯度下降可能并没有使损失函数收敛并达到最小值。因此,这些结果表明,在元测试过程中,使用相似度度量比梯度下降更加高效。
4.4. Varying inner steps in meta-training
在本实验中,我们通过改变元训练中的内梯度步骤数,来展示GMT2的运行时间和GPU内存使用情况。这里我们记录了200个episodes的运行时间,并将MAML作为基准进行比较。
运行时间和GPU内存使用情况的对比结果分别呈现在图3(a)和图3(b)中。我们可以看到,GMT2在内梯度步骤从10步增加到60步时,时间增量非常小,而MAML在60个内梯度步骤下的时间几乎是10步时的8倍。至于GPU内存使用情况,GMT2的内存使用一直保持不变,而MAML在内梯度步骤从10步增加到60步时,内存使用量大幅增加。增加内梯度步骤对GMT2没有额外的GPU内存消耗,因为它在元训练过程中不会计算二阶导数,而增加内梯度步骤会导致MAML需要更多的二阶导数计算。这些结果解释了为什么许多基于梯度的方法,如MAML,倾向于使用较小的内梯度步骤,而不是较大的步骤,尽管增加内步骤的数量确实有助于损失函数的收敛,并使得模型表现更好。根据以上结果,我们可以确认GMT2允许设置足够的内步骤来保证内循环的收敛。
图 3. MAML与我们方法在内梯度步骤从10到60的比较。图 3(a) 显示了在200个回合中,MAML和我们方法的元训练运行时间。图 3(b) 显示了MAML和我们方法的GPU内存使用情况。
4.5. Feature representation before and after meta-training
如第3.2节和第3.3节所述,GMT2增强了内层的特征表示,并利用其在元测试中进行样本分类。如果我们从度量方法的角度回顾我们的方法,可以发现它实际上训练了一个很好的特征提取器。在这里,我们通过视觉方式深入分析其特征表示。
我们使用线性判别分析(LDA)32、主成分分析(PCA)33和t-SNE34来投影从GMT2内层提取的特征。元测试任务从mini-ImageNet、tiered-ImageNet和CIFAR-FS数据集中随机采样,如图4–6所示。为了进行比较,我们还绘制了在元训练前的初始模型和MAML的结果。如我们所见,初始模型和MAML无法很好地区分不同类别的点。然而,GMT2的样本特征表示较好,来自同一类的样本在特征空间 f ϕ f_{\phi} fϕ中趋向于聚集在一起,而不同类别的样本则彼此远离。这个特征对于度量方法在元测试中正确分类样本非常有帮助。此外,观察结果告诉我们,通过GMT2的对抗策略训练的内层能够学习到非常好的特征表示。
另一方面,对于MAML来说,在元测试中使用度量方法并不合适,因为根据可视化结果,它的内层特征表示并不强。这表明我们方法的局限性:只有当内层具有较强的特征表示时,度量方法才能在元测试中近似梯度下降。如果其他现有的基于梯度的少样本方法希望在元测试中采用度量策略,必须首先检查其内层特征表示的强度。我们将在未来的工作中进行探讨。
因此,从可视化结果中我们可以得出结论:GMT2的基于梯度的部分确实有能力元训练一个良好的特征提取器,并且直接使用度量方法在元测试中分类样本是可行的。
图 4. 在 mini-ImageNet 数据集上提取的由内层
f
ϕ
f_{\phi}
fϕ 提取的特征。我们在这里使用 LDA、PCA 和 t-SNE 方法来选择第一和第二主成分,并将它们绘制在 2D 图中。不同的颜色代表不同的类别。图 4(a)、4(d) 和 4(g) 显示了初始模型的结果;图 4(b)、4(e) 和 4(h) 显示了 MAML 的结果;图 4©、4(f) 和 4(i) 显示了 GMT2 的结果。
图 5. 在 tiered-ImageNet 数据集上提取的由内层
f
ϕ
f_{\phi}
fϕ 提取的特征。我们在这里使用 LDA、PCA 和 t-SNE 方法来选择第一和第二主成分,并将它们绘制在 2D 图中。不同的颜色代表不同的类别。图 5(a)、5(d) 和 5(g) 显示了初始模型的结果;图 5(b)、5(e) 和 5(h) 显示了 MAML 的结果;图 5©、5(f) 和 5(i) 显示了 GMT2 的结果。
图 6. 在 CIFAR-FS 数据集上提取的由内层
f
ϕ
f_{\phi}
fϕ 提取的特征。我们在这里使用 LDA、PCA 和 t-SNE 方法来选择第一和第二主成分,并将它们绘制在 2D 图中。不同的颜色代表不同的类别。图 6(a)、6(d) 和 6(g) 显示了初始模型的结果;图 6(b)、6(e) 和 6(h) 显示了 MAML 的结果;图 6©、6(f) 和 6(i) 显示了 GMT2 的结果。
4.6. Numerically measuring feature representation
我们在元训练前后计算RFC35,以数值化地衡量GMT2的特征表示。RFC使用类内方差与类间方差的比率来度量特征的聚类情况,如公式(14)所示。
R F C ( θ , x i , j ) = C N ∑ i , j ∥ f θ ( x i , j − μ i ) ∥ 2 2 / ∑ i ∥ μ i − μ ∥ 2 2 RFC (\theta, x_{i,j}) = \frac{C}{N} \sum_{i,j} \left\| f_{\theta} (x_{i,j} - \mu_i) \right\|_2^2 \Bigg/ \sum_i \left\| \mu_i - \mu \right\|_2^2 RFC(θ,xi,j)=NCi,j∑∥fθ(xi,j−μi)∥22/i∑∥μi−μ∥22
其中, x i , j x_{i,j} xi,j是类 i i i中的特征向量, μ i \mu_i μi是类 i i i中所有特征向量的均值, μ \mu μ是所有特征向量的均值, C C C是类别的数量, N N N是每个类别的样本数量。RFC值越低,意味着类间分离越好。从表8中的结果可以看出,RFC在元训练后显著降低。
4.7. Comparing similarity metrics
GMT2可以在其元测试中选择其他度量方法。这里我们测试了欧几里得距离是否比余弦相似度在元测试中表现更好。我们将结果列于表9。
我们可以发现,使用余弦相似度的分类结果更好。此外,我们还观察到,欧几里得距离在从5-way 1-shot到5-way 5-shot任务中有较大的准确度提升。由于在5-way 1-shot任务中只有一个数据点来计算类别中心,这会导致较大的统计误差。因此,当我们使用更多样本(5-shot)时,欧几里得距离的表现变得与余弦相似度相竞争。另一方面,欧几里得距离在公式(16)中的操作比余弦相似度在公式(15)中的操作少,因此欧几里得距离的运行时间更短。可能经过一些调整后,欧几里得距离会取得更好的性能,但目前余弦相似度仍然是更好的选择。
4.8. Ablation study
有人可能会发现,GMT2的距离正则化与ProtoNet的损失函数类似21。它们都通过更新特征提取器使同一类别的数据点更接近。为了衡量这一影响,在这里我们去除了除了Rcluster以外的所有项,并将其调整为度量基方法。它将作为基准,展示在元训练过程中梯度下降方法学习的特征效果。我们使用余弦相似度和欧几里得距离作为相似度度量,分别如公式(15)和公式(16)所示。
sim cos ( a , b ) = a ⋅ b ∥ a ∥ ∥ b ∥ \text{sim}_{\text{cos}}(a, b) = \frac{a \cdot b}{\|a\|\|b\|} simcos(a,b)=∥a∥∥b∥a⋅b
sim Euc ( a , b ) = − ∥ a − b ∥ 2 \text{sim}_{\text{Euc}}(a, b) = -\|a - b\|^2 simEuc(a,b)=−∥a−b∥2
在表10中展示了在tiered-ImageNet和CIFAR-FS上的对比结果。结合原始结果,表10中的结果显示,正是在元训练过程中通过梯度下降学习的特征,使得度量基方法在元测试中表现更好,而不是Rcluster学习的特征。
5. Conclusion
在本文中,我们提出了一种新颖的一阶对抗元学习算法GMT2,该算法结合了基于梯度的元训练和基于度量的元测试。 (1) 在元测试中,我们正式建立了基于度量的策略与梯度下降之间的近似关系。它要求内循环应该有足够的步骤来收敛,并且在元训练过程中内层应该具有表达性的特征表示。 (2) 此外,我们设计了一种一阶对抗训练策略,以帮助满足元测试中的近似条件。最后一层的对抗更新阻止了内层的正确分类,而内层则试图学习能够实现更清晰类别区分的特征表示。 (3) 实验结果表明,我们的方法不仅完成一个少样本任务所需时间更少,而且在相似的实验设置下与流行的少样本方法表现出竞争力。
至于未来的工作,我们可以研究公式(9)中更好的度量方法,以提高元测试性能。此外,可能存在更好的特征提取器训练策略。此外,正如在视觉实验中所示,使用基于度量的方法来逼近梯度下降依赖于内层的强大特征表示。如果我们在未来的工作中克服了这一限制,使用我们的方法来提高几乎所有基于梯度的元学习算法的性能是很有前景的。
A. Krizhevsky, I. Sutskever, G.E. Hinton, Imagenet classification with deep convolutional neural networks, Adv. Neural Inf. Process. Syst. 25 (2012) 1097–1105. ↩︎ ↩︎
K. Simonyan, A. Zisserman, Very deep convolutional networks for large-scale image recognition, 2014, arXiv preprint arXiv:1409.1556. ↩︎
K. He, X. Zhang, S. Ren, J. Sun, Deep residual learning for image recognition, in: Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, 2016, pp. 770–778. ↩︎
J. Deng, W. Dong, R. Socher, L.-J. Li, K. Li, L. Fei-Fei, Imagenet: A large-scale hierarchical image database, in: 2009 IEEE Conference on Computer Vision and Pattern Recognition, 2009, pp. 248–255. ↩︎ ↩︎
A. Krizhevsky, G. Hinton, Learning Multiple Layers of Features from Tiny Images, Citeseer, 2009. ↩︎ ↩︎
H. Altae-Tran, B. Ramsundar, A.S. Pappu, V. Pande, Low data drug discovery with one-shot learning, ACS Central Sci. 3 (4) (2017) 283–293. ↩︎ ↩︎
Y. Wang, Q. Yao, J.T. Kwok, L.M. Ni, Generalizing from a few examples: A survey on few-shot learning, ACM Comput. Surv. 53 (3) (2020) 1–34. ↩︎
S. Ravi, H. Larochelle, Optimization as a model for few-shot learning, in: 5th International Conference on Learning Representations, ICLR 2017, Toulon, France, April 24–26, 2017, Conference Track Proceedings. ↩︎
C. Finn, P. Abbeel, S. Levine, Model-agnostic meta-learning for fast adaptation of deep networks, in: Proceedings of the 34th International Conference on Machine Learning, vol. 70, 2017, pp. 1126–1135. ↩︎ ↩︎ ↩︎ ↩︎ ↩︎ ↩︎ ↩︎
P. Zhou, X. Yuan, H. Xu, S. Yan, J. Feng, Efficient meta learning via minibatch proximal update, Adv. Neural Inf. Process. Syst. 32 (2019) 1534–1544. ↩︎ ↩︎ ↩︎
M.A. Jamal, L. Wang, B. Gong, A lazy approach to long-horizon gradientbased meta-learning, in: Proceedings of the IEEE/CVF International Conference on Computer Vision, 2021, pp. 6577–6586. ↩︎ ↩︎ ↩︎
L. Bertinetto, J.F. Henriques, P.H.S. Torr, A. Vedaldi, Meta-learning with differentiable closed-form solvers, in: International Conference on Learning Representations, 2019. ↩︎ ↩︎ ↩︎
K. Lee, S. Maji, A. Ravichandran, S. Soatto, Meta-learning with differentiable convex optimization, in: Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, 2019, pp. 10657–10665. ↩︎ ↩︎ ↩︎ ↩︎
M.A. Jamal, G.-J. Qi, Task agnostic meta-learning for few-shot learning, in: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, 2019, pp. 11719–11727. ↩︎ ↩︎ ↩︎ ↩︎ ↩︎ ↩︎
A. Rajeswaran, C. Finn, S. Kakade, S. Levine, Meta-learning with implicit gradients, Adv. Neural Inf. Process. Syst. (2019). ↩︎ ↩︎
A. Nichol, J. Achiam, J. Schulman, On first-order meta-learning algorithms, 2018, arXiv preprint arXiv:1803.02999. ↩︎ ↩︎ ↩︎
A. Raghu, M. Raghu, S. Bengio, O. Vinyals, Rapid learning or feature reuse? towards understanding the effectiveness of MAML, in: International Conference on Learning Representations, 2020. ↩︎ ↩︎ ↩︎ ↩︎
D. Yuan, X. Chang, Z. Li, Z. He, Learning adaptive spatial-temporal contextaware correlation filters for UAV tracking, ACM Trans. Multimed. Comput. Commun. Appl. 18 (3) (2022) 1–18. ↩︎
G. Koch, R. Zemel, R. Salakhutdinov, Siamese neural networks for one-shot image recognition, in: ICML Deep Learning Workshop, vol. 2, 2015. ↩︎
D. Yuan, W. Kang, Z. He, Robust visual tracking with correlation filters and metric learning, Knowl.-Based Syst. 195 (2020) 105697. ↩︎
J. Snell, K. Swersky, R. Zemel, Prototypical networks for few-shot learning, Adv. Neural Inf. Process. Syst. 30 (2017) 4077–4087. ↩︎ ↩︎ ↩︎
Y. Wu, M. Ren, R. Liao, R. Grosse, Understanding short-horizon bias in stochastic meta-optimization, in: International Conference on Learning Representations, 2018. ↩︎
H.-J. Ye, W.-L. Chao, How to train your MAML to excel in few-shot classification, in: International Conference on Learning Representations, 2022. ↩︎
J. Shin, H.B. Lee, B. Gong, S.J. Hwang, Large-scale meta-learning with continual trajectory shifting, in: International Conference on Machine Learning, 2021, pp. 9603–9613. ↩︎
H. Liu, R. Socher, C. Xiong, Taming MAML: Efficient unbiased metareinforcement learning, in: International Conference on Machine Learning, 2019, pp. 4061–4071. ↩︎
L. Bertinetto, J.F. Henriques, J. Valmadre, P. Torr, A. Vedaldi, Learning feed-forward one-shot learners, Adv. Neural Inf. Process. Syst. (2016) 523–531. ↩︎
B. Kulis, Metric learning: A survey, Found. Trends Mach. Learn. 5 (4) (2012) 287–364. ↩︎ ↩︎
I. Goodfellow, J. Pouget-Abadie, M. Mirza, B. Xu, D. Warde-Farley, S. Ozair, A. Courville, Y. Bengio, Generative adversarial nets, Adv. Neural Inf. Process. Syst. 27 (2014). ↩︎
O. Vinyals, C. Blundell, T. Lillicrap, D. Wierstra, Matching networks for one shot learning, Adv. Neural Inf. Process. Syst. 29 (2016) 3630–3638. ↩︎ ↩︎
M. Ren, E. Triantafillou, S. Ravi, J. Snell, K. Swersky, J.B. Tenenbaum, H. Larochelle, R.S. Zemel, Meta-learning for semi-supervised few-shot classification, in: International Conference on Learning Representations, 2018. ↩︎
D.P. Kingma, J. Ba, Adam: A method for stochastic optimization, in: International Conference on Learning Representations, 2014. ↩︎
R.A. Fisher, The use of multiple measurements in taxonomic problems, Ann. Eugen. 7 (2) (1936) 179–188. ↩︎
K. Pearson, LIII. On lines and planes of closest fit to systems of points in space, Lond. Edinb. Dublin Philos. Mag. J. Sci. 2 (11) (1901) 559–572. ↩︎
L. Van der Maaten, G. Hinton, Visualizing data using T-SNE, J. Mach. Learn. Res. 9 (11) (2008). ↩︎
M. Goldblum, S. Reich, L. Fowl, R. Ni, V. Cherepanova, T. Goldstein, Unraveling meta-learning: Understanding feature representations for few-shot tasks, in: International Conference on Machine Learning, 2020, pp. 3607–3616. ↩︎