当前位置: 首页 > article >正文

The Surprising Effectiveness of Test-Time Training for Abstract Reasoning

The Surprising Effectiveness of Test-Time Training for Abstract Reasoning

基本信息

博客贡献人

JokerLin

作者

Ekin Akyürek,Mehul Damani,Linlu Qiu,Han Guo,Yoon Kim,Jacob Andreas

标签

LMs,Test-time Training,Abstraction and Reasoning Corpus,Neural Network

摘要

​ 语言模型在其训练分布中的任务上表现出令人印象深刻的性能,但经常难以解决需要复杂推理的新问题。本文以 ARC 为基准,研究了测试时训练 (Test-Time Training,TTT) 的有效性——在推理过程中使用从输入数据得出的损失临时更新模型参数——作为提高模型推理能力的机制。通过系统实验,确定了成功 TTT 的三个关键组成部分:(1)对类似任务进行初始微调 (2)辅助任务格式和增强 (3)逐个实例训练。TTT 显著提高了 ARC 任务的性能,与基本微调模型相比,准确率提高了约 6 倍;将 TTT 应用于 8B 参数语言模型,在 ARC 的公共验证集上实现了 53% 的准确率,将纯神经方法的最新技术提高了近 25%。通过将本文的方法与最近的程序生成方法集成,获得了 61.9% 的 SoTA 公共验证准确率,与人类的平均分数相匹配。本文的研究结果表明,显式符号搜索并不是改进神经语言模型中抽象推理的唯一途径;将额外的测试时间应用于对 few-shot 示例的继续训练也可能非常有效。

引言

​ 大规模神经语言模型 (LMs) 擅长执行其训练数据中出现的任务,并且通常是这些任务的基本变体或组合。给定自然语言任务规范或少量示例,LMs 通常会成功推断出所需的任务并生成适当的输出。对于复杂和新颖的任务,通常很难仅通过从 LMs 中抽样来获得正确答案。

​ 然而近年来的一个重要发现是,可以通过额外的测试时计算来增强 LM 解码,从而显著提高 LM 性能。这类方法包括思维链提示、多数投票抽样、代码执行和搜索。最近受到关注的一种扩展策略是测试时训练 (TTT),其中模型通过基于测试时输入的显式梯度步骤进行更新。这种方法与标准微调不同,因为它在极低数据范围内运行。

​ 本文在 ARC 中评估了这些方法,这是一组极具挑战性的 few-shot 视觉推理问题。ARC 是测试 LM 泛化极限的理想基准,因为它以新颖的形式呈现新颖的任务,需要重要的搜索和推理功能。当前语言模型在 ARC 上的性能不佳,大多数成功的方法都依赖于程序合成技术。

在这里插入图片描述

​ 通过仔细选择这些组件,TTT 可以显著提高 ARC 上的 LM 性能——与 1B 模型相比,准确性提高了六倍,并使用 8B 模型在 ARC 任务上为已发布的纯神经模型获得最先进的结果。事实上,本文的结果表明,当配备测试时训练时,普通 LM 可以达到或超过 ARC 上许多神经符号方法的性能。

​ 本文的贡献如下:

  • 通过新颖的测试时训练数据生成和自洽组件,确定并系统地分析了 ARC 任务的测试时训练所需的关键组件
  • 在 ARC 验证集上已发布的神经方法中取得了最先进的结果:
    • 使用 8B 参数模型的公共验证集的准确率为 53%
    • 与程序综合方法集成时,准确率为 61.9%,与数据集上的平均人类表现相匹配
  • 证明以前只能通过程序综合解决的任务,可以用配备 TTT 框架的全神经方法来解决

​ 这些结果挑战了符号分量对于解决此类复杂任务是绝对必要的假设。

简介

ARC挑战

​ ARC 旨在通过语言模型解决视觉难题的能力来评估语言模型的抽象推理能力。每个任务由二维网格的输入输出对(最大 30 × 30 个)组成,其中包含由多达 10 种不同颜色组成的形状或图案,如图 1(b) 所示。每对的输出是通过应用直观且共享的转换规则或函数 y = f (x) 获得的。

​ ARC 中的每个任务都由训练和测试拆分组成,其中包含:

  • 训练示例表示为 ( x k t r a i n , y k t r a i n ) k = 1 K (x^{train}_{k},y^{train}_k)^K_{k=1} (xktrain,yktrain)k=1K(通常 K 范围为 2 到 7)。
  • 测试示例表示为 ( x m t e s t , y m t e s t ) m = 1 M (x^{test}_{m},y^{test}_m)^M_{m=1} (xmtest,ymtest)m=1M(通常 M 范围为 1 到 3)。

​ 给定这组训练示例,目标是通过推理底层转换来预测测试输入 x t e s t x_{test} xtest 的测试输出 y t e s t y_{test} ytest

​ 将任务表示为 d = ( x t r a i n , y t r a i n , x t e s t , y t e s t ) d = (x_{train}, y_{train}, x_{test}, y_{test}) d=(xtrain,ytrain,xtest,ytest),其中 d ∈ D A R C d \in D_{ARC} dDARC,即此类 ARC 任务的集合。ARC 数据集的原始训练集和验证集,分别为 D A R C t r a i n D^{train}_{ARC} DARCtrain D A R C v a l D^{val}_{ARC} DARCval,每个数据集由 400 个任务组成。 成功标准需要为所有测试输出生成完全匹配项。

​ 大多数 ARC 方法可以分为两大类:程序综合和完全神经。程序综合方法尝试先找到变换函数 f f f,然后再将其应用于测试示例。另一方面,完全神经方法尝试直接预测输出 y t e s t y_{test} ytest,仅隐含地推理潜在的转换。在这项工作中,本文使用一种完全神经的方法,使用 LM 来预测测试输出。

上下文学习

​ 在一定规模上,许多 LM 表现出适应新任务的能力,而无需通过简单地调节提供的输入示例或指令来更新其参数。给定一系列输入-输出对 ( x 1 , y 1 ) , . . . , ( x n , y n ) (x_1, y_1), ..., (x_n,y_n) (x1,y1),...,(xn,yn) 和一个新的输入 x n + 1 x_{n+1} xn+1,LM 可生成输出 y ^ n + 1 \widehat y_{n+1} y n+1
y ^ n + 1 = L M ( ⋅ ∣ x 1 , y 1 , . . . x n , y n , x n + 1 ) (1) \widehat y_{n+1} = LM(\cdot | x_1,y_1,...x_n,y_n,x_{n+1})\tag{1} y n+1=LM(x1,y1,...xn,yn,xn+1)(1)
​ 上下文学习与任何标准的机器学习算法都不同,并且它不能开箱即用地用于新任务。

测试时训练 TTT

​ 一般的 TTT 过程如下:从初始模型参数 θ 0 θ_0 θ0 开始,对于每个测试输入(或批量输入),首先从测试输入生成训练数据 D T T T D_{TTT} DTTT( d i n p u t d_{input} dinput)。然后优化这些参数以最小化损失函数 L ( D T T T ; θ ) \mathcal{L}(D_{TTT};\theta) L(DTTT;θ),生成临时更新的参数 θ d \theta _{d} θd 用于预测。生成预测后,模型将恢复为下一个实例或批次的原始参数 θ 0 \theta _{0} θ0 。因此,TTT 为每个测试输入训练一个专门的预测模型。

实验设置

​ 为了研究每个 TTT 组分的影响,本文通过改变一个组分,同时将其他组分保持在最佳值(在各自的部分中描述)来进行实验。我们在实验中的默认配置使用以下设置:

  • 模型架构与优化
    • 使用来自 Llama-3 模型的 8B 参数语言模型,以及来自 Llama-3.2 模型的 1B、3B。使用低秩适应 (LoRA) 进行参数高效的测试时训练。对于每个任务 d 初始化一组单独的 LoRA 参数,这些参数在数据集 D T T T D_{TTT} DTTT 上训练。LoRA 等级设置为 128,并将调整应用于 MLP、注意力和输出层。使用 AdamW 优化器训练模型,有 2 个周期,批量大小为 2。
  • 数据 & 格式
    • 为了有效评估,从 ARC 验证集中随机挑选了 80 个平衡的 ARC 任务,包括 20 个简单、20 个中等、20 个困难、20 个专家任务。出于效率原因,将 D T T T D_{TTT} DTTT 限制为每个任务最多 250 个示例,使用 NVIDIA-A100 GPU 时,100 个随机采样的验证任务的整个 TTT 和推理过程大约需要 12 小时。

TTT 数据集和损失

数据生成

​ 给定一个任务,采用一组训练输入-输出对 ( x k t r a i n , y k t r a i n ) k = 1 K (x^{train}_{k},y^{train}_k)^K_{k=1} (xktrain,yktrain)k=1K 并将它们转换为一组增强的测试时间-训练任务 D T T T D_{TTT} DTTT。本文使用两步过程获得 D T T T D_{TTT} DTTT:首先,从给定的训练输入-输出对中创建一组留一的上下文学习任务;其次,在这组数据集上使用基于规则的可逆转换来获得增强的数据集。图 2 总结了这个过程。

在这里插入图片描述

  • 步骤一:留一任务

​ 通过从训练示例中排除第 j 个示例对,可以创建以下合成任务:
d j I C L = ( ( x k , y k ) k ∈ { 1 , . . . , k } ∖ { j } ⏟ synthetic training examples , x j , y j ⏟ synthetic test example )   w h e r e j ∈ [ 1 , K ] (2) d^{ICL}_{j}=(\underbrace{(x_k,y_k)_{k \in \{1,...,k\} \setminus \{j\}}}_{\text{synthetic training examples}}, \underbrace{x_j,y_j}_{\text{synthetic test example}})\ where j \in [1, K] \tag{2} djICL=(synthetic training examples (xk,yk)k{1,...,k}{j},synthetic test example xj,yj) wherej[1,K](2)
​ 其中,第 j 个示例对的 d j d_j dj 合成训练任务被视为测试用例。可以生成 n 个不同的任务,每个任务包含 n − 1 个示例对,还包括两个随机排列的 d j d_j dj 版本。

  • 步骤二:基于规则的转换

​ 考虑一个可逆变换 t,使得 t − 1 ( t ( x ) ) = x t^{-1}(t(x))=x t1(t(x))=x。对于步骤 1 中获得的每个任务,可以使用 t 生成一个新的增强任务 t ( d j I C L ) t(d^{ICL}_j) t(djICL),其中 t 应用于任务中的每个网格。

​ 选择简单的转换,这些转换保留了基本关系,同时引入了受控的变化,例如旋转、翻转、颜色排列、示例排列、大小缩放等。最后得到:
D T T T − I C L = { t ( d j I C L ) }   f o r   a l l   t , j   p a i r s . (3) D_{TTT-ICL}=\{t(d^{ICL}_j)\}\ for\ all\ t,j\ pairs.\tag{3} DTTTICL={t(djICL)} for all t,j pairs.(3)
​ 为了与上述 “测试时上下文学习” 方法进行比较,本文还评估了 “测试时端到端学习” 方法。将每个输入-输出对视为一个独立的训练实例,直接从示例演示中创建一个监督数据集。与上下文学习设置不同,没有上下文用于预测:
d j E 2 E = ( x j , y j )   w h e r e   j ∈ [ 1 , K ] (4) d^{E2E}_j=(x_j,y_j)\ where\ j \in [1,K]\tag{4} djE2E=(xj,yj) where j[1,K](4)
​ 这相当于在 ICL 设置中设置 leave-(n − 1)-out 任务,因为没有提供训练示例作为上下文。与 ICL 案例类似,可以应用基于规则的转换来增强数据集:
D T T T − E 2 E = { t ( d j E 2 E ) }   f o r   a l l   t , j   p a i r s . (5) D_{TTT-E2E}=\{t(d^{E2E}_j)\}\ for\ all\ t,j\ pairs.\tag{5} DTTTE2E={t(djE2E)} for all t,j pairs.(5)
​ 这种方法在计算上效率更高,因为它直接学习输入-输出映射,而无需管理上下文(即 few-shot 提示)的开销。

优化目标

​ 在测试时训练期间,本文使用 LoRA,同时保持大部分基本模型冻结。这种方法允许在保持模型的一般功能的同时进行计算高效的适应。

​ 训练目标:给定任务的测试时训练数据集 D T T T i D^i_{TTT} DTTTi,将演示和测试输出的标准语言建模损失降至最低:
L i ( D T T T i ; θ i ) = ∑ d ∈ D T T T i ( ∑ n = 2 K L L M ( y n ∣ x 1 , y 1 , . . . , x n ; θ i ) + L L M ( y t e s t ∣ x 1 , y 1 , . . . , x K , y K , x t e s t ; θ i ) ) (6) \mathcal{L}_i(D^i_{TTT};\theta_i)=\sum_{d\in D^i_{TTT}}(\sum^{K}_{n=2}\mathcal{L}_{LM}(y_n|x_1,y_1,...,x_n;\theta_i)+\mathcal{L}_{LM}(y_{test}|x_1,y_1,...,x_K,y_K,x_{test};\theta_i))\tag{6} Li(DTTTi;θi)=dDTTTi(n=2KLLM(ynx1,y1,...,xn;θi)+LLM(ytestx1,y1,...,xK,yK,xtest;θi))(6)
​ 其中 L L M \mathcal{L}_{LM} LLM 是语言建模的标准交叉熵损失。从第二个示例 (n = 2) 开始为演示包括损失项。通过这样做,鼓励模型从第二个演示对本身开始推理转换模式。

​ 特定于任务的参数:不是为测试集中的所有任务学习单个 LoRA 适配器,而是为每个任务学习一个单独的任务特定的 LoRA 适配器。也就是获得了 N 个不同的 LoRA 适配器,其中 N 是测试任务的数量。

结果

​ 将该方法的主要实现与以下消融进行了比较:

  • FT (No TTT):原版基线,其中 TTT 被消融,而使用微调模型
  • No Transformations:无基于转换的数据增强
  • End-to-End (E2E) Data:使用端到端任务公式,而不是标准的上下文任务设置
  • Shared TTT:与学习特定于任务的 LoRA 适配器相比,单个 LoRA 适配器是使用所有任务的聚合数据集学习的
  • No Demonstration Loss:数据的训练输出中的演示不会丢失。也就是说,TTT 损失只是:

L i ( D T T T i ; θ i ) = ∑ d ∈ D T T T i L L M ( y t e s t ∣ x 1 , y 1 , . . . , x K , y K , x t e s t ; θ i ) ) (7) \mathcal{L}_i(D^i_{TTT};\theta_i)=\sum_{d\in D^i_{TTT}}\mathcal{L}_{LM}(y_{test}|x_1,y_1,...,x_K,y_K,x_{test};\theta_i))\tag{7} Li(DTTTi;θi)=dDTTTiLLM(ytestx1,y1,...,xK,yK,xtest;θi))(7)

  • QLoRA:不是全精度的基本模型更新,而是为每个任务学习量化的 LoRA 适配器,这是为内存效率考虑的 LoRA 的替代方案。

在这里插入图片描述

​ 结果如图 3 所示。本文的 TTT 方法很有效,将微调模型的准确率提高了大约 6倍(5 → 29)。辅助任务的结构显着影响 TTT 的有效性。使用上下文学习任务的性能明显优于使用端到端任务,在相同条件下,任务的相对性能下降了 11 个(减少 38%)。这可能仅仅是由于训练的参数较少。放弃应用于增强数据的转换会损害 16 个任务(减少 55%)。

​ 接下来,本文消融 TTT 优化的多个组件,以分析它们对性能的贡献。在所有任务中学习单个 LoRA 适配器会降低 7 个任务的性能(降低 24%)。这是意料之中的,因为学习专用适配器可以为每个任务训练更多参数。其次,通过承担输出演示的损失而做出的决策略微提高了性能 (26 → 29),因为本文认为这会迫使模型在处理演示时对转换进行推理。最后观察到,使用量化 LoRA (QLoRA) 只会导致性能略有下降 (29 → 26) — 在内存受限的情况下,使用 QLoRA 可能是可行的。

增强推理策略

增强推理

​ 最近的研究表明,扩展测试时计算可以显著提高 LM 的性能。执行此操作的最常见技术之一是对多个响应进行采样,然后使用排名程序选择最佳响应。然而采样在具有多种可能的解决方案(代码中的程序)或通往最终答案(数学)的多种可能路径的领域中非常有效,但在直接生成答案时可能是有害的,因为无法在确保样本内一致性的同时确保执行样本之间的多样性。作为另一种推理时间扩展,本文使用增强推理策略,通过使用几何变换并结合贪婪解码方案来生成多个预测候选者。

​ 对于具有训练示例 ( x k , y k ) k = 1 K (x_k,y_k)^K_{k=1} (xk,yk)k=1K 和测试输入 x t e s t x_{test} xtest 的给定任务,使用可逆几何变换来生成任务的等效变换版本。设 T \mathcal{T} T 为一组可逆几何变换(例如,旋转和反射)。对于每个转换 t ∈ T t \in \mathcal{T} tT ,将 t 应用于所有训练演示和测试输入,并使用这些转换后的输入运行本文模型。

​ 然后,应用逆变换来获得该变换的最终预测。
y ~ ∼ L M ( t ( d i n p u t ) ) : = [ t ( x 1 ) , t ( y 1 ) , . . . , t ( x t e s t ) ] (8) \widetilde{y} \sim LM(t(d_{input})):=[t(x_1),t(y1),...,t(x_{test})]\tag{8} y LM(t(dinput)):=[t(x1),t(y1),...,t(xtest)](8)

y t = t − 1 ( y ~ ) (9) y_t=t^{-1}(\widetilde{y})\tag{9} yt=t1(y )(9)

集成预测(投票策略-贪婪解码策略)

​ 采用分层投票策略来确定候选人集 { y } i − 1 n ⋅ ∣ T ∣ \{y\}^{n \cdot |\mathcal{T}|}_{i-1} {y}i1nT。这种方法涉及两个投票阶段,以逐步缩小最佳候选者的范围:首先,选择每个转换中最频繁的预测,然后对特定于转换的候选者进行总体投票,以确定前 2 个最频繁的预测。每个阶段的详细信息如下:

  • 转换内投票:按相应的转换 t 对预测进行分组,并选择每组中前 3 个最频繁的预测。如果一个组中存在的唯一预测少于 3 个,将通过通过以下方式计算其他预测来补充候选项:
    • 基于行的多数:对于预测输出网格中的每一行,采用转换组中的所有预测中最频繁的行值。
    • 基于列的众数:同样对于预测输出网格中的每一列,在转换组中的所有预测中采用最频繁的列值。
  • 全局投票:使用从转换内投票中获得的特定于转换的候选函数,进行总体投票,以选出前 2 个最常见的预测进行提交。如果出现平局,则优先使用身份转换进行预测。
结果

​ 为了分析增强推理和投票的影响,本文进行了以下消融:

  • Vanilla:此基线遵循标准推理方法,没有任何增强的推理或投票。它从模型中为任务的 2 种排列生成 2 个预测。此设置可作为评估增强推理和投票策略优势的参考点
  • Transformed Inference (Rotate/Transpose/Flip):测量仅从任务的特定转换版本生成预测时的性能,如图 5 所示。这将评估单独应用的每个转换的单个有效性。请注意,Vanilla 也可以被视为此类别的一部分,其中转换是 identity function
  • Hierarchical Voting:包括增强推理和投票
  • Flattened Voting:没有使用分层投票策略,而是对完整的 n ⋅ ∣ T ∣ n \cdot |\mathcal{T}| nT 预测中确定 2 个最常见的预测
  • 如果正确答案存在于 n ⋅ ∣ T ∣ n \cdot |\mathcal{T}| nT 的集合中,则 oracle 会选择正确答案。如果投票过程完美,则 oracle 会提供可能的最佳性能上限。

在这里插入图片描述

​ 结果总结在图 5 中。如图所示,特定转换版本的单个性能通常很差,转置转换的准确率最差。但是,通过投票程序聚合这些转换会带来显著的改进。这表明,一些任务在其转换版本中可能更容易解决,并且使用自洽性(投票)进行聚合通常是有益的,这一发现在以前的工作中也观察到。此外,虽然扁平化投票程序提高了准确性,但分层投票程序优于它。事实上,分层过程与 oracle 相当,表明分层聚合可以有效地选择正确答案(如果存在),并且准确率很高。

TTT 前微调

​ 虽然测试时训练有助于特定任务的适应,但基本模型的功能会影响最终性能。本文开发了几种生成合成训练数据的方法,通过微调来增强基础模型的抽象推理能力。在本节中将详细介绍微调数据生成策略,并分析不同数据源和模型大小对最终性能的影响。

准备微调数据

​ Hodel 等人提供了域特定语言 (DSL)REARC 以及解决 t a s k − i task-i taski 的转换 f i f_i fi,以及在该DSL中为 D A R C t r a i n D^{train}_{ARC} DARCtrain 数据集中的每个训练任务的数据生成函数 g i g_i gi。这些函数支持对保持相同底层转换原理的新输入-输出对进行采样:
d = ( x , y ) ∼ e v a l ( g i ) (10) d=(x,y) \sim eval(g_i)\tag{10} d=(x,y)eval(gi)(10)
​ 其中 d 表示新生成的输入-输出对,可以使用与原始 t a s k − i task-i taski 相同的变换函数 f i f_i fi 进行求解。

  • 使用现有生成器

​ REARC 中的生成器函数 g s gs gs 已经通过生成相同任务的不同实例提供了一种有效的数据增强工具。本文通过多次运行这些代码并将这些新样本( d ∼ e v a l ( g i ) d \sim eval(g_i) deval(gi))随机拆分为一组训练和测试样本,从这些训练任务中生成额外的样本。

  • Few-shot 提示 LLM

​ 使用了几种方法使用 LM(GPT4 和 GPT4-o 的集成)生成新任务。最简单的方法使用少数样本生成新的任务生成器:
g ′ ∼ L M ( g 1 , g 2 , . . . , g m ) (11) g' \sim LM(g_1,g_2,...,g_m)\tag{11} gLM(g1,g2,...,gm)(11)
​ 其中 g ′ g' g 是新的生成器函数, g 1 , . . . , g m g_1,...,g_m g1,...,gm是现有的生成器函数(如图 6 所示)。本文从现有的训练集中统一抽样不同的 m 个例子,多次重复此过程以获得大量任务。

在这里插入图片描述

​ 使用任务描述来增强生成器函数,并共同生成描述和生成器:
( s ′ , g ′ ) ∼ L M ( s 1 , g 1 , s 2 , g 2 , . . . , s m , g m ) (12) (s',g') \sim LM(s_1,g_1,s_2,g_2,...,s_m,g_m)\tag{12} (s,g)LM(s1,g1,s2,g2,...,sm,gm)(12)
​ 其中 s i s_i si 表示任务 i i i 的描述。

​ 为了获得任务描述,本文为 10 个训练任务手动创建了描述。然后,这些描述用于通过小样本提示为训练和验证任务生成描述。为了增加任务的多样性,使用带有分层字段(类别、摘要和描述)的任务描述。

​ 本文没有共同生成任务描述和函数生成,而是额外部署了如下两阶段方法:
s ′ ∼ L M ( s 1 , s 2 , . . . , s m ) (13) s' \sim LM(s_1,s_2,...,s_m)\tag{13} sLM(s1,s2,...,sm)(13)

g ′ ∼ L M ( s 1 , g 1 , s 2 , g 2 , . . . , s m , g m , s ′ ) (14) g' \sim LM(s_1,g_1,s_2,g_2,...,s_m,g_m,s')\tag{14} gLM(s1,g1,s2,g2,...,sm,gm,s)(14)

​ 这种方法首先生成一个任务描述 s ′ s' s,然后在现有任务对和新描述上设置生成器的创建条件。本文总共收集了 6426 个使用这些基于 LLM 的方法的生成器。

  • 几何变换

​ 合成任务通过各种几何变换得到增强,例如基本变换(旋转、反射、随机偏移和大小缩放)、图案操作(随机修补、平铺和重复)、颜色排列和涉及多个基本变换顺序应用的复合变换。这些转换以三种方式应用:

​ 1)变化输入: ( x , y ) → ( t ( x ) , y ) (x,y) \rightarrow (t(x), y) (x,y)(t(x),y)

​ 2)变化输出: ( x , y ) → ( x , t ( y ) ) (x,y) \rightarrow (x, t(y)) (x,y)(x,t(y))

​ 3)同时变化输入输出: ( x , y ) → ( t ( x ) , t ( y ) ) (x,y) \rightarrow (t(x), t(y)) (x,y)(t(x),t(y))

​ 这些转换在 30% 的时间内随机应用于任务的变体。

结果

​ 使用增强数据执行完全微调 1B、3B Llama 3.2 指令调整和 8B Llama 3 指令调整。对增强数据进行以下分析:

  • No FT:原始的 Llama 3 指令调整模型,没有任何微调
  • All:使用描述的所有方法,包括 REARC、基于规则的增强和 LM 生成
  • No-Geom:从所有任务中删除几何变换
  • No-LM:只使用 REARC 和基于规则的增强,不包括 LM 生成的任务

在这里插入图片描述

FT 数据如何影响 TTT 在图 7 左侧中使用不同的微调数据比较模型。发现在 REARC 上用基于规则的增强训练的模型取得了最强的性能。包括 LM 生成的任务会损害 5% 的性能,这表明当前基于 LM 的任务生成方法可能需要更复杂的过滤机制。

TTT 中的模型大小和缩放 在图 7 右侧中显示了使用不同模型大小的结果。增加模型大小可以持续提高 FT 性能,8B 模型实现了 36% 的最高准确率。还观察到 TTT 有效地缩小了较小模型的性能差距,1B 和 3B 模型在 TTT 后实现了相似的准确性。

ARC 基准测试和与其他系统的比较

​ 在本文对 80 项任务进行开发实验之后,展示了完整的 ARC 公开评估集的综合结果,并将本文的系统与现有方法进行了比较。分析集中在三个关键方面:本文的 TTT 方法的影响、将本文方法与现有方法相结合的好处以及全神经和程序合成方法之间的差异。

  • TTT 的影响

​ 将 TTT 和推理程序应用于本文的基本微调模型。TTT 将准确率从 39.3% 提高到 47.1%,超过了现有的端到端神经模型结果。

  • 与现有方法集成

​ BARC通过结合神经和程序合成方法实现了 54.4% 的准确率——这是以前公开可用的最高结果。虽然全神经方法与本文的系统有相似之处,但本文的 TTT 和推理有几个额外的组件可以提高性能。特别是测试时训练包括每个任务的 LoRA 和更大的增强集,而本文预测包括可逆转换下的增强推理和分层自洽投票方案。为了验证本文的改进,将 TTT 应用于 BARC 的全神经模型,实现了 53% 的准确率,比原始 TTT 方法提高了 35%。

在这里插入图片描述

​ 基于这些结果,探索了本文的方法与 BARC 组件的各种组合:

​ 1)将 TTT 和神经模型与 BARC 的合成器相结合,将准确率提高到 58.5%。

​ 2)将 TTT 与 BARC 的神经模型和合成器相结合,将准确率提高到 61.9%。

​ 这种最终配置在 ARC 公共评估集上建立了一个新的最先进的技术,可与 60.2% 的平均人类表现相媲美。虽然这代表了重大进步,但与人类最佳表现 97.8% 仍有很大差距,这表明还有进一步改进的空间。

  • 比较程序生成和端到端建模

​ ARC 的程序合成和完全神经预测器是高度互补的,即使在相同的任务上训练也是如此。端到端神经模型只能解决程序合成模型解决的任务的 42.2%。然而当配备本文的 TTT 时,BARC 的微调全神经模型解决了 73.5% 的程序合成模型解决的任务,这表明本文 TTT 显著提高了神经模型学习系统推理模式的能力。

总结

局限
  • 基础 LM 模型可能在预训练就接触过相关测试数据集
启发
  • TTT 的有效性及潜力,证明了在数据量极低的环境中,在测试时训练更新模型参数具有提升模型性能的潜力,能显著提高语言模型在抽象与推理语料库(ARC)上的准确率
  • 将TTT方法与程序生成方法相结合,达到了与人类平均表现相当的性能,表明神经模型可以通过适当的测试时计算资源分配来解决复杂的推理问题,而不仅仅依赖于符号搜索
  • 基于几何变换的数据增强策略,并结合了自洽性的投票机制来选择最佳预测,提高了模型在面对不同变体时的鲁棒性和准确性

相关知识链接

  • 论文链接
  • BibTex
@misc{akyürek2024surprisingeffectivenesstesttimetraining,
      title={The Surprising Effectiveness of Test-Time Training for Abstract Reasoning}, 
      author={Ekin Akyürek and Mehul Damani and Linlu Qiu and Han Guo and Yoon Kim and Jacob Andreas},
      year={2024},
      eprint={2411.07279},
      archivePrefix={arXiv},
      primaryClass={cs.AI},
      url={https://arxiv.org/abs/2411.07279}, 
}

http://www.kler.cn/a/515801.html

相关文章:

  • 自动化实现的思路变化
  • ChatGPT大模型极简应用开发-CH2-深入了解 GPT-4 和 ChatGPT 的 API
  • 简识JVM栈帧中的操作数栈
  • SQL-leetcode—1141. 查询近30天活跃用户数
  • Linux(Centos 7.6)命令详解:iconv
  • 【线性代数】列主元法求矩阵的逆
  • 【设计模式-行为型】状态模式
  • React 中hooks之useInsertionEffect用法总结
  • 99.11 金融难点通俗解释:净资产收益率(ROE)VS投资资本回报率(ROIC)VS总资产收益率(ROA)
  • 【算法笔记】力扣热题100(LeetCode hot-100)438. 找到字符串中所有字母异位词 滑动窗口
  • 安卓程序作为web服务端的技术实现:AndServer 实现登录权限拦截
  • js为table列宽度添加拖拽调整
  • 原生toFixed的bug
  • 多版本并发控制:MVCC的作用和基本原理
  • javaweb之HTML
  • 【spring】集成JWT实现登录验证
  • Grafana系列之面板接入Prometheus Alertmanager
  • C#树图显示目录下所有文件以及文件大小
  • 深入探究 YOLOv5:从优势到模型导出全方位解析
  • 简识JVM的栈帧优化共享技术
  • SQL-leetcode—1174. 即时食物配送 II
  • 【设计模式-行为型】观察者模式
  • Git报错:refusing to merge unrelated histories
  • 基于ESP32-IDF驱动GPIO输出控制LED
  • ChatGPT大模型极简应用开发-CH2-深入了解 GPT-4 和 ChatGPT 的 API
  • linux CentOS 创建账号,并设置权限