当前位置: 首页 > article >正文

unsloth-llama3-8b.py 中文备注版

# %% [markdown]
# 在免费的 Tesla T4 Google Colab 实例上运行此代码,点击 "Runtime" 然后点击 "Run all"
# <div class="align-center">
# <a href="https://unsloth.ai/"><img src="https://github.com/unslothai/unsloth/raw/main/images/unsloth%20new%20logo.png" width="115"></a>
# <a href="https://discord.gg/unsloth"><img src="https://github.com/unslothai/unsloth/raw/main/images/Discord button.png" width="145"></a>
# <a href="https://docs.unsloth.ai/"><img src="https://github.com/unslothai/unsloth/blob/main/images/documentation%20green%20button.png?raw=true" width="125"></a></a> 加入 Discord 获取帮助 + ⭐ <i>在 <a href="https://github.com/unslothai/unsloth">Github</a> 上给我们点赞 </i> ⭐
# </div>
# 
# 要在自己的电脑上安装 Unsloth,请参考我们 Github 页面上的[安装说明](https://docs.unsloth.ai/get-started/installing-+-updating)。
#
# 你将学习如何进行[数据准备](#Data)、[训练](#Train)、[运行模型](#Inference)和[保存模型](#Save)。

# %% [markdown] 
# ### 新闻

# %% [markdown]
# **阅读我们的[博客文章](https://unsloth.ai/blog/r1-reasoning)了解如何训练推理模型。**
#
# 访问我们的文档查看所有[模型上传](https://docs.unsloth.ai/get-started/all-our-models)和[笔记本](https://docs.unsloth.ai/get-started/unsloth-notebooks)。

# %% [markdown]
# ### 安装

# # %%
# %%capture
# 跳过 Colab 中的重启消息
# 知识点:sys.modules 是一个字典,包含所有已导入的模块
import sys; modules = list(sys.modules.keys())
# 移除所有包含 "PIL" 或 "google" 的模块
for x in modules: sys.modules.pop(x) if "PIL" in x or "google" in x else None

# 安装必要的包
# unsloth: 用于加速 LLM 训练的库
# vllm: 高性能的 LLM 推理引擎
# 参考: https://github.com/vllm-project/vllm
!pip install unsloth vllm
# 升级 pillow 以避免兼容性问题
!pip install --upgrade pillow

# %% [markdown]
# ### Unsloth

# %% [markdown]
# 在所有函数之前使用 `PatchFastRL` 来为 GRPO 和其他 RL 算法打补丁!

# %%
# 导入必要的类
from unsloth import FastLanguageModel, PatchFastRL
# 为 GRPO 和 FastLanguageModel 应用补丁
# 知识点:GRPO(Generative Reward Policy Optimization)是一种强化学习算法
# 参考:https://arxiv.org/abs/2307.10729
PatchFastRL("GRPO", FastLanguageModel)

# %% [markdown]
# 加载 `Llama 3.1 8B Instruct` 并设置参数

# %%
# 导入必要的库
from unsloth import is_bfloat16_supported
import torch

# 设置序列最大长度,可以增加以支持更长的推理过程
max_seq_length = 512 
# 设置 LoRA 秩,更大的秩意味着更智能但训练更慢
# 知识点:LoRA(Low-Rank Adaptation)是一种参数高效的微调方法
# 参考:https://arxiv.org/abs/2106.09685
lora_rank = 32 

# 加载预训练模型
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "meta-llama/meta-Llama-3.1-8B-Instruct",
    max_seq_length = max_seq_length,
    load_in_4bit = True, # 使用 4bit 量化,设为 False 则使用 16bit LoRA
    fast_inference = True, # 启用 vLLM 快速推理
    max_lora_rank = lora_rank,
    gpu_memory_utilization = 0.6, # 如果内存不足则降低此值
)

# 获取 PEFT 模型
# 知识点:PEFT(Parameter-Efficient Fine-Tuning)是一系列参数高效的微调方法的统称
# 参考:https://github.com/huggingface/peft
model = FastLanguageModel.get_peft_model(
    model,
    r = lora_rank, # LoRA 秩,建议值:8,16,32,64,128
    target_modules = [
        "q_proj", "k_proj", "v_proj", "o_proj",
        "gate_proj", "up_proj", "down_proj",
    ], # 如果内存不足可以移除 QKVO
    lora_alpha = lora_rank,
    use_gradient_checkpointing = "unsloth", # 启用长上下文微调
    random_state = 3407, # 设置随机种子以保证可重复性
)

# %% [markdown]
# ### 数据准备
# <a name="Data"></a>
# 
# 我们直接使用 [@willccbb](https://gist.github.com/willccbb/4676755236bb08cab5f4e54a0475d6fb) 的数据准备和奖励函数代码。你也可以创建自己的函数!

# %%
import re
from datasets import load_dataset, Dataset

# 加载和准备数据集
# 定义系统提示,指定输出格式
SYSTEM_PROMPT = """
Respond in the following format:
<reasoning>
...
</reasoning>
<answer>
...
</answer>
"""

# 定义 XML 格式的思维链模板
XML_COT_FORMAT = """\
<reasoning>
{reasoning}
</reasoning>
<answer>
{answer}
</answer>
"""

# 从文本中提取 XML 格式的答案
def extract_xml_answer(text: str) -> str:
    answer = text.split("<answer>")[-1]
    answer = answer.split("</answer>")[0]
    return answer.strip()

# 从文本中提取带有 #### 标记的答案
def extract_hash_answer(text: str) -> str | None:
    if "####" not in text:
        return None
    return text.split("####")[1].strip()

# 获取 GSM8K 数据集的问题
# 知识点:GSM8K 是一个数学问题数据集
# 参考:https://github.com/openai/grade-school-math
def get_gsm8k_questions(split = "train") -> Dataset:
    data = load_dataset('openai/gsm8k', 'main')[split] # type: ignore
    data = data.map(lambda x: { # type: ignore
        'prompt': [
            {'role': 'system', 'content': SYSTEM_PROMPT},
            {'role': 'user', 'content': x['question']}
        ],
        'answer': extract_hash_answer(x['answer'])
    }) # type: ignore
    return data # type: ignore

dataset = get_gsm8k_questions()

# 奖励函数
# 正确性奖励函数:检查答案是否正确
def correctness_reward_func(prompts, completions, answer, **kwargs) -> list[float]:
    responses = [completion[0]['content'] for completion in completions]
    q = prompts[0][-1]['content']
    extracted_responses = [extract_xml_answer(r) for r in responses]
    print('-'*20, f"Question:\n{q}", f"\nAnswer:\n{answer[0]}", f"\nResponse:\n{responses[0]}", f"\nExtracted:\n{extracted_responses[0]}")
    return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)]

# 整数奖励函数:检查答案是否为整数
def int_reward_func(completions, **kwargs) -> list[float]:
    responses = [completion[0]['content'] for completion in completions]
    extracted_responses = [extract_xml_answer(r) for r in responses]
    return [0.5 if r.isdigit() else 0.0 for r in extracted_responses]

# 严格格式奖励函数:检查输出是否完全符合指定格式
def strict_format_reward_func(completions, **kwargs) -> list[float]:
    pattern = r"^<reasoning>\n.*?\n</reasoning>\n<answer>\n.*?\n</answer>\n$"
    responses = [completion[0]["content"] for completion in completions]
    matches = [re.match(pattern, r) for r in responses]
    return [0.5 if match else 0.0 for match in matches]

# 宽松格式奖励函数:检查输出是否基本符合格式要求
def soft_format_reward_func(completions, **kwargs) -> list[float]:
    pattern = r"<reasoning>.*?</reasoning>\s*<answer>.*?</answer>"
    responses = [completion[0]["content"] for completion in completions]
    matches = [re.match(pattern, r) for r in responses]
    return [0.5 if match else 0.0 for match in matches]

# 计算 XML 标签的正确使用情况
def count_xml(text) -> float:
    count = 0.0
    if text.count("<reasoning>\n") == 1:
        count += 0.125
    if text.count("\n</reasoning>\n") == 1:
        count += 0.125
    if text.count("\n<answer>\n") == 1:
        count += 0.125
        count -= len(text.split("\n</answer>\n")[-1])*0.001
    if text.count("\n</answer>") == 1:
        count += 0.125
        count -= (len(text.split("\n</answer>")[-1]) - 1)*0.001
    return count

# XML 计数奖励函数:根据 XML 标签的使用情况给出奖励
def xmlcount_reward_func(completions, **kwargs) -> list[float]:
    contents = [completion[0]["content"] for completion in completions]
    return [count_xml(c) for c in contents]

# %% [markdown]
# <a name="Train"></a>
# ### 训练模型
# 
# 现在设置 GRPO Trainer 和所有配置!

# %%
# 导入必要的类
from trl import GRPOConfig, GRPOTrainer

# 设置训练参数
# 知识点:这些参数对模型训练的效果有重要影响
# 参考:https://huggingface.co/docs/transformers/main_classes/trainer
training_args = GRPOConfig(
    use_vllm = True, # 使用 vLLM 进行快速推理
    learning_rate = 5e-6, # 学习率
    adam_beta1 = 0.9, # Adam 优化器参数
    adam_beta2 = 0.99, # Adam 优化器参数
    weight_decay = 0.1, # 权重衰减,用于防止过拟合
    warmup_ratio = 0.1, # 预热比例
    lr_scheduler_type = "cosine", # 学习率调度器类型
    optim = "paged_adamw_8bit", # 优化器类型
    logging_steps = 1, # 日志记录步数
    bf16 = is_bfloat16_supported(), # 是否使用 bfloat16
    fp16 = not is_bfloat16_supported(), # 是否使用 fp16
    per_device_train_batch_size = 1, # 每个设备的训练批次大小
    gradient_accumulation_steps = 1, # 梯度累积步数,增加到 4 可使训练更平滑
    num_generations = 6, # 生成数量,如果内存不足则减少
    max_prompt_length = 256, # 最大提示长度
    max_completion_length = 200, # 最大完成长度
    # num_train_epochs = 1, # 训练轮数,设为 1 进行完整训练
    max_steps = 250, # 最大训练步数
    save_steps = 250, # 保存检查点的步数
    max_grad_norm = 0.1, # 最大梯度范数,用于梯度裁剪
    report_to = "none", # 可以使用 Weights & Biases
    output_dir = "outputs", # 输出目录
)

# %% [markdown]
# 让我们运行训练器!向上滚动,你会看到奖励表。目标是让 `reward` 列增加!
# 
# 你可能需要等待 150 到 200 步才能看到效果。前 100 步的奖励可能为 0。请耐心等待!
# 
# | Step | Training Loss | reward    | reward_std | completion_length | kl       |
# |------|---------------|-----------|------------|-------------------|----------|
# | 1    | 0.000000      | 0.125000  | 0.000000   | 200.000000        | 0.000000 |
# | 2    | 0.000000      | 0.072375  | 0.248112   | 200.000000        | 0.000000 |
# | 3    | 0.000000      | -0.079000 | 0.163776   | 182.500000        | 0.000005 |
# 

# %%
# 初始化并运行训练器
trainer = GRPOTrainer(
    model = model,
    processing_class = tokenizer,
    reward_funcs = [
        xmlcount_reward_func,
        soft_format_reward_func,
        strict_format_reward_func,
        int_reward_func,
        correctness_reward_func,
    ],
    args = training_args,
    train_dataset = dataset,
)
trainer.train()

# %% [markdown]
# <a name="Inference"></a>
# ### 推理
# 现在让我们测试刚刚训练的模型!首先,让我们尝试未经 GRPO 训练的模型:

# %%
# 准备输入文本
text = tokenizer.apply_chat_template([
    {"role" : "user", "content" : "Calculate pi."},
], tokenize = False, add_generation_prompt = True)

# 设置采样参数
# 知识点:这些参数影响文本生成的多样性和质量
# 参考:https://arxiv.org/abs/1904.09751
from vllm import SamplingParams
sampling_params = SamplingParams(
    temperature = 0.8, # 温度参数,控制生成的随机性
    top_p = 0.95, # 累积概率阈值
    max_tokens = 1024, # 最大生成标记数
)
# 生成文本
output = model.fast_generate(
    [text],
    sampling_params = sampling_params,
    lora_request = None,
)[0].outputs[0].text

output

# %% [markdown]
# 现在使用我们刚刚用 GRPO 训练的 LoRA - 我们先保存 LoRA!

# %%
# 保存 LoRA 权重
model.save_lora("grpo_saved_lora")

# %% [markdown]
# 现在我们加载 LoRA 并测试:

# %%
# 准备带有系统提示的输入文本
text = tokenizer.apply_chat_template([
    {"role" : "system", "content" : SYSTEM_PROMPT},
    {"role" : "user", "content" : "Calculate pi."},
], tokenize = False, add_generation_prompt = True)

# 设置采样参数
from vllm import SamplingParams
sampling_params = SamplingParams(
    temperature = 0.8,
    top_p = 0.95,
    max_tokens = 1024,
)
# 使用加载的 LoRA 生成文本
output = model.fast_generate(
    text,
    sampling_params = sampling_params,
    lora_request = model.load_lora("grpo_saved_lora"),
)[0].outputs[0].text

output

# %% [markdown]
# 我们的推理模型变得更好了 - 由于我们只训练了大约一个小时,它并不总是正确的 - 如果我们延长序列长度并训练更长时间,它会变得更好!

# %% [markdown]
# <a name="Save"></a>
# ### 保存为 float16 用于 VLLM
# 
# 我们也支持直接保存为 `float16`。选择 `merged_16bit` 用于 float16 或 `merged_4bit` 用于 int4。我们也允许使用 `lora` 适配器作为备选。使用 `push_to_hub_merged` 上传到你的 Hugging Face 账户!你可以在 https://huggingface.co/settings/tokens 获取个人令牌。

# %%
# 合并为 16bit
if False: model.save_pretrained_merged("model", tokenizer, save_method = "merged_16bit",)
if False: model.push_to_hub_merged("hf/model", tokenizer, save_method = "merged_16bit", token = "")

# 合并为 4bit
if False: model.save_pretrained_merged("model", tokenizer, save_method = "merged_4bit",)
if False: model.push_to_hub_merged("hf/model", tokenizer, save_method = "merged_4bit", token = "")

# 仅保存 LoRA 适配器
if False: model.save_pretrained_merged("model", tokenizer, save_method = "lora",)
if False: model.push_to_hub_merged("hf/model", tokenizer, save_method = "lora", token = "")

# %% [markdown]
# ### GGUF / llama.cpp 转换
# 我们现在原生支持保存为 `GGUF` / `llama.cpp`!我们克隆 `llama.cpp` 并默认保存为 `q8_0`。我们支持所有方法如 `q4_k_m`。使用 `save_pretrained_gguf` 进行本地保存,使用 `push_to_hub_gguf` 上传到 HF。
# 
# 一些支持的量化方法(完整列表见我们的 [Wiki 页面](https://github.com/unslothai/unsloth/wiki#gguf-quantization-options)):
# * `q8_0` - 快速转换。资源使用高,但通常可接受。
# * `q4_k_m` - 推荐。对 attention.wv 和 feed_forward.w2 张量的一半使用 Q6_K,其余使用 Q4_K。
# * `q5_k_m` - 推荐。对 attention.wv 和 feed_forward.w2 张量的一半使用 Q6_K,其余使用 Q5_K。
# 
# [**新功能**] 要微调并自动导出到 Ollama,试试我们的 [Ollama 笔记本](https://colab.research.google.com/drive/1WZDi7APtQ9VsvOrQSSC5DDtxq159j8iZ?usp=sharing)

# %%
# 保存为 8bit Q8_0
if False: model.save_pretrained_gguf("model", tokenizer,)
# 记得在 https://huggingface.co/settings/tokens 获取令牌!
# 并将 hf 改为你的用户名!
if False: model.push_to_hub_gguf("hf/model", tokenizer, token = "")

# 保存为 16bit GGUF
if False: model.save_pretrained_gguf("model", tokenizer, quantization_method = "f16")
if False: model.push_to_hub_gguf("hf/model", tokenizer, quantization_method = "f16", token = "")

# 保存为 q4_k_m GGUF
if False: model.save_pretrained_gguf("model", tokenizer, quantization_method = "q4_k_m")
if False: model.push_to_hub_gguf("hf/model", tokenizer, quantization_method = "q4_k_m", token = "")

# 保存为多个 GGUF 选项 - 如果你想要多个选项,这样更快
if False:
    model.push_to_hub_gguf(
        "hf/model", # 将 hf 改为你的用户名!
        tokenizer,
        quantization_method = ["q4_k_m", "q8_0", "q5_k_m",],
        token = "",
    )

# %% [markdown]
# 现在,在 llama.cpp 或基于 UI 的系统(如 Jan 或 Open WebUI)中使用 `model-unsloth.gguf` 文件或 `model-unsloth-Q4_K_M.gguf` 文件。你可以在[这里](https://github.com/janhq/jan)安装 Jan,在[这里](https://github.com/open-webui/open-webui)安装 Open WebUI。
# 
# 我们完成了!如果你对 Unsloth 有任何问题,我们有一个 [Discord](https://discord.gg/unsloth) 频道!如果你发现任何 bug 或想了解最新的 LLM 动态,或需要帮助,加入项目等,欢迎加入我们的 Discord!
# 
# 一些其他链接:
# 1. Llama 3.2 对话笔记本。[免费 Colab](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3.2_(1B_and_3B)-Conversational.ipynb)
# 2. 保存微调到 Ollama。[免费笔记本](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3_(8B)-Ollama.ipynb)
# 3. Llama 3.2 Vision 微调 - 放射学用例。[免费 Colab](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3.2_(11B)-Vision.ipynb)
# 6. 在我们的[文档](https://docs.unsloth.ai/get-started/unsloth-notebooks)中查看 DPO、ORPO、持续预训练、对话微调等笔记本!
# 
# <div class="align-center">
#   <a href="https://unsloth.ai"><img src="https://github.com/unslothai/unsloth/raw/main/images/unsloth%20new%20logo.png" width="115"></a>
#   <a href="https://discord.gg/unsloth"><img src="https://github.com/unslothai/unsloth/raw/main/images/Discord.png" width="145"></a>
#   <a href="https://docs.unsloth.ai/"><img src="https://github.com/unslothai/unsloth/blob/main/images/documentation%20green%20button.png?raw=true" width="125"></a>
# 
#   加入 Discord 获取帮助 + ⭐️ <i>在 <a href="https://github.com/unslothai/unsloth">Github</a> 上给我们点赞 </i> ⭐️
# </div>
# 


http://www.kler.cn/a/574461.html

相关文章:

  • 使用 Arduino 的 WiFi 控制机器人
  • 二、双指针——6. 三数之和
  • Python函数定义详细教程:参数类型详解,报错UnboundLocalError怎么解决。
  • 贪心算法一
  • aws(学习笔记第三十一课) aws cdk深入学习(batch-arm64-instance-type)
  • Java多线程与高并发专题——为什么 Map 桶中超过 8 个才转为红黑树?
  • PPT 小黑第20套
  • java8中young gc的垃圾回收器选型,您了解嘛
  • AI面板识别 - 华为OD统一考试(java)
  • 风控模型算法面试题集结
  • 面试基础--Spring Boot启动流程及源码实现
  • IDEA 2024.1.7 Java EE 无框架配置servlet
  • osg官方例子
  • React基础之插值
  • 蓝桥杯 Excel地址
  • 深入剖析 Kubernetes 弹性伸缩:HPA 与 Metrics Server
  • FPGA-按键消抖
  • 青训营:简易分布式爬虫
  • 171. Excel 表列序号
  • 【消费主义与性别角色重构】