当前位置: 首页 > article >正文

跟李沐学AI:InstructGPT论文精读(SFT、RLHF)

原论文:[2203.02155] Training language models to follow instructions with human feedback

原视频:InstructGPT 论文精读【论文精读·48】_哔哩哔哩_bilibili

简介

1. RLHF 的基本概念

RLHF 是一种结合强化学习和人类反馈的训练方法,旨在让模型生成更符合人类期望的输出。它通过引入人类的偏好数据来调整模型的行为,使其不仅能够生成语法正确的文本,还能生成语义上更贴合用户意图的内容

2. RLHF 的主要步骤

RLHF 的流程通常分为以下三个阶段:

(1) 预训练语言模型
  • 在 RLHF 流程开始之前,使用一个预训练的语言模型作为基础。
  • 预训练的目标是让模型掌握大规模语言知识,但此时的模型尚未经过特定任务的优化
(2) 监督微调(Supervised Fine-Tuning, SFT)
  • 在这一阶段,研究人员收集一组高质量的人类生成的指令-响应对(instruction-response pairs),并用这些数据对预训练模型进行监督微调。
  • 这些指令-响应对由人类标注者编写,确保模型学会如何根据明确的指令生成合适的回答
(3) 强化学习优化
  • 强化学习阶段是 RLHF 的核心部分,分为两个子步骤:
    • 奖励模型(Reward Model, RM)训练 :研究人员收集人类对模型输出的偏好数据(例如,标注者选择更优的回答),并用这些数据训练一个奖励模型。奖励模型的作用是为模型生成的回答打分,反映其质量或与人类期望的匹配程度。
    • 策略优化(Policy Optimization) :使用强化学习算法(如 PPO,Proximal Policy Optimization)对模型进行进一步优化。PPO 是一种高效的强化学习算法,能够在保持模型稳定性的同时提升性能。模型会根据奖励模型的反馈不断调整其生成策略,最终生成更符合人类偏好的回答。

3. RLHF 的优势

  • 更高的指令遵从性 :通过引入人类反馈,InstructGPT 能够更好地理解并执行复杂的指令,减少了无关或不准确的回答。
  • 减少有害内容 :RLHF 方法可以帮助模型避免生成有害、偏见或不适当的内容,从而提高安全性。
  • 自然流畅的对话能力 :虽然 InstructGPT 更注重指令执行,但 RLHF 的优化也让模型在对话场景中表现得更加自然。

摘要

扩大模型的规模并不能从本质上让模型遵守用户的意图。如大语言模型可能像用户生成不真实、有害的或者无用的输出。

这篇文章中,作者团队提出了一种将大语言模型在大规模基于人类反馈的微调数据上与用户意图对齐的方法。

从一组标注者编写的提示词和通过OpenAI API提交的提示词开始,作者团队收集了一个数据集,其中包含标注者预期的模型输出结果,作者用这些数据来通过监督学习微调GPT-3。

然后,作者收集了一个模型输出排名的数据集,进一步使用基于人类反馈的强化学习来微调这个监督模型,并将最终得到的模型称为InstructGPT。

导论

大语言模型可以根据提示来完成一系列自然语言处理工作,但是模型经常会出现捏造事实、生成有偏见或有害的文本内容。作者认为这是因为模型训练的目标函数存在错误。模型训练的目标函数是从网络文本中预测下一个词元,这一目标函数与遵循用户提示的目标有一定偏差。这就是语言建模目标没有“对齐”。

为此,作者提出RLHF(Reinforcement Learning form Human Feedback,基于人类反馈的强化学习),以保证模型输出具有帮助性(helpful)、真实性(honest)和无害性(harmless)。

1. 采集各类问题prompt;

2. 标注者对问题进行回答(以模型预期的输出方式进行回答)

3. prompt+回答的文本对讲用于微调GPT-3 (SFT,Self Supervised Learning,有监督微调)

1. 收集一段问题prompt,同时讲模型对这个问题的不同输出进行采样。(模型输出具有随机性,所以对同一个问题会有不同的回答,输出的灵活度取决于解码策略)

2. 标注者对采样的模型输出从好到坏进行排序。

3. 将这些有好坏排序标注的回答数据用于训练一个奖励模型(RM,Reward Model)

这一步可以减轻数据标注成本。

1. 从数据集中随机抽取一个新的提示(prompt),例如“Write a story about frogs”(写一个关于青蛙的故事)。这个提示将作为输入提供给策略模型。

2. 策略模型(Policy Model)根据接收到的提示生成一个输出。在这个例子中,策略模型可能生成一段故事的开头,如“Once upon a time...”(从前……)。 

3. 奖励模型(Reward Model, RM)对策略模型生成的输出进行评估,并计算出一个奖励值(reward)。这个奖励值反映了输出的质量或与人类期望的匹配程度。

4. 计算出的奖励值被用于通过PPO(Proximal Policy Optimization,近端策略优化算法)更新策略模型。PPO 是一种高效的强化学习算法,能够帮助策略模型根据奖励信号调整其生成策略,从而在后续迭代中生成更高质量的输出。

方法

数据集来源:

标注人员编写提示词,提示词涉及如下三类:

  • 一般性的任意提示词
  • 多步问答提示词
  • 用户反馈的提示词:每个用户最多会被采集200个提示词,数据将根据用户ID来划分为训练集、测试集、验证集。

基于这些提示词,作者创造了三种不同的数据集:

  • SFT数据集:标注人员基于提示词编写答案
  • RM数据集:标注人员对模型输出的结果进行排序
  • PPO数据集:没有任何人工标签,直接用于RLHF微调

数据标注:招聘了40人进行数据标注,与这些标注人员紧密联系。

模型

1. 在SFT数据集基础上,对GPT-3进行微调,一共训练了16轮。SFT模型仅仅在第一轮就出现了过拟合,但是对后续步骤没有影响,甚至有帮助。

2. RM模型移除了SFT模型中的Softmax层,转为一层线性层将模型输出投影为一个标量奖励。本文使用了6B的模型作为奖励模型。

RM模型的损失函数如下:

对于每个输入Prompt x,我们生成 K 个不同的答案,这里 K=9。我们将这些答案分为两组:一组是排序较高的答案 Yw​,另一组是排序较低的答案 Yl​。

我们的目标是让模型能够正确地识别出哪些答案更好。为此,我们将 Yw​ 和 Yl​ 输入到奖励模型中,分别得到它们的奖励分数 rθ​(x,Yw​) 和 rθ​(x,Yl​)。由于 Yw​ 的排序高于 Yl​,我们希望 Yw​ 的奖励分数也高于 Yl​。

为了实现这一目标,我们计算两个答案的奖励分数之差,并通过sigmoid函数进行转换,以确保结果在0到1之间。然后,我们对这个结果取对数,并添加一个负号,以确保当 Yw​ 的奖励高于 Yl​ 时,损失值为正。最后,我们将这个值乘以1 / (CK_2),即 1 / 36​,以平均化所有可能的配对组合。

3. 强化学习:使用PPO算法对策略进行优化。

结果 

经过SFT+RLHF后的1.3B模型效果好于原有的175B GPT-3模型。 

局限性:

  • 1. 模型的行为与标注者息息相关
  • 2. 模型不是完全的安全

http://www.kler.cn/a/554352.html

相关文章:

  • 如何在Java爬虫中设置动态延迟以避免API限制
  • 缓存-算法
  • 6. 【.NET 8 实战--孢子记账--从单体到微服务--转向微服务】--微服务基础工具与技术--Ocelot 网关--概念与简单入门
  • uniapp多端适配
  • 【数据分享】2000—2024年逐月归一化植被指数(NDVI)栅格数据(免费获取/全国/分省)
  • 【JavaEE进阶】数据库连接池
  • web网络安全:跨站脚本攻击(XSS)
  • 知识库-查看知识详情接口
  • 图论 之 DFS
  • C/C++面试知识点总结
  • 2.20学习
  • 《Operating System Concepts》阅读笔记:p50-p61
  • 基于 Flask 与 MySQL 构建简单的博客系统
  • 面试基础--分布式任务调度系统设计方案
  • 数据结构:广义表( Generalized List)及其实现
  • SpringMVC的基本使用
  • SpringBoot 项目配置日志输出
  • mysql中union all和WITH ROLLUP实现汇总的两种方式
  • Flink CDC详解
  • 在实时大数据处理中如何平衡延迟和吞吐量