当前位置: 首页 > article >正文

使用PyTorch实现GPT-2直接偏好优化训练:DPO方法改进及其与监督微调的效果对比

基于人类反馈的强化学习(RLHF)已成为大型语言模型(LLM)训练流程中的关键环节,并持续获得研究界的广泛关注。

本文将探讨RLHF技术,特别聚焦于直接偏好优化(Direct Preference Optimization, DPO)方法,并详细阐述了一项实验研究:通过DPO对GPT-2 124M模型进行调优,同时与传统监督微调(Supervised Fine-tuning, SFT)方法进行对比分析。

本文将系统阐述DPO的工作原理、实现机制,以及其与传统RLHF和SFT方法的本质区别。

RLHF的基本原理

RLHF在LLM训练的后期阶段发挥关键作用,其核心目标是使模型与难以明确定义的微妙人类偏好达成更好的一致性。以下将详细分析其必要性。

现代LLM的训练通常包含多个阶段:

预训练阶段是第一阶段,模型通过在互联网等来源的海量文本数据上优化交叉熵目标进行训练。对于规模最大的模型,预训练数据集可能包含数万亿个非结构化token。该阶段使模型掌握基本的语言结构和事实性知识,形成一个能够准确完成句子和获取事实的"基础"模型,但其输出往往缺乏对话的自然性。

监督微调是第二阶段,模型在精心构建的问答对数据集上进行训练,这些数据集明确定义了特定上下文的最优响应。这些最优响应通常由领域专家编写,确保其格式规范、长度适当且信息充分。

RLHF构成第三阶段。该阶段旨在优化模型在那些难以精确定义但易于判断的行为场景中的表现(例如当AI公司期望其模型展现出顺从和友善的特质时),虽然创建数千个符合这些标准的最优响应成本高昂且耗时,但对已有响应进行评判则相对容易。这表明RLHF在优化那些难以生成标准答案但易于评估的行为方面具有独特优势

经典RLHF:奖励模型与PPO的结合

首先,我们需要理解RLHF最经典的实现方法,该方法最初由OpenAI在2019年的研究中提出。为简明起见,本节将从较高层面进行阐述。

RLHF的基础是偏好数据集的构建。这通常通过向人类评估者展示提示和可能的完成结果集合,并让其选择最佳完成结果来实现。RLHF的核心目标是训练模型生成那些能够被人类评估者选为最佳回应的内容。

从强化学习的角度,这个问题可以进行形式化描述。我们可以将LLM视为参数化策略,其中环境状态由输入LLM的上下文表示,模型通过生成新token来执行动作。一个完整的生成序列即构成一个轨迹。

关键观察在于,我们的偏好数据集本质上是一组轨迹及其对应的"质量"评分。其中的技术要点在于我们可以将偏好数据转化为分配给每个完成结果的标量奖励,使得最受偏好的完成结果获得更高的奖励值。这样就能将轨迹与其最终奖励建立关联,这就意味着最大化策略价值等同于最大化所有轨迹的期望奖励,也就是我们的优化目标。因此可以采用PPO等策略梯度方法来优化策略,提高产生高奖励轨迹的概率,同时确保LLM在RLHF过程中不会发生过大的偏离。

但是由于LLM生成的多样性,我们会面临着极其庞大的状态空间,策略梯度方法本质上是"在线"的,仅依赖固定的离线轨迹集合难以有效执行PPO。

所以我们不直接使用偏好数据作为轨迹,而是利用它训练一个奖励模型作为人类评估的代理。奖励模型本身通常是一个LLM,接收来自actor LLM的文本输入并输出标量奖励。奖励模型的损失函数定义为:

其中r表示奖励模型,x为提示,y为候选完成结果集合,b为首选响应的索引。该模型通过学习最大化首选响应的奖励同时抑制其他响应的奖励值。这种"对比"机制至关重要——与仅展示高质量响应的SFT相比,RLHF教会LLM在生成高质量响应的同时主动规避低质量响应。

具体的训练过程是在线进行的:向模型提供提示,采样其生成结果,使用奖励模型预测奖励,并执行PPO更新。为防止分布偏移影响训练过程中奖励模型的准确性,需要定期使用actor模型的最新生成结果重新训练奖励模型。

这自然引发了一个关键问题:RLHF为何能够实现有效泛化?对此存在两种可能的解释。

悲观的解释认为这并非真正的泛化——部分研究表明奖励模型存在过拟合现象,RLHF调优的模型仅是直接复制数据集中的偏好。然而在某些场景下,RLHF展现出卓越的效果,模型在各个方面都实现了显著提升。

乐观的解释则认为,由于奖励模型本身就是一个LLM,它在分配奖励时会利用预训练阶段获得的内部知识。因此奖励函数实际上是多维的,隐含了预训练阶段的知识,这意味着它传递的信息远超显式的偏好。虽然这仅是一个工作假设,但它可能有助于解释RLHF强大的泛化能力。

DPO:一种创新方法

直接偏好优化(Direct Preference Optimization, DPO)由Rafailov等人于2023年提出,该方法在不需要实际强化学习的情况下实现了RLHF的目标。DPO方法消除了对奖励模型的依赖,转而利用固定的离线偏好数据集进行模型调优。

图1:DPO与传统RLHF的架构对比。DPO无需奖励模型。

DPO论文的作者指出传统RLHF流程存在时间和计算资源消耗过高的问题,主要体现在两个方面:

首先,与典型的强化学习问题类似,训练过程需要从actor LLM不断生成新的完成结果——这意味着每个训练步骤都涉及LLM的多次前向传递计算,导致巨大的计算开销。这源于传统RLHF采用在线学习方法的本质特性。

其次,训练流程需要额外训练一个完整的LLM(即奖励模型)用于显式奖励计算。

DPO在这两个方面都采取了不同的策略——它是一种离线方法,并通过偏好数据来隐式定义奖励。如果将RLHF视为在线强化学习,SFT视为纯行为克隆,那么DPO则代表了两者之间的离线强化学习方法。

DPO数据集由响应对构成,其中对于每个提示,都包含一个"被选择"的响应和一个"被拒绝"的响应。DPO训练涉及actor模型和一个固定的、不可训练的参考模型(通常初始化为与actor模型相同的参数)。在训练过程中对actor模型进行优化,同时通过正则化确保其不会过度偏离参考模型。

这导致了如下DPO目标函数,其中y_w表示被选择的响应,y_l表示被拒绝的响应。本质上该损失函数促使actor和参考模型之间被选择响应的概率比相对于被拒绝响应的概率比增长。当这种情况发生时,sigmoid内的量为正值,使得sigmoid值趋近于0,从而产生较小幅度的负对数值(导致较低的损失)。

在初次接触DPO时,可以会产生以下一些问题问题:

该损失函数与典型的交叉熵损失加KL散度项结构相似。其独特之处究竟在哪里?将拒绝的响应作为负样本的实际效果如何?

考虑到模仿学习和离线强化学习固有的探索限制——使用DPO训练的LLM在遇到数据集外的提示时是否会失效?

DPO相比于针对选定响应的纯SFT是否具有实质性优势?

为了获得更深入的理解,我们将使用PyTorch中从零实现DPO,并将其应用于参数量为1.24亿的最小规模GPT-2模型。同时,我实现了SFT以进行对比分析。

数据集构建

DPO方法的基础是偏好数据集的构建,每个样本包含一个提示、一个"被选择"的响应和一个"被拒绝"的响应。被选择的响应体现了我们期望模型展现的行为特征,而被拒绝的响应则代表了相反的特征。

在实际应用场景中,被选择和被拒绝的响应往往只存在细微的差异,比如语气或表达方式的微小变化。然而,为了便于在实验中清晰观察模型行为的变化,本文构建了一个具有显著对比特征的数据集,设定了一个明确的行为目标:要求模型在每个响应中都提及宾夕法尼亚大学(Penn)。

数据集的构建采用了基于GPT-4的合成方法。首先使用GPT-4生成一系列提示,然后对每个提示生成两种响应:一个包含Penn相关内容(作为被选择的响应),另一个不包含Penn相关内容(作为被拒绝的响应)。这种方法效果显著。GPT-4能够生成多样化的提示,并能自然地将Penn融入被选择的响应中。以下展示了通过该方法生成的典型示例:

表1:合成DPO数据集中的示例样

通过上述方法,我们最终构建了包含约1,300个训练样本和125个测试样本的数据集,其中测试样本被完全隔离用于模型评估。

HuggingFace基准实现

本文的主要目标是从零开始实现DPO,所以首先使用HuggingFace提供的DPOTrainer作为性能基准。这一基准将作为评估自主实现的DPO效果的参照标准。我们采用默认超参数配置进行了训练,并获得了稳定的收敛结果(如图2所示)。

图2. 使用HuggingFace DPOTrainer训练过程中的损失函数变化曲线。

在保留测试集上对调优后的模型进行评估时,我们观察到模型行为发生了显著变化。模型的输出表现出高度一致性,几乎所有响应都自然地包含了对宾夕法尼亚大学的引用。表2展示了测试集中的两个典型示例:

表2. 基础GPT-2模型与经HuggingFace DPOTrainer调优后的GPT-2模型输出对比。

DPO实现与训练过程

接下来,我们使用纯PyTorch实现了DPO训练框架,以深入理解DPO的内部机制。完整的实现代码在我本文最后提供。

从架构层面来看,DPO损失函数需要四组概率输入——分别是actor模型和参考模型对被选择响应和被拒绝响应的概率评估。响应的整体概率并非模型的直接输出——模型在每个时间步只预测下一个token的概率分布,完整响应的对数概率可以通过各个token对数概率的和来计算:

首先将响应输入模型,通过softmax函数将原始logits转换为概率分布,再取对数得到对数概率。然后使用响应序列作为标签来提取我们期望模型生成的token序列的对数概率。这些对数概率的和即为该模型对整个响应序列的对数概率评估。下图3详细说明了这一过程对被选择响应的具体实现:

图3. 从模型中提取特定响应对数概率的计算流程。

获取全部四个对数概率后,将其代入DPO损失函数进行优化训练。

在训练过程中,损失函数呈现出平稳的下降趋势,表明模型训练过程符合预期:

图4. DPO训练过程中的损失函数变化。

但是我们发现经过训练的模型输出出现了不连贯和重复的问题。以下是模型在保留测试集上的几个典型输出样例:

表3:使用naive DPO损失训练后的模型输出样例。可以观察到输出存在明显的不连贯性和重复性问题。

问题诊断与分析

通过深入分析DPO损失函数的特性,我们注意到最小化损失等价于最大化sigmoid函数内部的量,该量与对数概率的差值成正比:

进一步分析发现,即使右侧括号中的量变得比左侧括号更快地趋向负值,整体量仍会增长。这意味着即使actor模型对被选择响应的对数概率相对于参考模型降低,DPO损失仍可能继续下降。这种特性很可能是导致模型过拟合和性能崩溃的根本原因。

实验数据证实了这一推测(如图5所示)。在DPO论文中,括号内的每个量被定义为"奖励",这也反映在图例的标注中。从直观上理解,这种现象可能源于我们设定的对齐目标过于特定。在我们的实验设置中,被拒绝的响应从客观标准(如语法正确性和语义连贯性)来看仍是合格的响应——只是不符合我们特定的行为要求。因此对齐目标与响应质量在某种程度上是正交的,这导致了学习过程的不稳定性。

图5. DPO训练过程中的奖励变化(actor和参考模型之间的对数概率比)。可以观察到被选择响应的奖励实际呈现下降趋势。

这一发现表明我们需要对损失函数进行改进,以确保actor模型对被选择响应的对数概率始终高于参考模型。

改进的损失函数与DPOP

探索多个候选损失函数的过程,我们发现了来自Abacus.AI研究团队的一篇论文,该论文恰好针对我们观察到的问题提出了解决方案。他们提出的改进方法称为DPOP,其核心是通过损失函数强制保持actor模型对被选择响应的对数概率比始终为正。

max函数内部的项在actor的对数概率低于参考模型时为正值。此时sigmoid内部的量减小,导致sigmoid值趋近于0,从而产生较大的损失值,有效地惩罚了这种不期望的情况。

实现这一改进后的训练结果显著优于原始方法。被选择和被拒绝响应之间的奖励差距在训练过程中稳步增加,更重要的是,被选择响应的奖励始终保持为正值

图6. DPOP训练过程中的奖励变化。值得注意的是,被选择响应的奖励始终保持在正值区间。

改进后的模型在输出质量上也实现了显著提升,生成的内容更加连贯、符合语法规范且自然流畅:

表4. DPOP与DPO模型输出的质量对比。

这些实验结果清晰地表明,原始DPO训练可能导致模型性能在训练过程中发生崩溃,而DPOP通过改进的损失函数有效地解决了这一问题。

通过在保留测试集上比较DPOP生成结果与HuggingFace实现的输出,并使用GPT-4从连贯性、语法正确性和相关性三个维度进行评估。结果显示原始DPO的胜率接近0%,而DPOP达到了51%的胜率。

图7. DPO和DPOP调优模型与HuggingFace DPOTrainer训练模型的胜率对比。

与SFT的比较研究

对DPO方法的一个主要疑虑是其与SFT在本质上的相似性。所以我们以DPO数据集中的被选择响应作为专家标注,实现并评估了基于SFT的模型训练方法。

这里设计了两种不同的损失函数:

第一种称为"SFT损失",本质上是对完整响应序列的标准交叉熵损失。

第二种称为"KL-SFT损失",在SFT损失的基础上增加了一个惩罚项,用于限制actor模型与参考模型之间的偏差——这可以视为一种伪KL散度度量。需要注意的是,该惩罚项对双向偏差都进行惩罚。

我们使用这两种损失函数对GPT-2进行微调,并将结果与DPO和DPOP方法进行系统对比。

首先可以观察到部分SFT模型的输出完全没有提及Penn。为此我们引入了"对齐准确率"这一指标来评估所有训练模型的表现——即在保留测试集上包含Penn引用的输出比例。实验结果显示,DPO和DPOP调优的模型在对齐目标上保持了100%的准确率,而SFT模型的准确率分别为94%和88%。这种差异可能源于SFT模型缺乏负样本,从而导致其输出表现出更大的多样性。这凸显了负样本在抑制模型生成那些在技术上可行但不符合期望的输出方面的重要作用,这正是偏好优化方法的核心优势所在。

图8. 对齐准确率反映了模型在保留测试集中提及Penn的输出比例。DPO和DPOP接近100%,而SFT约为90%。

其次,使用GPT-4对SFT、KL-SFT、DPOP、DPO和HuggingFace DPOTrainer模型之间进行了全面的胜率评估(如下图所示)。出人意料的是,SFT模型获得了最高的胜率,这一结果通过对响应的人工分析得到了进一步验证。相比DPOP和HuggingFace的输出,SFT模型的生成结果展现出更强的创造性和更好的语法规范性。

图9. GPT-4评估的各模型对间胜率对比。SFT模型表现最佳,对DPOP的胜率达到71%。

这些实验结果表明,在特定行为模式的优化方面,SFT可能与DPO具有相当的效果。尽管DPO通过引入被拒绝样本提高了对对齐目标的遵循度,但这种机制也可能导致过拟合和不稳定性,从而在生成质量上略逊于SFT模型。

研究局限性

首先,使用的数据集与真实世界的RLHF数据集有显著差异。本文聚焦于优化一个特定的行为模式:在响应中提及Penn。而真实场景中的数据集通常反映更为微妙的偏好特征,被选择和被拒绝响应之间的差异更加细微。这意味着在真实数据集上,RLHF可能由于其更灵活的响应质量建模能力而显著优于SFT。

其次,本文使用了参数量最小的GPT-2模型,这在某种程度上简化了优化过程。这可能导致我们的评估结果存在偏差——由于模型本身知识储备有限,很难准确评估模型是否发生了知识遗忘或性能退化。

第三,我们使用GPT-4进行对齐准确率和胜率评估,而GPT-4对提示的细微变化表现出较高的敏感性。考虑到GPT-4可能无法准确模拟人类偏好,实际的人类评估结果可能会有显著差异。

总结

DPO是一种简洁而高效的技术范式。相比传统的奖励模型+PPO的RLHF方法,它具有实现简单、计算高效的优势,同时能够产生良好的优化效果。其目标函数在强化学习和SFT目标之间实现了有效的平衡,使模型能够高度遵循DPO数据集中反映的偏好特征,这反映在接近100%的对齐准确率上。

DPOP在稳定性方面优于原始DPO。DPO存在已知的不稳定性问题,可能导致训练过程中被选择响应的对数概率下降。我们的实验结果验证了这一问题,并证实DPOP能够有效改善模型表现。从理论角度分析,这可能是因为我设定的对齐目标过于特定,而数据集中的被拒绝响应从语法和连贯性等客观标准来看仍具有较高质量。在更一般性的对齐任务中,如果被拒绝响应展现出明显的不连贯性或错误,DPO可能会表现得更好。

针对选定响应的SFT展现出comparable的性能。尽管SFT在保留测试集上的响应并不总是严格遵循对齐目标,但这些响应通常表现出更强的创造性和具体性。在GPT-4的评估中,SFT相比DPOP获得了71%的显著胜率。

RLHF评估面临的挑战。准确评估模型行为的变化程度、行为一致性以及是否保留了先前训练阶段获得的知识,是一个极具挑战性的任务。我们对使用前沿模型进行评估持谨慎态度,因为这些模型的评分表现出较大的方差,且对提示的细微变化特别敏感。现有RLHF文献多采用人类评估者与AI评估者相结合的方法,这似乎提供了更好的平衡,但其实仍然有一个未解决的问题就是顶级AI实验室是如何执行这一过程的。例如:如何向评估者传达模型的预期行为?如何在评估者之间建立一致的评判标准?是否应该使用创建偏好数据集的同一批人类评估者?这些都是有待深入研究的开放性问题。

本文代码

https://avoid.overfit.cn/post/d8468a92798745d298b1130c98adc934

作者:Aalok Patwa


http://www.kler.cn/a/447842.html

相关文章:

  • Mysql复习(二)
  • 内置函数.
  • PostgreSQL表达式的类型
  • 智能座舱进阶-应用框架层-Handler分析
  • Vue之版本演进
  • 防火墙技术与网络安全
  • 机器学习(四)-回归模型评估指标
  • 【LeetCode】906、超级回文数
  • vue入门教程:组件透传 Attributes
  • c++领域展开第四幕——类和对象(上篇收尾 this指针、c++和c语言的初步对比)超详细!!!!
  • 如何使用PSQL Tool还原pg数据库(sql格式)
  • Kubernetes网络管理
  • 示波器--UNI-T 优利德 UT4102C 使用介绍
  • 前端面试:项目细节重难点问题分享(19)
  • 一步一步写线程之十六线程的安全退出之二例程
  • 2024年12月的《数据资产管理实践指南(7.0版)》解析
  • 使用Python构建个性化学习管理系统
  • javaEE-线程的常用方法-4
  • GIT与github的链接(同步本地与远程仓库)
  • 深入理解 Java 中的 ArrayList 和 List:泛型与动态数组
  • (2024.12)Ubuntu20.04安装ZED-SDK
  • 图解HTTP-HTTP报文
  • 硬盘接口模式sata与ahci区别, U盘UEFI GPT与Legacy 启动项区别,硬盘格式MBR和gpt的区别
  • JavaEE 导读与环境配置
  • 【Windows版】opencv 和opencv_contrib配置
  • 大模型+安全实践之春天何时到来?