【复现DeepSeek-R1之Open R1实战】系列6:GRPO源码逐行深度解析(上)
目录
- 4 GRPO源码分析
- 4.1 数据类 `GRPOScriptArguments`
- 4.2 系统提示字符串 `SYSTEM_PROMPT`
- 4.3 奖励函数
- 4.3.1 accuracy_reward函数
- 4.3.2 verify函数
- 4.3.3 format_reward函数
- 4.4 将数据集格式化为对话形式
- 4.5 初始化GRPO Trainer
【复现DeepSeek-R1之Open R1实战】系列3:SFT和GRPO源码逐行深度解析(上)
【复现DeepSeek-R1之Open R1实战】系列5:SFT和GRPO源码逐行深度解析(中)
4 GRPO源码分析
前面两篇博文已经详细介绍了一些基础知识和SFT源码,本文继续解读GRPO源码。与SFT源码差不多的部分,我们就不展开细说了,这里只解析GRPO独特的部分。
4.1 数据类 GRPOScriptArguments
该类使用了 Python 的 dataclass
装饰器,这是一种简化类定义的方式,特别是对于那些主要用来存储数据的类。它继承自 ScriptArguments
类。
-
reward_funcs: 这是一个列表,包含了一系列可能的奖励函数名称,默认值为
["accuracy", "format"]
。这些奖励函数可能是用于评估模型性能的不同标准。reward_funcs: list[str] = field( default_factory=lambda: ["accuracy", "format"], metadata={ "help": "List of reward functions. Possible values: 'accuracy', 'format', 'reasoning_steps', 'cosine', 'repetition_penalty', 'length'" }, )
-
cosine_min_value_wrong 和 cosine_max_value_wrong: 分别表示错误答案在余弦相似度尺度上的最小和最大奖励值,默认分别为
0.0
和-0.5
。 -
cosine_min_value_correct 和 cosine_max_value_correct: 分别表示正确答案在余弦相似度尺度上的最小和最大奖励值,默认分别为
0.5
和1.0
。 -
cosine_max_len: 表示余弦相似度尺度的最大长度,默认值为
1000
。 -
repetition_n_grams: 表示用于重复惩罚奖励的n-gram数量,默认值为
3
。 -
repetition_max_penalty: 表示重复惩罚奖励的最大负值,默认值为
-1.0
。
每个字段都使用了 field()
函数来定义其默认值和元数据(如帮助信息)。这有助于工具和库更好地理解和处理这些字段,例如生成命令行解析器时。
4.2 系统提示字符串 SYSTEM_PROMPT
SYSTEM_PROMPT = (
"A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant "
"first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning "
"process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., "
"<think> reasoning process here </think><answer> answer here </answer>"
)
字符串描述了一个对话场景,用户先提问,助手首先思考推理过程,然后提供答案。推理过程和答案分别用 <think>
和 <answer>
标签包裹,这种格式化有助于区分和识别不同的部分,和DeepSeek-R1的思考过程格式一致。
4.3 奖励函数
奖励函数的定义如下,GRPO默认用到了accuracy_reward和format_reward这两个函数。
# Get reward functions
REWARD_FUNCS_REGISTRY = {
"accuracy": accuracy_reward,
"format": format_reward,
"reasoning_steps": reasoning_steps_reward,
"cosine": get_cosine_scaled_reward(
min_value_wrong=script_args.cosine_min_value_wrong,
max_value_wrong=script_args.cosine_max_value_wrong,
min_value_correct=script_args.cosine_min_value_correct,
max_value_correct=script_args.cosine_max_value_correct,
max_len=script_args.cosine_max_len,
),
"repetition_penalty": get_repetition_penalty_reward(
ngram_size=script_args.repetition_n_grams,
max_penalty=script_args.repetition_max_penalty,
),
"length": len_reward,
}
reward_funcs = [REWARD_FUNCS_REGISTRY[func] for func in script_args.reward_funcs]
这段代码定义了一个奖励函数注册表 REWARD_FUNCS_REGISTRY
,并根据用户提供的配置动态生成一个奖励函数列表 reward_funcs
。每个奖励函数用于评估模型输出的不同方面,如准确性、格式、推理步骤等。
- 注册表定义
accuracy
: 使用accuracy_reward
函数评估模型输出的准确性。format
: 使用format_reward
函数评估模型输出的格式。reasoning_steps
: 使用reasoning_steps_reward
函数评估模型输出的推理步骤。cosine
: 使用get_cosine_scaled_reward
函数计算余弦相似度奖励,参数包括:min_value_wrong
: 错误情况下的最小值。max_value_wrong
: 错误情况下的最大值。min_value_correct
: 正确情况下的最小值。max_value_correct
: 正确情况下的最大值。max_len
: 最大长度。
repetition_penalty
: 使用get_repetition_penalty_reward
函数计算重复惩罚奖励,参数包括:ngram_size
: n-gram 的大小。max_penalty
: 最大惩罚值。
length
: 使用len_reward
函数评估模型输出的长度。
- 动态生成奖励函数列表
reward_funcs = [REWARD_FUNCS_REGISTRY[func] for func in script_args.reward_funcs]
- 根据
script_args.reward_funcs
中指定的奖励函数名称,从REWARD_FUNCS_REGISTRY
中获取相应的奖励函数,并生成一个列表reward_funcs
。
4.3.1 accuracy_reward函数
该函数用于计算模型生成的补全与真实答案之间的准确性奖励。它通过解析和验证生成的内容与真实答案来确定奖励值。
def accuracy_reward(completions, solution, **kwargs):
"""Reward function that checks if the completion is the same as the ground truth."""
contents = [completion[0]["content"] for completion in completions]
rewards = []
for content, sol in zip(contents, solution):
gold_parsed = parse(
sol,
extraction_mode="first_match",
extraction_config=[LatexExtractionConfig()],
)
if len(gold_parsed) != 0:
# We require the answer to be provided in correct latex (no malformed operators)
answer_parsed = parse(
content,
extraction_config=[
LatexExtractionConfig(
normalization_config=NormalizationConfig(
nits=False,
malformed_operators=False,
basic_latex=True,
equations=True,
boxed="all",
units=True,
),
# Ensures that boxed is tried first
boxed_match_priority=0,
try_extract_without_anchor=False,
)
],
extraction_mode="first_match",
)
# Reward 1 if the content is the same as the ground truth, 0 otherwise
reward = float(verify(answer_parsed, gold_parsed))
else:
# If the gold solution is not parseable, we reward 1 to skip this example
reward = 1.0
print("Failed to parse gold solution: ", sol)
rewards.append(reward)
return rewards
- completions (
list
): 包含多个补全结果的列表,每个补全结果是一个包含内容的字典列表。 - solution (
list
): 真实答案的列表。 - kwargs: 其他可选参数(在本函数中未使用)。
-
提取补全内容
contents = [completion[0]["content"] for completion in completions]
- 从
completions
列表中提取每个补全的第一个内容(假设每个补全是单个元素的列表),形成一个新的contents
列表。
- 从
-
初始化奖励列表
rewards = []
-
遍历每个补全和对应的真实答案
for content, sol in zip(contents, solution): gold_parsed = parse( sol, extraction_mode="first_match", extraction_config=[LatexExtractionConfig()], )
- 使用
zip
函数将contents
和solution
配对。 - 对于每一对补全内容和真实答案,首先解析真实答案
sol
,使用parse
函数提取其中的信息。
- 使用
-
处理解析结果
if len(gold_parsed) != 0: answer_parsed = parse( content, extraction_config=[ LatexExtractionConfig( normalization_config=NormalizationConfig( nits=False, malformed_operators=False, basic_latex=True, equations=True, boxed="all", units=True, ), # Ensures that boxed is tried first boxed_match_priority=0, try_extract_without_anchor=False, ) ], extraction_mode="first_match", )
- 如果解析得到的真实答案
gold_parsed
非空,则继续解析生成的补全内容content
。 - 使用
LatexExtractionConfig
和NormalizationConfig
进行详细配置,确保解析过程中考虑了各种格式要求(如方程、单位等)。
- 如果解析得到的真实答案
-
计算奖励
reward = float(verify(answer_parsed, gold_parsed))
- 使用
verify
函数比较生成的补全解析结果和真实答案的解析结果。 - 如果两者匹配,则返回
1.0
,否则返回0.0
。
- 使用
-
处理无法解析的情况
else: reward = 1.0 print("Failed to parse gold solution: ", sol)
- 如果真实答案无法解析,则默认给予奖励
1.0
并打印警告信息。
- 如果真实答案无法解析,则默认给予奖励
-
添加奖励到列表
rewards.append(reward)
-
返回所有奖励
return rewards
4.3.2 verify函数
该函数用于验证目标表达式是否与参考表达式匹配,它通过多种比较策略来处理不同的数学对象(如数字、表达式、集合、矩阵等),并提供灵活的配置选项以适应不同的需求。
def verify(
gold: list[Basic | MatrixBase | str] | Basic | MatrixBase | str,
target: list[Basic | MatrixBase | str] | Basic | MatrixBase | str,
float_rounding: int=6,
numeric_precision: int=15,
strict: bool=True,
timeout_seconds: int=3
) -> bool:
- gold: 参考或正确的表达式,可以是单个 SymPy 表达式(
Basic
或MatrixBase
)、字符串或这些类型的列表。 - target: 需要验证的表达式,类型同
gold
。 - float_rounding: 浮点数舍入的小数位数,默认为 6。
- numeric_precision: 数值比较时考虑的小数位数,默认为 15。
- strict: 是否启用严格比较模式,默认为
True
。- 在严格模式下:变量很重要,集合不可与元组比较。
- 在非严格模式下:变量按位置匹配,集合可与元组比较。
- timeout_seconds: 单次比较操作的最大超时时间(秒),默认为 3 秒。
-
定义内部比较函数
compare_single_extraction
@timeout(timeout_seconds=timeout_seconds) def compare_single_extraction(gold: Basic | MatrixBase | str, target: Basic | MatrixBase | str) -> bool: ...
- 使用装饰器
@timeout
设置超时保护,默认超时时间为timeout_seconds
。 - 比较两个表达式:
- 如果两者都是 SymPy 表达式(
Basic
或MatrixBase
),则调用sympy_expr_eq
进行比较。 - 如果两者都是字符串,则进行简单的字符串比较。
- 如果两者都是 SymPy 表达式(
- 使用装饰器
-
定义包装函数
compare_single_extraction_wrapper
def compare_single_extraction_wrapper(g, t): try: return compare_single_extraction(g, t) except Exception as e: logger.exception(f"Error comparing {g} and {t}") return False
- 包装
compare_single_extraction
,捕获并记录任何异常,返回False
以避免程序中断。
- 包装
-
处理输入列表
if not isinstance(gold, list): gold = [gold] if not isinstance(target, list): target = [target]
- 如果
gold
或target
不是列表,则将其转换为单元素列表,以便统一处理。
- 如果
-
组合所有可能的比较
return any(compare_single_extraction_wrapper(g, t) for g, t in product(gold, target))
- 使用
itertools.product
生成所有可能的gold
和target
组合。 - 对每个组合调用
compare_single_extraction_wrapper
,如果任意一对匹配成功,则返回True
。
- 使用
4.3.3 format_reward函数
函数用于检查给定的完成文本是否符合特定的格式,它验证完成文本是否包含 <think>
和 <answer>
标签,并且这两个标签的内容是非空的。
def format_reward(completions, **kwargs):
"""Reward function that checks if the completion has a specific format."""
pattern = r"^<think>.*?</think>\s*<answer>.*?</answer>$"
completion_contents = [completion[0]["content"] for completion in completions]
matches = [re.match(pattern, content, re.DOTALL | re.MULTILINE) for content in completion_contents]
return [1.0 if match else 0.0 for match in matches]
- completions: 这是一个列表,其中每个元素都是一个包含完成内容的对象(通常是字典)。假设每个完成对象的第一个元素包含一个键
"content"
,其值是需要检查的文本。 - kwargs: 其他关键字参数,这里没有使用,但可以为未来的扩展提供灵活性。
-
正则表达式模式定义:
pattern = r"^<think>.*?</think>\s*<answer>.*?</answer>$"
- 这个正则表达式用于匹配字符串是否以
<think>
开始,紧接着是任意字符(非贪婪匹配),然后是</think>
,接着可能有任意数量的空白字符(包括换行符),最后是以<answer>
开始并以</answer>
结束。 .*?
是非贪婪匹配,确保尽可能少地匹配字符。\s*
匹配零个或多个空白字符(包括换行符)。re.DOTALL | re.MULTILINE
标志允许点号.
匹配所有字符(包括换行符),并且使多行文本中的每一行都可以独立匹配。
- 这个正则表达式用于匹配字符串是否以
-
提取完成内容:
completion_contents = [completion[0]["content"] for completion in completions]
- 这里通过列表推导式从
completions
列表中提取每个完成对象的第一个元素的"content"
字段,形成一个新的列表completion_contents
。
- 这里通过列表推导式从
-
匹配正则表达式:
matches = [re.match(pattern, content, re.DOTALL | re.MULTILINE) for content in completion_contents]
- 使用
re.match
函数对completion_contents
中的每个内容应用正则表达式模式。 matches
列表将包含re.Match
对象(如果匹配成功)或None
(如果匹配失败)。
- 使用
-
生成奖励分数:
return [1.0 if match else 0.0 for match in matches]
- 最后一步是根据匹配结果生成奖励分数。如果匹配成功(即
match
不是None
),则返回1.0
;否则返回0.0
。
- 最后一步是根据匹配结果生成奖励分数。如果匹配成功(即
示例代码:
completions = [
[{"content": "<think>This is reasoning.</think><answer>This is answer.</answer>"}],
[{"content": "<think>This is reasoning.</think>"}],
[{"content": "<answer>This is answer.</answer>"}],
[{"content": "This does not match."}]
]
reward_scores = format_reward(completions)
print(reward_scores) # 输出: [1.0, 0.0, 0.0, 0.0]
在这个例子中:
- 第一个完成内容完全匹配正则表达式,因此得分为
1.0
。 - 后三个完成内容不符合要求,因此得分均为
0.0
。
4.4 将数据集格式化为对话形式
# Format into conversation
def make_conversation(example):
return {
"prompt": [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": example["problem"]},
],
}
dataset = dataset.map(make_conversation)
for split in dataset:
if "messages" in dataset[split].column_names:
dataset[split] = dataset[split].remove_columns("messages")
将一个数据集中的每个示例转换为对话格式,并确保数据集中没有多余的列(如 messages
)。
- 输入:
example
是一个字典,包含单个数据样本的信息,其中"problem"
键对应的值是用户的问题或任务描述。 - 输出:返回一个新的字典,包含一个
"prompt"
键,其值是一个对话列表:- 第一条消息是系统消息,内容由
SYSTEM_PROMPT
定义。 - 第二条消息是用户消息,内容是
example["problem"]
。
- 第一条消息是系统消息,内容由
dataset.map(make_conversation)
:使用map
方法将make_conversation
函数应用到数据集的每个示例上,生成新的对话格式。- 移除多余列:遍历数据集的每个拆分(split),如果存在
"messages"
列,则将其移除。
4.5 初始化GRPO Trainer
trainer = GRPOTrainer(
model=model_args.model_name_or_path,
reward_funcs=reward_funcs,
args=training_args,
train_dataset=dataset[script_args.dataset_train_split],
eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None,
peft_config=get_peft_config(model_args),
callbacks=get_callbacks(training_args, model_args),
)
篇幅有限,训练部分的代码我们放到下一篇博文详细解读!