当前位置: 首页 > article >正文

【LLM系列6】DPO 训练

DPO(Direct Preference Optimization,直接偏好优化)提供了一种简化的训练方法,通过直接从人类偏好数据优化语言模型,而无需显式训练奖励模型(Reward Model)。其核心优势在于提升了训练效率和稳定性,避免了传统强化学习中间步骤的复杂性。

与RLHF的区别

  • RLHF:通过一个两步流程来进行训练,首先是训练奖励模型,然后用强化学习算法(如PPO)优化策略。这要求设计一个奖励模型来对策略的输出进行评分,进而指导模型优化。
  • DPO:不需要奖励模型,直接通过人类反馈的偏好数据来优化策略。策略本身就被看作是隐式奖励函数,因此训练过程是一步到位的,减少了流程的复杂性。

理论基础

DPO利用Bradley-Terry模型,通过偏好数据推导出优选((y_w))和次优((y_l))回答之间的对数概率差异。然后,通过对数似然损失函数来优化策略。

具体来说,损失函数的形式为:

[
L_{\text{DPO}} = -E_{(x, y_w, y_l)} \left[ \log \sigma \left( \beta \log \frac{\pi_\theta(y_w|x)}{\pi_{\text{ref}}(y_w|x)} - \beta \log \frac{\pi_\theta(y_l|x)}{\pi_{\text{ref}}(y_l|x)} \right) \right]
]

其中:

  • (\pi_\theta(y|x)):当前模型的策略。
  • (\pi_{\text{ref}}(y|x)):参考模型的策略(如监督微调后的模型)。
  • (\beta):温度系数,控制模型偏离参考模型的程度。

损失函数

DPO损失通过最大化优选与次优输出之间的偏好数据对数似然来优化模型,利用Sigmoid函数将奖励差值转换为概率,从而引导模型的优化方向。具体实现上,模型会基于优选和次优回答的对比来更新参数。

训练流程

  1. 初始化策略:基于一个已经经过监督微调的参考模型(如SFT模型)初始化策略。
  2. 偏好数据采样:使用人类提供的偏好数据集(每个输入(x)对应多个输出(y_1, y_2),并带有优选标签)。
  3. 单步优化:直接通过偏好数据的对数概率进行梯度下降,优化策略。

技术细节

  • 温度系数((\beta)):控制策略偏离参考模型的幅度。通常设置较小的值,如0.1到0.5,以避免过度偏离。
  • 学习率:学习率一般较低,建议在1e-6到5e-6之间,以避免过度训练和模型不稳定。

DPO代码实现

以下是一个基于PyTorch的简化版DPO实现代码:

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from torch.utils.data import Dataset, DataLoader

# 定义DPO数据集(输入x,优选y_w,次优y_l)
class DPODataset(Dataset):
    def __init__(self, data):
        self.data = data  # 示例数据格式:[{"x": "How are you?", "y_w": "I'm fine.", "y_l": "I'm bad."}, ...]
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        return self.data[idx]

# 加载参考模型(SFT模型)和待优化策略模型
model_name = "gpt2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
model_ref = AutoModelForCausalLM.from_pretrained(model_name)  # 固定参考模型

# 冻结参考模型参数
for param in model_ref.parameters():
    param.requires_grad = False

# 定义DPO损失函数
def dpo_loss(model, batch, beta=0.1):
    # 编码输入和回答
    x = [item["x"] for item in batch]
    y_w = [item["y_w"] for item in batch]
    y_l = [item["y_l"] for item in batch]
    
    # 计算优选和次优回答的对数概率
    def get_logps(texts, answers):
        logps = []
        for text, answer in zip(texts, answers):
            input_text = text + " " + answer
            inputs = tokenizer(input_text, return_tensors="pt")
            with torch.no_grad():
                logits = model(**inputs).logits
            # 计算对数概率(简化版,需适配具体模型结构)
            answer_ids = tokenizer.encode(answer, add_special_tokens=False)
            logp = torch.log_softmax(logits[0, -len(answer_ids)-1:-1], dim=-1)
            logp = logp.gather(-1, torch.tensor(answer_ids).unsqueeze(-1)).sum()
            logps.append(logp)
        return torch.stack(logps)
    
    logp_w = get_logps(x, y_w)
    logp_l = get_logps(x, y_l)
    logp_ref_w = get_logps(x, y_w)  # 参考模型的对数概率
    logp_ref_l = get_logps(x, y_l)
    
    # 计算损失
    log_ratio_w = logp_w - logp_ref_w
    log_ratio_l = logp_l - logp_ref_l
    loss = -torch.log(torch.sigmoid(beta * (log_ratio_w - log_ratio_l))).mean()
    return loss

# 训练循环
dataset = DPODataset([{"x": "How are you?", "y_w": "I'm fine.", "y_l": "I'm bad."}])  # 示例数据
dataloader = DataLoader(dataset, batch_size=2, shuffle=True)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-6)

model.train()
for epoch in range(3):
    for batch in dataloader:
        loss = dpo_loss(model, batch)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    print(f"Epoch {epoch}, Loss: {loss.item():.4f}")

重点文献

  1. DPO 原始论文(2023):Direct Preference Optimization: Your Language Model is Secretly a Reward Model
  2. RLHF 理论分析(2022):Training Language Models with Language Feedback
  3. 对比学习与偏好优化(2023):Contrastive Preference Learning: Learning from Human Feedback without RL

总结

DPO通过简化的优化流程和直接从偏好数据中学习的方式,大大提高了训练效率和稳定性。它避免了传统RLHF的奖励模型步骤,并能够更高效地利用偏好数据,尤其在无需复杂环境交互的离线训练场景中表现出色。


http://www.kler.cn/a/559961.html

相关文章:

  • 算法15--BFS
  • 【数据结构】 最大最小堆实现优先队列 python
  • 新民主主义革命理论的形成依据
  • Maven最新版安装教程
  • IO进程 day05
  • Nmap渗透测试指南:信息收集与安全扫描全攻略(2025最新版本)
  • 独立开发者Product Hunt打榜教程
  • WPF布局控件
  • 【C++】 stack和queue以及模拟实现
  • deepseek 导出导入模型(docker)
  • 计算机毕业设计SpringBoot+Vue.js医院管理系统(源码+文档+PPT+讲解)
  • Git与GitHub:深入理解与高效使用
  • 企业终端遭遇勒索病毒威胁?火绒企业版V2.0--企业用户勒索攻击防护建议!
  • wpf 页面切换的实现方式
  • HarmonyOS Next 计时器组件详解
  • 微信小程序 - 条件渲染(wx:if、hidden)与列表渲染(wx:for)
  • 跨境宠物摄像头是一种专为宠物主人设计的智能设备
  • Python的那些事第三十一篇:快速数据帧处理与可视化的高效工具Vaex
  • AWS Bedrock平台引入DeepSeek-R1 模型,推动深度学习
  • 量子计算的数学基础:复数、矩阵和线性代数