【LLM强化学习】LLM 强化学习中 Critic 模型训练详解
在 基于人类反馈的强化学习 (Reinforcement Learning from Human Feedback, RLHF) 中,Critic 模型 (Critic Model) 扮演着至关重要的角色。它就像一位严苛的评分老师,负责评估 策略模型 (Policy Model) 生成的 responses 的质量,并提供 价值估计 (Value Estimation),指导策略模型向着更符合人类偏好的方向优化。
简单来说,Critic 模型的目标是学习一个价值函数 (Value Function),这个价值函数能够 预测在给定状态 (State) 下,采取某个行动 (Action) 所能获得的未来累积奖励 (Cumulative Reward) 的期望值。在 LLM 强化学习的语境下:
- 状态 (State): 通常是输入给 LLM 的 prompt。
- 行动 (Action): LLM 生成的 response。
- 价值 (Value): response 的 质量评分 或 人类偏好程度 的预测值。
本文将深入剖析 Critic 模型的训练细节,包括 输入输出设计、数据准备、模型结构选择、损失函数以及训练技巧,帮助你全面理解 Critic 模型的训练过程,并为构建高效的 RLHF 系统打下坚实基础。
1. Critic 模型的输入与输出:Prompt + Response = Value
Critic 模型的核心任务是 评估 (Evaluate)。它需要接收 输入 (Input),并输出一个 标量值 (Scalar Value),作为对输入质量的评估。
1.1 输入 (Input):Prompt 与 Response 的结合
为了准确评估 response 的价值,Critic 模型需要同时考虑 输入 Prompt 和 生成的 Response。仅仅知道 response 内容是不够的,因为 同一个 response,在不同的 prompt 下,价值可能大相径庭。
例如,对于 prompt “你好”,response “你好” 是一个合理的回复,价值较高;但对于 prompt “请写一篇关于量子力学的科普文章”,response “你好” 就显得非常离谱,价值极低。
因此,Critic 模型的 输入 (Input) 通常是 Prompt 和 Response 的拼接 (Concatenation)。常见的拼接方式包括:
- [Prompt, Response]: 直接将 Prompt 和 Response 拼接在一起。
- [Prompt, Separator Token, Response]: 在 Prompt 和 Response 之间插入一个特殊的分隔符 (Separator Token),例如
[SEP]
,以帮助模型区分 Prompt 和 Response 的边界。
def prepare_critic_input(prompt, response, tokenizer, separator_token="[SEP]"):
"""
准备 Critic 模型的输入
Args:
prompt (str): 输入 prompt
response (str): 模型生成的 response
tokenizer (transformers.Tokenizer): 分词器
separator_token (str): 分隔符 token
Returns:
torch.Tensor: Critic 模型输入 (token ids)
"""
input_text = prompt + separator_token + response # 拼接 prompt 和 response
input_ids = tokenizer.encode(input_text, truncation=True, max_length=512) # 分词和截断
input_tensor = torch.tensor(input_ids).unsqueeze(0) # 转换为 tensor 并添加 batch 维度
return input_tensor
# 示例
prompt = "请写一首关于春天的诗歌"
response = "春风拂面,杨柳依依。"
tokenizer = load_tokenizer() # 加载分词器
critic_input = prepare_critic_input(prompt, response, tokenizer)
print(critic_input)
1.2 输出 (Output):标量价值评估值
Critic 模型的 输出 (Output) 是一个 标量值 (Scalar Value),代表模型对输入 Prompt 和 Response 组合的 价值评估。这个标量值通常被解释为 未来累积奖励的期望值,或者 response 的质量评分,或者 人类偏好程度 的预测值。
输出值的范围和含义取决于奖励函数的定义和训练数据的标注方式。 常见的输出值范围和含义包括:
- [0, 1] 范围的概率值: 表示 response 符合人类偏好的概率,例如 1 表示完全符合,0 表示完全不符合。
- [-1, 1] 或 [-inf, inf] 范围的奖励值: 表示 response 的绝对质量评分,数值越大表示质量越高。
- 相对排序得分: 在偏好数据中,Critic 模型可能输出一个相对排序得分,用于比较不同 responses 之间的优劣。
def critic_model(input_tensor):
"""
Critic 模型前向传播 (简化示例)
Args:
input_tensor (torch.Tensor): Critic 模型输入 (token ids)
Returns:
torch.Tensor: 标量价值评估值
"""
# ... 模型内部结构 ...
value = linear_layer(last_hidden_state) # 线性层输出标量值
return value
# 示例
critic = load_critic_model() # 加载 Critic 模型
value_prediction = critic_model(critic_input)
print(value_prediction) # 输出标量价值评估值
2. Critic 模型的数据准备:偏好数据,胜负分明
Critic 模型的训练数据与策略模型不同,它需要 偏好数据 (Preference Data),即 人类对不同 responses 优劣的排序标注。偏好数据通常以 pairwise (配对) 的形式出现,例如:
- Prompt: “请写一首关于夏天的诗歌”
- Response A: “夏日炎炎,汗滴禾下土。” (更好)
- Response B: “夏天真热啊。” (更差)
人类标注者需要 比较 Response A 和 Response B,并指出哪个 response 更符合人类偏好。
2.1 数据收集方式
-
人工标注 (Human Annotation): 雇佣人工标注者,对模型生成的 responses 进行 pairwise 比较和排序标注。这是最直接、最可靠的偏好数据来源,但成本较高,效率较低。
-
模型辅助标注 (Model-Assisted Annotation): 使用一个 初始奖励模型 (Initial Reward Model) 或 规则系统 (Rule-Based System) 对 responses 进行初步筛选和排序,然后由人工标注者对筛选后的数据进行精细标注。这种方式可以降低人工标注成本,提高效率。
-
合成数据 (Synthetic Data): 在某些情况下,可以使用 规则 或 模拟器 生成合成的偏好数据。例如,在代码生成任务中,可以使用 代码编译器 和 测试用例 来判断代码的正确性和效率,从而生成偏好数据。
2.2 数据格式
偏好数据通常以 JSON 或 CSV 等格式存储,每个数据样本包含以下信息:
- Prompt: 输入 prompt。
- Response A (更好): 人类偏好较好的 response。
- Response B (更差): 人类偏好较差的 response。
[
{
"prompt": "请写一首关于夏天的诗歌",
"response_A": "夏日炎炎,汗滴禾下土。",
"response_B": "夏天真热啊。"
},
{
"prompt": "...",
"response_A": "...",
"response_B": "..."
},
...
]
3. Critic 模型的结构选择:Transformer 的变体
Critic 模型本质上是一个 回归模型 (Regression Model),其输入是文本序列 (Prompt + Response),输出是标量价值评估值。因此,Transformer 架构及其变体 是 Critic 模型的常用选择。
常见的 Critic 模型结构包括:
-
Transformer Encoder: 使用 Transformer Encoder 编码输入文本序列,然后接一个 线性层 (Linear Layer) 输出标量值。这种结构适用于处理 单向文本序列,例如评估单个 response 的质量。
-
Transformer Decoder (因果语言模型): 可以使用预训练的 因果语言模型 (Causal Language Model),例如 GPT 系列模型,作为 Critic 模型的基础结构。将输入文本序列输入模型,取 最后一个 token 的 hidden state,然后接一个 线性层 输出标量值。这种结构可以 复用预训练模型的知识,提高训练效率和性能。DeepSeek-V3 的 Reward Model 采用了这种结构。
-
双塔模型 (Dual-Tower Model): 使用 两个独立的 Transformer Encoder 分别编码 Prompt 和 Response,然后将两个 Encoder 的输出 融合 (Fusion),例如拼接或求和,再接一个 线性层 输出标量值。这种结构可以 更细致地建模 Prompt 和 Response 之间的交互关系。
class CriticModel(nn.Module): # 使用 Transformer Decoder (因果语言模型) 作为 Critic 模型
def __init__(self, pretrained_model_name):
super().__init__()
self.transformer = AutoModelForCausalLM.from_pretrained(pretrained_model_name) # 加载预训练因果语言模型
self.value_head = nn.Linear(self.transformer.config.hidden_size, 1) # 线性层输出标量值
def forward(self, input_tensor):
outputs = self.transformer(input_tensor, output_hidden_states=True) # 获取 hidden states
last_hidden_state = outputs.hidden_states[-1][:, -1, :] # 取最后一个 token 的 hidden state
value = self.value_head(last_hidden_state) # 线性层输出标量值
return value
# 示例
critic = CriticModel("gpt2") # 使用 GPT-2 作为 Critic 模型
value = critic(critic_input)
4. Critic 模型的损失函数:Pairwise Ranking Loss
由于 Critic 模型训练数据是 偏好数据 (Pairwise Preference Data),因此需要使用 Pairwise Ranking Loss (配对排序损失) 来训练 Critic 模型,使其能够学习到 responses 之间的相对排序关系。
4.1 Pairwise Ranking Loss 的原理
Pairwise Ranking Loss 的核心思想是:对于每个偏好数据样本 (Prompt, Response A, Response B),我们希望 Critic 模型给 Response A 的评分高于 Response B 的评分。
常见的 Pairwise Ranking Loss 函数包括:
-
Margin Ranking Loss: 最大化更好 response (Response A) 和更差 response (Response B) 评分之间的 margin (间隔)。
L_margin = max(0, 1 - critic(prompt, response_A) + critic(prompt, response_B))
-
Binary Cross-Entropy Loss (BCE Loss): 将排序问题转化为 二分类问题。对于每个偏好样本,将 Response A 视为正例,Response B 视为负例,训练 Critic 模型 预测 Response A 为正例的概率高于 Response B 为正例的概率。
L_bce = -log(sigmoid(critic(prompt, response_A) - critic(prompt, response_B)))
DeepSeek-V3 的 Reward Model 采用了 BCE Loss。
4.2 BCE Loss 的实现
def calculate_critic_loss_bce(critic_output_A, critic_output_B):
"""
计算 Critic 模型的 BCE Loss
Args:
critic_output_A (torch.Tensor): Critic 模型对 Response A 的输出 (标量值)
critic_output_B (torch.Tensor): Critic 模型对 Response B 的输出 (标量值)
Returns:
torch.Tensor: BCE Loss
"""
preference_logits = critic_output_A - critic_output_B # 计算偏好 logits
loss = -torch.log(torch.sigmoid(preference_logits)) # BCE Loss
return loss
# 示例
critic_output_A = critic_model(critic_input_A) # Critic 模型对 Response A 的输出
critic_output_B = critic_model(critic_input_B) # Critic 模型对 Response B 的输出
critic_loss = calculate_critic_loss_bce(critic_output_A, critic_output_B)
print(critic_loss)
5. Critic 模型的训练技巧
- 预训练模型初始化: 使用预训练的语言模型 (例如 GPT-2, BERT) 初始化 Critic 模型,可以加速训练,提高性能。
- 学习率调度: 采用合适的学习率调度策略 (例如 Cosine Annealing, Linear Decay),可以提高训练稳定性和收敛速度。
- 正则化: 添加正则化项 (例如 Weight Decay, Dropout),可以防止过拟合,提高模型的泛化能力。
- 数据增强: 对训练数据进行增强 (例如 Back Translation, Synonym Replacement),可以增加数据的多样性,提高模型的鲁棒性。
- Early Stopping: 监控验证集上的性能指标 (例如 Pairwise Accuracy, Ranking Loss),当性能不再提升时,提前停止训练,防止过拟合。