当前位置: 首页 > article >正文

STaR: Bootstrapping Reasoning With Reasoning

STaR: Bootstrapping Reasoning With Reasoning

基本信息

博客贡献人

燕青

作者

Eric Zelikman, Yuhuai Wu, Jesse Mu, et al. from Stanford University and Google Research

标签

Large Language Model, Chain-of-thought, Fine-tuning

摘要

生成逐步的“思维链”逻辑依据(rationale)可以提高语言模型在数学或常识问答等复杂推理任务上的性能。然而,诱导语言模型进行逻辑依据生成需要构建大量逻辑依据数据集,或者仅使用few-shot推理来牺牲准确性。本文提出了一种技术来迭代地利用少量的逻辑依据示例和没有逻辑依据的大型数据集,以引导连续执行更复杂的推理的能力。这种技术称为“自学推理机”(STaR),依赖于一个简单的循环:生成回答许多问题的逻辑依据,并用一些逻辑依据示例进行提示;如果生成的答案是错误的,请在给出正确答案的情况下再次尝试生成理由;对最终产生正确答案的所有逻辑依据进行微调;重复上述过程。本文表明,与经过微调以直接预测最终答案的模型相比,STaR 显着提高了多个数据集上的性能,并且其性能与在 CommensenseQA 上微调 30 倍大的最先进语言模型相当。因此,STAR 让模型通过学习自己生成的推理来改进自身。

问题定义

人类的决策通常是思想链延伸的结果。最近的工作表明,显式中间推理(“逻辑依据”)也可以提高大语言模型(LLM)的性能。在给出最终答案之前生成明确的逻辑依据(rationale generation,逻辑依据生成)对于LLM在数学推理、常识推理、代码评估、社会偏见推理和自然语言推理等不同任务中很有价值。然而,诱导基本原理生成的两种主要方法都有严重的缺点。

逻辑依据生成的一种方法是构建一个微调的逻辑依据数据集,要么由人工标注人员手动标注,要么通过手工制作的模板自动标注。人工方法成本高昂,且无法为每个不同领域的问题构建这样的数据集。同时,基于模板的方法依赖于自动生成逻辑依据,只有在模版匹配的特定条件下才能发挥作用。

另一种方法是通过在语言模型prompt中只包含少量的逻辑依据示例(few-shot)来利用语境学习。与无逻辑依据的prompt相比,其被证明可以提高数学和符号推理任务的准确性。然而,虽然带有逻辑依据推理的few-shot技术往往优于不进行推理的对应技术,但它们的表现通常远不如使用较大数据集微调后直接预测答案的模型。

本文采用了一种不同的方法:通过利用LLM预先存在的推理能力,迭代地引导LLM生成高质量逻辑依据的能力。具体来说,本文通过小样本提示LLM来自我生成逻辑依据并通过微调那些导致正确答案的理由来进一步完善模型的能力。重复这个过程,每次使用改进的模型生成下一个训练集。这是一个协同的过程,逻辑依据生成的改进改善了训练数据,训练数据的改进又进一步改善了逻辑依据生成。

然而,本文发现这个循环最终没有解决训练集中的任何新问题,因为它没有接收到它未能解决的问题的直接训练信号。为了克服这个问题,本文提出逻辑依据化(rationalization):对于模型不能正确回答的每一个问题,通过提供模型正确的答案来生成新的逻辑依据。这让模型反向推理——给出正确答案,模型就能更容易地生成有用的逻辑依据。然后收集这些逻辑依据作为训练数据的一部分,这往往会提高整体的准确性。

因此,本文开发了自学习推理机(STaR,如图1所示)方法,这是一种可扩展的自学方法,允许模型学习生成自己的逻辑依据,同时也学习解决越来越困难的问题。在本文的方法中,重复以下过程:在每次迭代中,首先通过尝试使用当前模型的基本原理生成能力来求解数据集**(能够正确解决的问题),从而构建一个微调数据集;然后,使用合理化方法扩充该数据集(不能正确解决的问题)**,以证明模型未能解决的问题的真实答案;最后,在组合数据集上微调大语言模型。

1
图1. STaR概述及逻辑依据化原理
问题和答案来自原始数据集,而逻辑依据使用STaR生成。

方法

逻辑依据生成引导

给定一个预训练的LLM M M M 和一个问题 x x x 的初始数据集,其答案为 y : D = { ( x i , y i ) } i = 1 D y:\mathcal{D} = \{ ( x_i , y_i) \}^D_{i=1} yD={(xi,yi)}i=1D 。本文的技术从具有逻辑依据 r r r 的例子的一个few-shot提示集合 P P P 开始: P = ( x i p , r i p , y i p ) i = 1 P \mathcal{P} = { ( x^p_i , r^p_i , y^p_i) }^P_{i = 1} P=(xip,rip,yip)i=1P,其中 P ≪ D P \ll D PD ( $ e.g. P = 10$ )。类似于标准的few-shot提示,本文把这个提示集合串接到 D D D 中的每一个例子,即 x i = ( x 1 p , r 1 p , y 1 p , . . . , x P p , r P p , y P p , x i ) xi = ( x^p_1 , r^p_ 1 , y^p_1 , ... , x^p_P , r^p_P , y^p_P , x_i) xi=(x1p,r1p,y1p,...,xPp,rPp,yPp,xi),这鼓励模型产生 x i x_i xi 的理由 r ^ ( i ) \hat r ( i ) r^(i) 和答案 y ^ ( i ) \hat y ( i ) y^(i)。假设导致正确答案的理由比导致错误答案的逻辑依据具有更好的质量。因此,本文过滤了生成的逻辑依据,使其只包含导致正确答案 ( y ^ i = y i ) ( \hat y_i = y_i ) (y^i=yi) 的逻辑依据。本文在这个筛选后的数据集上对基础模型 M M M 进行微调,然后通过新微调后的模型生成新的逻辑依据来迭代这个过程。不断重复这个过程直到达到性能瓶颈。在这个过程中,一旦收集到一个新的数据集,就对初始的预训练模型 M M M 进行训练,而不是继续训练一个模型,以避免过拟合。该过程如算法1所示。

1
算法1. STaR
蓝色代表应用逻辑依据化(rationalization)的部分

算法1描述了完整的算法,用蓝色表示的部分对应逻辑依据化。如果没有这些部分,算法1对应的是没有逻辑依据化的STaR。对逻辑依据化产生的数据集进行微调,对于将模型暴露在其微调数据集中本来不会出现的困难问题上有着至关重要的好处。这可以理解为挑战模型对其失败的问题进行"跳出盒子外的思考"。逻辑依据化的第二个好处是数据集规模的增加。

逻辑依据化(rationalization)

逻辑依据生成引导算法具有一定的局限性。由于模型只在其正确回答的实例上进行训练,当模型在训练集中不能解决新问题时,改进就结束了。这在根本上是由于算法无法从失败的样本中获得任何训练信号。本文提出了一种被称之为"逻辑依据化(rationalization)"的技术。具体来说,提供答案作为对模型的提示,并要求模型按照与前一个理据生成步骤相同的风格生成理据。

在给定答案的情况下,模型能够反向推理,从而更容易产生导致正确答案的理由。例如,在图2中,在提示语中提供了"(b) 购物车是正确答案"的提示,以生成理由。对于逻辑依据生成模型未能解决的问题,采用逻辑依据化处理。当在数据集中添加一个逻辑依据化产生的逻辑依据时,并不会在其相应的prompt中包含暗示,就像模型在没有暗示的情况下提出了理由一样。过滤后,在先前生成的数据集上结合合理化生成的数据集进行微调。

实验

实现细节

本文选择了6B参数模型 GPT-J 作为基础语言模型,因为检查点和微调代码是公开可用的,并且该模型足够大,可以产生高质量的逻辑依据。

数据集
  • Arithmetic

    算术任务,目的是计算两个n位整数之和。该数据集用于评估模型的符号推理能力。

  • CommonsenseQA

    多选项常识推理任务常识QA (CQA) 是由概念网( ConceptNet )构建的,概念及其关系的语义图有超过百万个节点。该数据集有12,247个问题,每个问题有5个选择。

  • Grade School Math (GSM8K)

    本文还在年级数学( Grade School Math, GSM8K )数据集上进行了评估,该数据集包含7,473个训练和1,319个测试的年级-学校-等级的单词问题示例。这些数学问题都是用自然语言提出的,需要经过2到8个计算步骤才能得出最终答案。该数据集结合了算术和常识推理所需的技能。

实验结果
符号推理 (Arithmetic)

图2给出了每次循环迭代时模型在1~5位上的精度。运行STaR进行16次迭代后,总体准确率为89.5 %。作为参考,在10,000个没有逻辑依据的样本上训练5,000步的基线达到76.3 %的准确性。值得注意的是,算术问题的few-shot准确率很低,即使有逻辑依据:2位数加法的准确率也低于1%,更多位数的准确率接近零。

伴随逻辑依据化,准确性能够快速提高。在对模型生成的新数据集进行一次微调迭代后,2位加法的准确率从不足1 %提高到32 %。在没有逻辑依据化的情况下,性能提升是分阶段的:直到模型在(n-1)位加法上取得良好性能时,在n位加法上的性能才能显著提升。在逻辑依据化的情况下,模型可以同时学习多个位数。

2
图2. Arithmetic实验结果
对STaR算法每次迭代的n位求和精度进行了可视化,并对算法进行了逻辑依据化和非逻辑依据化处理。每一条曲线对应两个n位数字求和的精度。

自然语言推理(CQA)

CommonsesenseQA ( CQA )数据集引入了几个新的挑战。在算术任务中,推理步骤中的不正确暂存器(scratchpad),以及程度较低的逻辑依据化步骤,极有可能导致不正确的答案。另一方面,CQA问题是5道选择题,因此无论推理的质量如何,在大约20%的问题上,会随机得到正确的答案。此外,一些简单的启发式算法(例如语义相似性)可以在没有任何推理的情况下有意义地将其提高到≈30 %。

如表1所示,在整个数据集上,无逻辑依据化的STaR尽管在更少的数据上进行训练,但在最终答案上的表现优于直接微调的GPT-J。逻辑依据化的加入使这一性能提高到72.5 %,更接近GPT-3的73 %(其参数是GPT-J的30倍大)。正如预期的那样,可以看到STaR超过了few-shot基线,包括更大的137B LaMDA模型。本文预期如果将STaR应用到少样本表现更好的模型中,准确率会进一步提高。

3
表1. CQA实验结果

数学推理(GSM8K)

本文在GSM8K上再次发现,STaR在有逻辑依据的情况下显著超过了few-shot和训练直接预测答案(无逻辑依据)的表现,如表2所示。需要注意的是,在训练时,需要对第30次迭代的训练步数(经过7912步)进行限制,以防止训练过程变得过长。STaR的最终结果是在没有逻辑依据化的情况下迭代36次,及在逻辑依据化的情况下再迭代10次的结果。

3
表2. GSM8K实验结果

相关知识链接

论文原文

STaR: Self-Taught Reasoner Bootstrapping reasoning with reasoning

总结

  • 本文提出了一种从具有逻辑依据的少量初始例子(few-shot)中迭代生成逻辑依据数据集,而不需要检查生成逻辑依据的正确性的引导机制。
  • 本文用逻辑依据化来补充逻辑依据生成,对于模型不能解决的问题,用正确答案进行暗示,然后用回答正确的问题对应的逻辑依据扩充数据集,进行后续迭代训练。该方法加速并优化了模型自学过程。
  • 本文首次提出了允许预训练LLM迭代地使用其语言建模能力来改进自身的技术。

[局限]

  • 针对不同领域的下游任务或数据集,需要进行单独的微调(后续Quiet-STaR改进)

[启发]

  • 本文用到了类似强化学习的技术,将模型是否正确回答问题类比奖励函数,对模型进行后续的改进
  • 有效利用了模型自身的语言建模能力来改进其自身

BibTex

@article{zelikman2022star,
  title={Star: Bootstrapping reasoning with reasoning},
  author={Zelikman, Eric and Wu, Yuhuai and Mu, Jesse and Goodman, Noah},
  journal={Advances in Neural Information Processing Systems},
  volume={35},
  pages={15476--15488},
  year={2022}
}

http://www.kler.cn/a/316470.html

相关文章:

  • ubuntu cmake CPack将第三方库进行打包
  • 智享AI 无人自动直播的崛起 ,引领智能互动与自动带货新潮流!
  • 给查询业务添加redis缓存和缓存更新策略
  • 信捷 PLC C语言 POU 指示灯交替灭0.5秒亮0.5秒(保持型定时器)
  • LabVIEW 实现 find_nearest_neighbors 功能(二维平面上的最近邻查找)
  • Day 63 || 拓扑排序、dijkstra
  • Android 如何实现搜索功能:本地搜索?数据模型如何设计?数据如何展示和保存?
  • 基于YOLOv8+LSTM的商超扶梯场景下行人安全行为姿态检测识别
  • Python3使用websocket调用http代理IP
  • IP包头分析
  • 【SSM-Day2】创建SpringBoot项目
  • 基于丹摩智算平台-手把手拿下经典目标检测模型 Faster-Rcnn
  • H5响应式的文化传媒娱乐公司HTML网站模板源码
  • Linux入门学习:深刻理解计算机硬件与OS体系
  • saltstack配置管理
  • 深度学习基础案例5--VGG16人脸识别(体验学习的痛苦与乐趣)
  • C++第七节课 运算符重载
  • 航拍房屋检测系统源码分享
  • 对条件语言模型(Conditional Language Model)的目标函数的理解
  • C语言编译四大阶段
  • EasyExcel的基本使用——Java导入Excel数据
  • [C#]winform 使用opencvsharp实现玉米粒计数
  • 基于windows的mysql5.7安装配置教程
  • Vue 实现高级穿梭框 Transfer 封装
  • Qt 模型视图(四):代理类QAbstractItemDelegate
  • 【数字组合】