当前位置: 首页 > article >正文

LLM - 使用 LLaMA-Factory 微调 Qwen2-VL DPO(LoRA) 图像数据集 教程 (3)

欢迎关注我的CSDN:https://spike.blog.csdn.net/
本文地址:https://spike.blog.csdn.net/article/details/143725947

免责声明:本文来源于个人知识与公开资料,仅用于学术交流,欢迎讨论,不支持转载。


RLHF

DPO(Direct Preference Optimization, 直接偏好优化) 是在 RLHF 阶段中使用的优化算法,通过直接利用人类的偏好数据来优化策略模型,无需定义明确的奖励函数或进行复杂的强化学习过程。DPO的优化目标是,增加偏好样本的对数概率与减小非偏好样本响应的对数概率,结合动态加权机制,以避免仅使用概率比目标时遇到的模型退化问题。

DPO 公式参考:大模型训练 RLHF 阶段的 PPO/DPO 策略公式与源码

相关文章:

  1. 使用 LLaMA-Factory 微调大模型 环境配置与训练推理
  2. 使用 LLaMA-Factory 微调 Qwen2-VL SFT(LoRA) 图像数据集
  3. 使用 LLaMA-Factory 微调 Qwen2-VL DPO(LoRA) 图像数据集

在 LlamaFactory 中的 DPO 训练脚本,参考 examples/train_lora/qwen2vl_lora_dpo.yaml

需要准备数据集:rlhf_v,参考 data/dataset_info.json 的数据集信息,即:

"rlhf_v": {
  "hf_hub_url": "llamafactory/RLHF-V",
  "ranking": true,
  "formatting": "sharegpt",
  "columns": {
    "messages": "conversations",
    "chosen": "chosen",
    "rejected": "rejected",
    "images": "images"
  }
},

使用 HuggingFace 下载数据集 llamafactory/RLHF-V,即:

cd [your folder]/llm/datasets/
huggingface-cli download --token [your token] --repo-type dataset llamafactory/RLHF-V --local-dir llamafactory/RLHF-V

注意:huggingface-cli 的数据集参数 --repo-type dataset

配置 llama_factory 的 jupyter 环境,即:

pip install ipykernel
python -m ipykernel install --user --name llama_factory --display-name "llama_factory"

待训练的数据集是 rlhf-v.parquet,大约 348 M,包括 5733 条样本,显示样本信息,包括:

  • conversations,对话,也就是 LlamaFactory 的 messages
  • chosen,正样本
  • rejected,负样本
  • images,图像列表,注意是 bytes 格式
import pandas as pd

# 读取 Parquet 文件
file_path = '[your folder]/llm/datasets/llamafactory/RLHF-V/rlhf-v.parquet'
df = pd.read_parquet(file_path)

first_row = df.iloc[0] 
print(f"[Info] first_row: \n{first_row}")

显示信息:

print(f"[Info] first_row conversations: \n{first_row['conversations']}")
print(f"[Info] first_row chosen: \n{first_row['chosen']}")
print(f"[Info] first_row chosen: \n{first_row['rejected']}")
from IPython.display import display, Image
# 假设你有一个 bytes 形式的图像数据
image_bytes = first_row['images'][0]['bytes']
# 使用 IPython.display 显示图像
display(Image(data=image_bytes))

输出:

[{'from': 'human', 'value': '<image>What are the key features you observe in the image?'}]
{'from': 'gpt', 'value': 'A young man standing on stage wearing a white shirt and black pants.'}
{'from': 'gpt', 'value': 'A young man standing on stage wearing white pants and shoes.'}

问题是描述图像的关键特征,正确是穿着白色T恤和黑色裤子,错误是穿着白色裤子和鞋,让模型辨识主体颜色,图像如下:

Img

训练模型:

CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 llamafactory-cli train [your folder]/llm/LLaMA-Factory/examples/train_lora/qwen2vl_lora_dpo_my20241112.yaml

A100 的显存是 81920M,第0卡显存占用最大,大约 58.122G (81920 * 1024 * 1024 * 67.66%),8 卡 A100 消耗 388G (58 + 48 *6 + 42)

Image

在训练和验证指标中,除了 Loss 之外,包括 3 个部分,即 logits、rewards、logps,其中每个里面包括 chosen 和 rejected 两个部分。

  • rewards/chosen,chosen 的 policy_chosen_logps - reference_chosen_logps,值 0 ~ -12
  • rewards/rejected,rejected 的 policy_rejected_logps - reference_rejected_logps,值是 0 ~ -17
  • rewards/accuracies,正确率,0~1之间。
  • rewards/margins,差距的均值,逐渐增大,0~6之间。
  • logps/rejected,logps均值,-120~-250
  • logps/chosen,logps均值,-100~-280
  • logits/rejected,logits,-2.4 上下震荡
  • logits/chosen,logits,-2.4 上下震荡

源码:

metrics[f"{prefix}rewards/chosen"] = chosen_rewards.mean().item()
metrics[f"{prefix}rewards/rejected"] = rejected_rewards.mean().item()
metrics[f"{prefix}rewards/accuracies"] = (chosen_rewards > rejected_rewards).float().mean().item()
metrics[f"{prefix}rewards/margins"] = (chosen_rewards - rejected_rewards).mean().item()
metrics[f"{prefix}logps/rejected"] = policy_chosen_logps.mean().item()
metrics[f"{prefix}logps/chosen"] = policy_rejected_logps.mean().item()
metrics[f"{prefix}logits/rejected"] = policy_chosen_logits.mean().item()
metrics[f"{prefix}logits/chosen"] = policy_rejected_logits.mean().item()

logits 定义: 在生成任务(如语言生成)中,模型在每一步的输出是一个 logits 向量,每个 logits 值对应于某个词汇(token)的预激活值。logits 经过 softmax 变换后,得到的是当前时刻每个词汇的概率分布。logps,即 Log Probabilities。

rewards 曲线:

rewards

logpslogits 曲线:

logps

logitslogps 的计算公式,参考源码 src/llamafactory/train/trainer_utils.py

def get_batch_logps(
    logits: "torch.Tensor", labels: "torch.Tensor", label_pad_token_id: int = IGNORE_INDEX
) -> Tuple["torch.Tensor", "torch.Tensor"]:
    r"""
    Computes the log probabilities of the given labels under the given logits.

    Returns:
        logps: A tensor of shape (batch_size,) containing the sum of log probabilities.
        valid_length: A tensor of shape (batch_size,) containing the number of non-masked tokens.
    """
    if logits.shape[:-1] != labels.shape:
        raise ValueError("Logits (batchsize x seqlen) and labels must have the same shape.")

    labels = labels[:, 1:].clone()
    logits = logits[:, :-1, :]
    loss_mask = labels != label_pad_token_id
    labels[labels == label_pad_token_id] = 0  # dummy token
    per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2)
    return (per_token_logps * loss_mask).sum(-1), loss_mask.sum(-1)

Loss 是下降的:

Loss

参考 TorchTune - DPO 源码

参考文章:

  • DPO: Direct Preference Optimization 直接偏好优化 笔记
  • 知乎 - ChatGPT训练,你需要知道的奖励模型(Reward Model)
  • 知乎 - 拆解大语言模型RLHF中的PPO
  • 知乎 - 精调VLM(视觉语言模型)的经验
  • 知乎 - 一文读懂信息熵、交叉熵和相对熵(KL散度)
  • 知乎 - 史上最详细的强化学习DPO算法详解
  • 知乎 - dpo 的局限性
  • 知乎 - DPO 是如何简化 RLHF 的
  • Reinforcement Learning from Human Feedback (RLHF) - a simplified explanation
  • GitHub - OpenRLHF/OpenRLHF

http://www.kler.cn/a/416555.html

相关文章:

  • 关于函数式接口和编程的解析和案例实战
  • 进程的知识
  • 行驶证 OCR 识别API接口的影响因素
  • Avalonia11中读取外部配置文件
  • uniapp的video组件截图(抓拍)功能,解决截后为黑图bug
  • 工程企业如何做好成本控制?该如何入手?
  • 力扣 最长回文字串-5
  • EXCEL截取某一列从第一个字符开始到特定字符结束的字符串到新的一列
  • Websocket——化神篇
  • 解决 PyTorch Upsample 属性错误:方法与最佳实践
  • 在并发情况下,Elasticsearch如果保证读写一致?
  • redis中的哨兵
  • vue3.0 根据富文本html页面生成压缩包(含视频在线地址、图片在线地址、前端截图、前端文档)
  • NeurIPS 2024 有效投稿达 15,671 篇,数据集版块内容丰富
  • MySQL 性能:基准测试工具包(BMK-kit)
  • Java开发工程师最新面试题库系列——Java基础部分(附答案)
  • 深入浅出:开发者如何快速上手Web3生态系统
  • C++调用QML函数的两种方法
  • 计算机毕业设计Python+LSTM天气预测系统 AI大模型问答 vue.js 可视化大屏 机器学习 深度学习 Hadoop Spark
  • C++基础:muduo库学习记录
  • 格网法计算平面点云面积(matlab版本)
  • 考试排名(一)(结构体专题)
  • 2024年11月一区SCI-Alpha evolution-附Matlab免费代码
  • javax.net.ssl.SSLHandshakeException: Received fatal alert: protocol_version
  • DM-VIO(ROS)+t265配置运行记录(ubuntu18.04+ros melodic)
  • Maven - 优雅的管理多模块应用的统一版本号