【llm post-training】从Loss Function设计上看LLM SFT和RL的区别和联系
大型语言模型(LLM)的崛起深刻地改变了自然语言处理领域。为了让这些模型更好地服务于人类,两个关键的微调技术应运而生:监督微调 (Supervised Fine-tuning, SFT) 和 强化学习 (Reinforcement Learning, RL)。 虽然两者都旨在提升LLM的性能,但它们背后的原理和实现方式却大相径庭,而这些差异最核心的体现就在于它们所使用的 loss function (损失函数) 的设计。
本文将以 loss function 为中心,深入探讨 SFT 和 RL 在 LLM 微调中的区别与联系。我们将从最基础的原理出发,结合丰富的例子和代码片段,力求用最简单易懂的方式,帮助读者理解这两种关键技术的精髓。
1. 理解Loss Function:LLM训练的“指挥棒”
在深入 SFT 和 RL 之前,我们首先要理解 loss function 在LLM训练中的作用。 简单来说,loss function就像是训练过程中的“指挥棒”或“评分员”。 它衡量模型预测结果与真实目标之间的差距,并指导模型参数的调整方向。
- Loss function 的目标: 最小化模型预测与真实值之间的误差。误差越小,模型表现越好。
- 梯度下降 (Gradient Descent): 优化器 (如Adam, SGD) 利用 loss function 计算出的梯度,来迭代更新模型参数,朝着 loss function 最小化的方向前进。
不同的任务和训练方法需要设计不同的 loss function。 对于LLM的 SFT 和 RL,loss function 的设计更是至关重要,它直接决定了模型最终的学习目标和行为模式。
2. 监督微调 (SFT):模仿学习,以“标准答案”为目标
2.1 SFT 的核心思想:模仿人类示范
SFT 的核心思想非常直观:模仿学习 (Imitation Learning)。 我们提供高质量的 人工标注数据,让模型学习如何像人类一样完成特定的任务。 这些数据通常包含输入 (input) 和 对应的理想输出 (target output),例如:
- 指令遵循 (Instruction Following):
- Input: “请写一首关于秋天的诗歌。”
- Target Output: “秋风瑟瑟落叶飞,\n枫林尽染夕阳辉。\n稻浪金黄田野阔,\n雁字南归声声催。”
- 问答 (Question Answering):
- Input: “巴黎是哪个国家的首都?”
- Target Output: “法国。”
- 文本生成 (Text Generation):
- Input (prompt): “故事开头:在一个遥远的星系…”
- Target Output (续写): “…,一颗孤独的星球缓缓转动,上面居住着一群拥有奇特能力的生物。”
SFT 的目标就是让 LLM 学习生成 尽可能接近 target output 的文本。
2.2 SFT 的 Loss Function:交叉熵损失 (Cross-Entropy Loss)
SFT 最常用的 loss function 是 交叉熵损失 (Cross-Entropy Loss)。 对于文本生成任务,我们通常使用 token-level 的交叉熵损失。
原理:
- 假设我们的模型要预测序列中的下一个 token。
- 模型的输出是一个概率分布,表示模型预测下一个 token 为词汇表中每个词的概率。
- 交叉熵损失衡量的是 模型预测的概率分布 与 真实的 one-hot 分布 之间的差距。
- 真实的 one-hot 分布中,只有 正确的下一个 token 的概率为 1,其余 token 的概率为 0。
公式 (简化版):
假设词汇表大小为 V,模型预测下一个 token 的概率分布为 p = [ p 1 , p 2 , . . . , p V ] p = [p_1, p_2, ..., p_V] p=[p1,p2,...,pV],真实的下一个 token 的 one-hot 分布为 y = [ y 1 , y 2 , . . . , y V ] y = [y_1, y_2, ..., y_V] y=[y1,y2,...,yV] (其中 y i = 1 y_i = 1 yi=1 如果第 i 个 token 是正确的,否则 y i = 0 y_i = 0 yi=0)。 则交叉熵损失为:
L C E = − ∑ i = 1 V y i log ( p i ) L_{CE} = - \sum_{i=1}^{V} y_i \log(p_i) LCE=−∑i=1Vyilog(pi)
由于 y y y 是 one-hot 分布,只有正确的 token 对应的 y i y_i yi 为 1,其余为 0。因此,公式可以简化为:
L C E = − log ( p c o r r e c t _ t o k e n ) L_{CE} = - \log(p_{correct\_token}) LCE=−log(pcorrect_token)
直观理解:
交叉熵损失的目标是 最大化正确 token 的预测概率。 概率越高,损失越小;概率越低,损失越大。
代码示例 (PyTorch):
import torch
import torch.nn.functional as F
# 假设词汇表大小为 5
vocab_size = 5
# 模型预测的概率分布 (batch_size=1, sequence_length=1, vocab_size)
logits = torch.randn(1, 1, vocab_size) # 模拟模型输出的 logits,未经softmax
probs = F.softmax(logits, dim=-1) # 将 logits 转换为概率分布
# 真实的下一个 token 的索引 (batch_size=1, sequence_length=1)
target_token_index = torch.tensor([[2]]) # 假设正确的 token 索引为 2
# 计算交叉熵损失
loss = F.cross_entropy(logits.view(-1, vocab_size), target_token_index.view(-1))
print("模型预测概率分布:", probs)
print("目标token索引:", target_token_index)
print("交叉熵损失:", loss)
解释:
logits
代表模型输出的未经过 softmax 的值。F.softmax(logits, dim=-1)
将 logits 转换为概率分布。target_token_index
指定了正确的下一个 token 的索引。F.cross_entropy()
计算交叉熵损失。 PyTorch 的cross_entropy
函数会自动进行 softmax 计算,所以输入是 logits。
序列级别的交叉熵损失:
在实际 SFT 中,我们通常处理的是整个文本序列,而不是单个 token。 对于序列生成任务,我们会计算 每个 token 位置的交叉熵损失,然后 求平均 作为整个序列的损失。
假设我们有一个长度为 L 的目标序列 Y = [ y 1 , y 2 , . . . , y L ] Y = [y_1, y_2, ..., y_L] Y=[y1,y2,...,yL],模型预测的概率分布序列为 P = [ p 1 , p 2 , . . . , p L ] P = [p_1, p_2, ..., p_L] P=[p1,p2,...,pL]。 则序列级别的交叉熵损失为:
L S F T = 1 L ∑ t = 1 L L C E ( p t , y t ) L_{SFT} = \frac{1}{L} \sum_{t=1}^{L} L_{CE}(p_t, y_t) LSFT=L1∑t=1LLCE(pt,yt)
2.3 SFT 的优点和局限性
优点:
- 简单有效: SFT 的原理和实现都相对简单,易于理解和操作。
- 快速收敛: 由于有明确的“标准答案”作为监督信号,SFT 通常能够快速收敛,提升模型性能。
- 适用于多种任务: SFT 可以应用于各种文本生成任务,如指令遵循、问答、对话生成等。
局限性:
- 依赖高质量标注数据: SFT 的性能高度依赖于标注数据的质量和数量。 标注成本高昂,且容易引入人为偏见。
- 无法超越标注数据: SFT 模型主要学习模仿标注数据,很难超越人类专家的水平,也难以发现新的知识或创造性地解决问题。
- 可能过度拟合标注数据: 如果标注数据不够多样化,SFT 模型可能会过度拟合训练数据,导致泛化能力下降。
- 与人类价值观对齐困难: SFT 仅仅学习模仿人类的文本风格和知识,但难以学习人类的价值观、偏好和伦理道德,容易生成有害或不符合人类期望的内容。 例如,SFT 模型可能无法理解 “helpful, honest, and harmless” 的真正含义。
3. 强化学习 (RL):奖励驱动,以“最优行为”为目标
3.1 RL 的核心思想:通过试错学习最优策略
与 SFT 的模仿学习不同,强化学习 (Reinforcement Learning, RL) 更侧重于 通过与环境交互,学习能够最大化累积奖励的最优策略。 在 LLM 微调中,“环境” 可以理解为用户或者评判模型,“奖励” 则是由奖励模型 (Reward Model) 给出的,用来衡量模型生成文本的质量和对齐程度。
RL 的核心概念:
- Agent (智能体): LLM 模型本身,负责生成文本。
- Environment (环境): 用户、评判模型、或者模拟环境,负责与 agent 交互并给出反馈。
- Action (动作): Agent 生成的文本 (token 序列)。
- State (状态): 当前的环境状态,例如用户的输入、对话历史等 (在文本生成任务中,状态通常不显式定义,更多关注动作和奖励)。
- Reward (奖励): 环境对 agent 生成文本的反馈,用于衡量文本的质量和对齐程度。
- Policy (策略): Agent 的行为准则,即在给定状态下,如何选择 action (生成文本)。 LLM 模型的参数就代表了策略。
RL 的目标: 学习一个 policy (策略),使得 agent 在与环境交互的过程中,能够获得 最大的累积奖励。
3.2 RL 的 Loss Function:奖励最大化,策略梯度
RL 的 loss function 与 SFT 的交叉熵损失完全不同,它不再是直接模仿“标准答案”,而是 以最大化奖励为目标。 RL 中常用的 loss function 基于 策略梯度 (Policy Gradient) 方法。
策略梯度方法的核心思想:
- 直接优化策略 (模型参数),使其能够产生 更高奖励的 actions (文本)。
- 通过采样 (sample) 的方式,从当前策略中获取 actions 和对应的 rewards。
- 利用 rewards 来更新策略,增加高奖励 actions 的概率,降低低奖励 actions 的概率。
几种常见的 RL loss function 在 LLM 微调中的应用:
3.2.1 Policy Gradient (策略梯度) - REINFORCE
最基础的策略梯度算法是 REINFORCE。 其 loss function 可以表示为:
L P G = − E π θ [ R ( τ ) log π θ ( τ ) ] L_{PG} = - \mathbb{E}_{\pi_{\theta}} [R(\tau) \log \pi_{\theta}(\tau)] LPG=−Eπθ[R(τ)logπθ(τ)]
解释:
- π θ ( τ ) \pi_{\theta}(\tau) πθ(τ) 表示策略 π θ \pi_{\theta} πθ 生成轨迹 τ \tau τ (token 序列) 的概率, θ \theta θ 是模型参数。
- R ( τ ) R(\tau) R(τ) 表示轨迹 τ \tau τ 的累积奖励。
- E π θ [ ⋅ ] \mathbb{E}_{\pi_{\theta}} [\cdot] Eπθ[⋅] 表示在策略 π θ \pi_{\theta} πθ 下的期望。
直观理解:
- 如果一个轨迹 τ \tau τ 获得了较高的奖励 R ( τ ) R(\tau) R(τ),那么 − R ( τ ) log π θ ( τ ) -R(\tau) \log \pi_{\theta}(\tau) −R(τ)logπθ(τ) 就为负值,loss 减小,模型参数更新会 增加生成该轨迹 τ \tau τ 的概率。
- 反之,如果奖励较低,loss 增大,模型参数更新会 降低生成该轨迹 τ \tau τ 的概率。
代码示例 (简化版 - 概念演示):
import torch
import torch.nn.functional as F
import torch.optim as optim
# 假设模型是简单的线性模型
model = torch.nn.Linear(10, 5) # 输入维度 10, 词汇表大小 5
optimizer = optim.Adam(model.parameters(), lr=0.01)
def generate_text(model, input_tensor):
"""模拟文本生成过程 (简化版,实际LLM更复杂)"""
logits = model(input_tensor)
probs = F.softmax(logits, dim=-1)
action = torch.multinomial(probs, num_samples=1) # 从概率分布中采样 token
return action, probs
def calculate_reward(action):
"""模拟奖励函数 (根据 action 索引简单定义奖励)"""
if action.item() == 2: # 假设 token 索引 2 是 "好" 的 token
return 1.0
else:
return -0.1
# 模拟一个 episode (一次文本生成过程)
input_data = torch.randn(1, 10) # 模拟输入
actions = []
log_probs = []
rewards = []
for _ in range(10): # 生成 10 个 token
action, probs = generate_text(model, input_data)
reward = calculate_reward(action)
actions.append(action)
log_probs.append(torch.log(probs.squeeze()[action])) # 记录生成 action 的 log 概率
rewards.append(reward)
# 计算累积奖励 (这里简化为直接求和,实际中可能需要 discount factor)
cumulative_reward = sum(rewards)
# 计算 REINFORCE loss
policy_loss = 0
for log_prob, reward in zip(log_probs, rewards):
policy_loss += -log_prob * cumulative_reward # 注意这里使用 cumulative_reward
# 反向传播和参数更新
optimizer.zero_grad()
policy_loss.backward()
optimizer.step()
print("生成的 actions:", actions)
print("累积奖励:", cumulative_reward)
print("REINFORCE Loss:", policy_loss)
解释:
- 代码模拟了一个简单的文本生成过程和奖励计算。
generate_text
函数模拟模型生成 token 的过程。calculate_reward
函数模拟奖励函数。policy_loss
计算 REINFORCE loss,并进行反向传播和参数更新。- 关键点: loss 的计算依赖于 整个 episode 的累积奖励 和 生成 actions 的 log 概率。
REINFORCE 的局限性:
- 高方差: REINFORCE 算法的梯度估计方差很高,导致训练不稳定,收敛速度慢。
- 不适用于 online learning: 需要完成一个 episode 才能更新参数,不适合在线学习场景。
3.2.2 Proximal Policy Optimization (PPO) - 近端策略优化
为了解决 REINFORCE 的高方差和训练不稳定性问题,研究者提出了 Proximal Policy Optimization (PPO) 算法。 PPO 是目前 LLM RL 微调中最常用的算法之一。
PPO 的核心思想是 Trust Region Policy Optimization (TRPO) 的简化版本,旨在 限制策略更新的幅度,避免策略更新过大导致训练不稳定。
PPO 的 loss function 主要包含两部分: Clip Loss (裁剪损失) 和 Value Loss (值函数损失) (Value Loss 用于加速学习,本文主要关注 Policy Loss)。
PPO Clip Loss 公式 (简化版):
L P P O − C l i p = E τ ∼ π o l d [ min ( r a t i o ( θ ) ⋅ A π o l d ( τ ) , c l i p ( r a t i o ( θ ) , 1 − ϵ , 1 + ϵ ) ⋅ A π o l d ( τ ) ) ] L_{PPO-Clip} = \mathbb{E}_{\tau \sim \pi_{old}} [\min(ratio(\theta) \cdot A^{\pi_{old}}(\tau), clip(ratio(\theta), 1-\epsilon, 1+\epsilon) \cdot A^{\pi_{old}}(\tau))] LPPO−Clip=Eτ∼πold[min(ratio(θ)⋅Aπold(τ),clip(ratio(θ),1−ϵ,1+ϵ)⋅Aπold(τ))]
其中:
- π o l d \pi_{old} πold 是 旧策略 (old policy),即更新前的策略。
- π θ \pi_{\theta} πθ 是 新策略 (new policy),即当前要更新的策略。
- r a t i o ( θ ) = π θ ( τ ) π o l d ( τ ) ratio(\theta) = \frac{\pi_{\theta}(\tau)}{\pi_{old}(\tau)} ratio(θ)=πold(τ)πθ(τ) 是 新旧策略生成轨迹 τ \tau τ 概率的比值。
- A π o l d ( τ ) A^{\pi_{old}}(\tau) Aπold(τ) 是 优势函数 (Advantage Function),估计在状态 s 下采取动作 a 相对于平均水平的优势。 它可以简化理解为 当前轨迹 τ \tau τ 相对于平均轨迹的 “好坏程度” (通常使用 GAE 等方法计算)。
- ϵ \epsilon ϵ 是 裁剪参数 (clip parameter),通常设置为 0.1 或 0.2,用于限制 r a t i o ( θ ) ratio(\theta) ratio(θ) 的变化范围。
- c l i p ( r a t i o ( θ ) , 1 − ϵ , 1 + ϵ ) clip(ratio(\theta), 1-\epsilon, 1+\epsilon) clip(ratio(θ),1−ϵ,1+ϵ) 将 r a t i o ( θ ) ratio(\theta) ratio(θ) 裁剪到 [ 1 − ϵ , 1 + ϵ ] [1-\epsilon, 1+\epsilon] [1−ϵ,1+ϵ] 范围内。
直观理解:
- PPO Clip Loss 的目标仍然是 最大化优势函数 A π o l d ( τ ) A^{\pi_{old}}(\tau) Aπold(τ),即鼓励生成 “好” 的轨迹。
- 关键创新: 通过 clip 函数 限制了策略更新的幅度。
- 如果 r a t i o ( θ ) > 1 + ϵ ratio(\theta) > 1 + \epsilon ratio(θ)>1+ϵ 且 A π o l d ( τ ) > 0 A^{\pi_{old}}(\tau) > 0 Aπold(τ)>0 (即新策略比旧策略好,且优势为正),则 loss 会被 clip 到 ( 1 + ϵ ) ⋅ A π o l d ( τ ) (1+\epsilon) \cdot A^{\pi_{old}}(\tau) (1+ϵ)⋅Aπold(τ), 限制了策略提升幅度过大。
- 如果 r a t i o ( θ ) < 1 − ϵ ratio(\theta) < 1 - \epsilon ratio(θ)<1−ϵ 且 A π o l d ( τ ) < 0 A^{\pi_{old}}(\tau) < 0 Aπold(τ)<0 (即新策略比旧策略差,且优势为负),则 loss 会被 clip 到 ( 1 − ϵ ) ⋅ A π o l d ( τ ) (1-\epsilon) \cdot A^{\pi_{old}}(\tau) (1−ϵ)⋅Aπold(τ), 限制了策略下降幅度过大。
- 在 [ 1 − ϵ , 1 + ϵ ] [1-\epsilon, 1+\epsilon] [1−ϵ,1+ϵ] 范围内的策略更新不受裁剪限制。
代码示例 (简化版 - 概念演示):
import torch
import torch.nn.functional as F
import torch.optim as optim
# ... (模型和 optimizer 定义与 REINFORCE 示例相同) ...
def calculate_advantage(rewards):
"""简化优势函数计算 (实际中更复杂)"""
return sum(rewards) # 这里简化为直接求和
# 模拟一个 episode (一次文本生成过程)
# ... (生成 actions, log_probs, rewards 与 REINFORCE 示例相同) ...
# 计算优势函数
advantage = calculate_advantage(rewards)
# 获取旧策略生成 actions 的 log 概率 (假设已经提前保存)
old_log_probs = [torch.tensor([-0.5]), torch.tensor([-0.3]), ...] # 示例数据,实际需要从旧策略模型获取
# 计算 ratio
ratio = torch.exp(torch.tensor(log_probs) - torch.tensor(old_log_probs)) # 注意这里使用 log 概率的差值
# 计算 PPO Clip Loss
epsilon = 0.2
clip_ratio = torch.clamp(ratio, 1 - epsilon, 1 + epsilon)
ppo_clip_loss = -torch.min(ratio * advantage, clip_ratio * advantage).mean() # 取 min 并求均值
# 反向传播和参数更新
optimizer.zero_grad()
ppo_clip_loss.backward()
optimizer.step()
print("PPO Clip Loss:", ppo_clip_loss)
解释:
- 代码在 REINFORCE 示例的基础上,增加了 旧策略的 log 概率
old_log_probs
和 裁剪损失的计算ppo_clip_loss
。 calculate_advantage
函数简化了优势函数的计算。ratio
计算新旧策略概率比值。torch.clamp
实现裁剪函数。torch.min
和mean()
计算 PPO Clip Loss。
PPO 的优点:
- 训练稳定: 通过 clip 函数限制策略更新幅度,提高了训练稳定性。
- 采样效率高: 可以使用 off-policy 数据进行训练,提高了采样效率。
- 广泛应用: PPO 是目前最流行的 RL 算法之一,在各种 RL 任务中都取得了成功,包括 LLM 微调。
3.2.3 Direct Preference Optimization (DPO) - 直接偏好优化
PPO 虽然有效,但训练过程相对复杂,需要 reward model、value function 等组件,且超参数较多,调参困难。 Direct Preference Optimization (DPO) 算法旨在 简化 RL 训练流程,直接从人类偏好数据中学习策略,而 无需显式训练 reward model。
DPO 的核心思想是 将 reward model 隐式地融入到 loss function 中。 它利用 Bradley-Terry 模型 来建模人类偏好,并推导出可以直接优化的 loss function。
DPO Loss Function 公式 (简化版):
L D P O = − E ( x , y w , y l ) ∼ D [ log σ ( β ( log π θ ( y w ∣ x ) − log π θ ( y l ∣ x ) ) − log π r e f ( y w ∣ x ) π r e f ( y l ∣ x ) ) ] L_{DPO} = - \mathbb{E}_{(x, y_w, y_l) \sim D} [\log \sigma (\beta (\log \pi_{\theta}(y_w|x) - \log \pi_{\theta}(y_l|x)) - \log \frac{\pi_{ref}(y_w|x)}{\pi_{ref}(y_l|x)})] LDPO=−E(x,yw,yl)∼D[logσ(β(logπθ(yw∣x)−logπθ(yl∣x))−logπref(yl∣x)πref(yw∣x))]
其中:
- ( x , y w , y l ) (x, y_w, y_l) (x,yw,yl) 是 人类偏好数据,表示对于输入 x x x,人类更偏好 y w y_w yw (win) 而不是 y l y_l yl (lose)。
- D D D 是偏好数据集。
- π θ \pi_{\theta} πθ 是当前要训练的策略 (LLM 模型)。
- π r e f \pi_{ref} πref 是 参考策略 (reference policy),通常是 SFT 模型,作为训练的基准。
- β \beta β 是 温度参数,控制偏好的强度。
- σ ( z ) = 1 1 + e − z \sigma(z) = \frac{1}{1 + e^{-z}} σ(z)=1+e−z1 是 sigmoid 函数。
直观理解:
- DPO Loss 的目标是 最大化人类偏好数据集中 “win” 样本的概率,同时最小化 “lose” 样本的概率。
- 关键创新: 直接比较 win 和 lose 样本的 log 概率差值,并利用 sigmoid 函数将其转换为概率值。 隐式地学习了 reward model 的效果。
- log π r e f ( y w ∣ x ) π r e f ( y l ∣ x ) \log \frac{\pi_{ref}(y_w|x)}{\pi_{ref}(y_l|x)} logπref(yl∣x)πref(yw∣x) 是 参考策略的 log 概率比值,作为 正则化项,防止策略偏离参考策略太远,保持模型的生成质量。
代码示例 (简化版 - 概念演示):
import torch
import torch.nn.functional as F
import torch.optim as optim
# ... (模型和 optimizer 定义与 REINFORCE 示例相同) ...
# 模拟偏好数据 (输入 x, win 样本 y_w, lose 样本 y_l)
input_x = torch.randn(1, 10)
win_sample = torch.randint(0, 5, (1, 5)) # 假设 win 样本是长度为 5 的 token 序列
lose_sample = torch.randint(0, 5, (1, 5)) # 假设 lose 样本是长度为 5 的 token 序列
# 模拟参考策略 (这里简化为另一个模型,实际中可能是 SFT 模型)
ref_model = torch.nn.Linear(10, 5)
def calculate_log_prob(model, input_tensor, sample):
"""计算模型生成 sample 的 log 概率 (简化版)"""
logits = model(input_tensor)
log_probs = F.log_softmax(logits, dim=-1)
sample_log_prob = 0
for i in range(sample.size(1)): # 遍历 sample 中的每个 token
sample_log_prob += log_probs[0, i, sample[0, i]] # 累加每个 token 的 log 概率
return sample_log_prob
# 计算当前策略和参考策略生成 win 和 lose 样本的 log 概率
pi_theta_log_prob_win = calculate_log_prob(model, input_x, win_sample)
pi_theta_log_prob_lose = calculate_log_prob(model, input_x, lose_sample)
pi_ref_log_prob_win = calculate_log_prob(ref_model, input_x, win_sample)
pi_ref_log_prob_lose = calculate_log_prob(ref_model, input_x, lose_sample)
# 计算 DPO Loss
beta = 0.1 # 温度参数
dpo_loss = -torch.log(torch.sigmoid(beta * (pi_theta_log_prob_win - pi_theta_log_prob_lose) - (pi_ref_log_prob_win - pi_ref_log_prob_lose)))
# 反向传播和参数更新
optimizer.zero_grad()
dpo_loss.backward()
optimizer.step()
print("DPO Loss:", dpo_loss)
解释:
- 代码模拟了偏好数据和参考策略,并计算 DPO Loss。
calculate_log_prob
函数简化了计算模型生成样本 log 概率的过程。dpo_loss
计算 DPO Loss。
DPO 的优点:
- 训练流程更简洁: 无需显式训练 reward model,简化了训练流程。
- 更稳定高效: 直接优化策略,避免了 reward model 训练带来的不稳定性和误差累积。
- 与人类偏好更对齐: 直接从人类偏好数据中学习,更容易与人类价值观和偏好对齐。
3.3 RL 的优点和挑战
优点:
- 更好地与人类价值观对齐: 通过 reward model 或偏好数据引导模型学习人类期望的行为,更容易实现与人类价值观的对齐。
- 超越 SFT 的能力: RL 可以探索更广阔的策略空间,发现 SFT 无法学习到的更优策略,例如更具创造性、更符合长期目标的行为。
- 更强的泛化能力: RL 模型不仅学习模仿标注数据,更学习如何根据奖励信号进行决策,可能具有更强的泛化能力。
挑战:
- 训练复杂且不稳定: RL 训练过程比 SFT 更复杂,需要 reward model、环境交互、策略梯度等组件,且容易出现训练不稳定、收敛困难等问题。
- Reward hacking (奖励作弊): 模型可能会学习到一些 “作弊” 策略,即最大化奖励,但却不符合人类的真实意图。 例如,为了获得高奖励,模型可能会生成冗长、重复、但表面上看起来 “helpful” 的文本。
- Reward model 偏差: reward model 的质量直接影响 RL 训练效果。 如果 reward model 存在偏差或噪声,会导致模型学习到错误的行为。
- 计算成本高昂: RL 训练通常需要大量的环境交互和样本采样,计算成本比 SFT 更高。
4. SFT 和 RL 的联系与区别
4.1 联系:相辅相成,共同提升LLM性能
SFT 和 RL 并非完全对立的技术,它们之间存在着密切的联系,可以相辅相成,共同提升 LLM 的性能。
- SFT 作为 RL 的预训练阶段: 通常情况下,我们会先使用 SFT 对 LLM 进行预训练,使其具备基本的文本生成能力和指令遵循能力。 SFT 模型可以作为 RL 的 初始化模型 (initial policy) 和 参考策略 (reference policy)。
- SFT 数据用于 reward model 训练: SFT 数据中高质量的人工标注输出可以作为 reward model 训练的 正样本 (positive examples),用来训练 reward model 识别 “好” 的文本。
- RL 提升 SFT 模型的对齐能力: SFT 模型虽然能够模仿人类文本风格,但难以学习人类价值观和偏好。 RL 可以基于 reward model 或偏好数据,进一步 微调 SFT 模型,使其更好地与人类价值观对齐,生成更 “helpful, honest, and harmless” 的文本。
总结: SFT 为 LLM 提供了 基础能力,RL 则在 SFT 的基础上 提升了模型的对齐性和高级能力。 两者通常结合使用,共同构建强大的 LLM。
4.2 区别:Loss Function 反映本质差异
SFT 和 RL 最核心的区别在于它们的 loss function 设计,这反映了它们本质上的学习目标和方法差异。
特征 | 监督微调 (SFT) | 强化学习 (RL) |
---|---|---|
学习目标 | 模仿人类示范,学习生成与标注数据相似的文本 | 最大化累积奖励,学习生成符合人类期望的最优行为 |
数据来源 | 人工标注数据 (输入-输出对) | 奖励信号 (reward model 或人类偏好数据) |
Loss Function | 交叉熵损失 (Cross-Entropy Loss) | 策略梯度损失 (Policy Gradient Loss, PPO Clip Loss, DPO Loss 等) |
监督信号 | 明确的“标准答案” (target output) | 间接的“奖励信号” (reward) |