unsloth-llama3-8b.py 中文备注版
# %% [markdown]
# 在免费的 Tesla T4 Google Colab 实例上运行此代码,点击 "Runtime" 然后点击 "Run all"
# <div class="align-center">
# <a href="https://unsloth.ai/"><img src="https://github.com/unslothai/unsloth/raw/main/images/unsloth%20new%20logo.png" width="115"></a>
# <a href="https://discord.gg/unsloth"><img src="https://github.com/unslothai/unsloth/raw/main/images/Discord button.png" width="145"></a>
# <a href="https://docs.unsloth.ai/"><img src="https://github.com/unslothai/unsloth/blob/main/images/documentation%20green%20button.png?raw=true" width="125"></a></a> 加入 Discord 获取帮助 + ⭐ <i>在 <a href="https://github.com/unslothai/unsloth">Github</a> 上给我们点赞 </i> ⭐
# </div>
#
# 要在自己的电脑上安装 Unsloth,请参考我们 Github 页面上的[安装说明](https://docs.unsloth.ai/get-started/installing-+-updating)。
#
# 你将学习如何进行[数据准备](#Data)、[训练](#Train)、[运行模型](#Inference)和[保存模型](#Save)。
# %% [markdown]
# ### 新闻
# %% [markdown]
# **阅读我们的[博客文章](https://unsloth.ai/blog/r1-reasoning)了解如何训练推理模型。**
#
# 访问我们的文档查看所有[模型上传](https://docs.unsloth.ai/get-started/all-our-models)和[笔记本](https://docs.unsloth.ai/get-started/unsloth-notebooks)。
# %% [markdown]
# ### 安装
# # %%
# %%capture
# 跳过 Colab 中的重启消息
# 知识点:sys.modules 是一个字典,包含所有已导入的模块
import sys; modules = list(sys.modules.keys())
# 移除所有包含 "PIL" 或 "google" 的模块
for x in modules: sys.modules.pop(x) if "PIL" in x or "google" in x else None
# 安装必要的包
# unsloth: 用于加速 LLM 训练的库
# vllm: 高性能的 LLM 推理引擎
# 参考: https://github.com/vllm-project/vllm
!pip install unsloth vllm
# 升级 pillow 以避免兼容性问题
!pip install --upgrade pillow
# %% [markdown]
# ### Unsloth
# %% [markdown]
# 在所有函数之前使用 `PatchFastRL` 来为 GRPO 和其他 RL 算法打补丁!
# %%
# 导入必要的类
from unsloth import FastLanguageModel, PatchFastRL
# 为 GRPO 和 FastLanguageModel 应用补丁
# 知识点:GRPO(Generative Reward Policy Optimization)是一种强化学习算法
# 参考:https://arxiv.org/abs/2307.10729
PatchFastRL("GRPO", FastLanguageModel)
# %% [markdown]
# 加载 `Llama 3.1 8B Instruct` 并设置参数
# %%
# 导入必要的库
from unsloth import is_bfloat16_supported
import torch
# 设置序列最大长度,可以增加以支持更长的推理过程
max_seq_length = 512
# 设置 LoRA 秩,更大的秩意味着更智能但训练更慢
# 知识点:LoRA(Low-Rank Adaptation)是一种参数高效的微调方法
# 参考:https://arxiv.org/abs/2106.09685
lora_rank = 32
# 加载预训练模型
model, tokenizer = FastLanguageModel.from_pretrained(
model_name = "meta-llama/meta-Llama-3.1-8B-Instruct",
max_seq_length = max_seq_length,
load_in_4bit = True, # 使用 4bit 量化,设为 False 则使用 16bit LoRA
fast_inference = True, # 启用 vLLM 快速推理
max_lora_rank = lora_rank,
gpu_memory_utilization = 0.6, # 如果内存不足则降低此值
)
# 获取 PEFT 模型
# 知识点:PEFT(Parameter-Efficient Fine-Tuning)是一系列参数高效的微调方法的统称
# 参考:https://github.com/huggingface/peft
model = FastLanguageModel.get_peft_model(
model,
r = lora_rank, # LoRA 秩,建议值:8,16,32,64,128
target_modules = [
"q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj",
], # 如果内存不足可以移除 QKVO
lora_alpha = lora_rank,
use_gradient_checkpointing = "unsloth", # 启用长上下文微调
random_state = 3407, # 设置随机种子以保证可重复性
)
# %% [markdown]
# ### 数据准备
# <a name="Data"></a>
#
# 我们直接使用 [@willccbb](https://gist.github.com/willccbb/4676755236bb08cab5f4e54a0475d6fb) 的数据准备和奖励函数代码。你也可以创建自己的函数!
# %%
import re
from datasets import load_dataset, Dataset
# 加载和准备数据集
# 定义系统提示,指定输出格式
SYSTEM_PROMPT = """
Respond in the following format:
<reasoning>
...
</reasoning>
<answer>
...
</answer>
"""
# 定义 XML 格式的思维链模板
XML_COT_FORMAT = """\
<reasoning>
{reasoning}
</reasoning>
<answer>
{answer}
</answer>
"""
# 从文本中提取 XML 格式的答案
def extract_xml_answer(text: str) -> str:
answer = text.split("<answer>")[-1]
answer = answer.split("</answer>")[0]
return answer.strip()
# 从文本中提取带有 #### 标记的答案
def extract_hash_answer(text: str) -> str | None:
if "####" not in text:
return None
return text.split("####")[1].strip()
# 获取 GSM8K 数据集的问题
# 知识点:GSM8K 是一个数学问题数据集
# 参考:https://github.com/openai/grade-school-math
def get_gsm8k_questions(split = "train") -> Dataset:
data = load_dataset('openai/gsm8k', 'main')[split] # type: ignore
data = data.map(lambda x: { # type: ignore
'prompt': [
{'role': 'system', 'content': SYSTEM_PROMPT},
{'role': 'user', 'content': x['question']}
],
'answer': extract_hash_answer(x['answer'])
}) # type: ignore
return data # type: ignore
dataset = get_gsm8k_questions()
# 奖励函数
# 正确性奖励函数:检查答案是否正确
def correctness_reward_func(prompts, completions, answer, **kwargs) -> list[float]:
responses = [completion[0]['content'] for completion in completions]
q = prompts[0][-1]['content']
extracted_responses = [extract_xml_answer(r) for r in responses]
print('-'*20, f"Question:\n{q}", f"\nAnswer:\n{answer[0]}", f"\nResponse:\n{responses[0]}", f"\nExtracted:\n{extracted_responses[0]}")
return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)]
# 整数奖励函数:检查答案是否为整数
def int_reward_func(completions, **kwargs) -> list[float]:
responses = [completion[0]['content'] for completion in completions]
extracted_responses = [extract_xml_answer(r) for r in responses]
return [0.5 if r.isdigit() else 0.0 for r in extracted_responses]
# 严格格式奖励函数:检查输出是否完全符合指定格式
def strict_format_reward_func(completions, **kwargs) -> list[float]:
pattern = r"^<reasoning>\n.*?\n</reasoning>\n<answer>\n.*?\n</answer>\n$"
responses = [completion[0]["content"] for completion in completions]
matches = [re.match(pattern, r) for r in responses]
return [0.5 if match else 0.0 for match in matches]
# 宽松格式奖励函数:检查输出是否基本符合格式要求
def soft_format_reward_func(completions, **kwargs) -> list[float]:
pattern = r"<reasoning>.*?</reasoning>\s*<answer>.*?</answer>"
responses = [completion[0]["content"] for completion in completions]
matches = [re.match(pattern, r) for r in responses]
return [0.5 if match else 0.0 for match in matches]
# 计算 XML 标签的正确使用情况
def count_xml(text) -> float:
count = 0.0
if text.count("<reasoning>\n") == 1:
count += 0.125
if text.count("\n</reasoning>\n") == 1:
count += 0.125
if text.count("\n<answer>\n") == 1:
count += 0.125
count -= len(text.split("\n</answer>\n")[-1])*0.001
if text.count("\n</answer>") == 1:
count += 0.125
count -= (len(text.split("\n</answer>")[-1]) - 1)*0.001
return count
# XML 计数奖励函数:根据 XML 标签的使用情况给出奖励
def xmlcount_reward_func(completions, **kwargs) -> list[float]:
contents = [completion[0]["content"] for completion in completions]
return [count_xml(c) for c in contents]
# %% [markdown]
# <a name="Train"></a>
# ### 训练模型
#
# 现在设置 GRPO Trainer 和所有配置!
# %%
# 导入必要的类
from trl import GRPOConfig, GRPOTrainer
# 设置训练参数
# 知识点:这些参数对模型训练的效果有重要影响
# 参考:https://huggingface.co/docs/transformers/main_classes/trainer
training_args = GRPOConfig(
use_vllm = True, # 使用 vLLM 进行快速推理
learning_rate = 5e-6, # 学习率
adam_beta1 = 0.9, # Adam 优化器参数
adam_beta2 = 0.99, # Adam 优化器参数
weight_decay = 0.1, # 权重衰减,用于防止过拟合
warmup_ratio = 0.1, # 预热比例
lr_scheduler_type = "cosine", # 学习率调度器类型
optim = "paged_adamw_8bit", # 优化器类型
logging_steps = 1, # 日志记录步数
bf16 = is_bfloat16_supported(), # 是否使用 bfloat16
fp16 = not is_bfloat16_supported(), # 是否使用 fp16
per_device_train_batch_size = 1, # 每个设备的训练批次大小
gradient_accumulation_steps = 1, # 梯度累积步数,增加到 4 可使训练更平滑
num_generations = 6, # 生成数量,如果内存不足则减少
max_prompt_length = 256, # 最大提示长度
max_completion_length = 200, # 最大完成长度
# num_train_epochs = 1, # 训练轮数,设为 1 进行完整训练
max_steps = 250, # 最大训练步数
save_steps = 250, # 保存检查点的步数
max_grad_norm = 0.1, # 最大梯度范数,用于梯度裁剪
report_to = "none", # 可以使用 Weights & Biases
output_dir = "outputs", # 输出目录
)
# %% [markdown]
# 让我们运行训练器!向上滚动,你会看到奖励表。目标是让 `reward` 列增加!
#
# 你可能需要等待 150 到 200 步才能看到效果。前 100 步的奖励可能为 0。请耐心等待!
#
# | Step | Training Loss | reward | reward_std | completion_length | kl |
# |------|---------------|-----------|------------|-------------------|----------|
# | 1 | 0.000000 | 0.125000 | 0.000000 | 200.000000 | 0.000000 |
# | 2 | 0.000000 | 0.072375 | 0.248112 | 200.000000 | 0.000000 |
# | 3 | 0.000000 | -0.079000 | 0.163776 | 182.500000 | 0.000005 |
#
# %%
# 初始化并运行训练器
trainer = GRPOTrainer(
model = model,
processing_class = tokenizer,
reward_funcs = [
xmlcount_reward_func,
soft_format_reward_func,
strict_format_reward_func,
int_reward_func,
correctness_reward_func,
],
args = training_args,
train_dataset = dataset,
)
trainer.train()
# %% [markdown]
# <a name="Inference"></a>
# ### 推理
# 现在让我们测试刚刚训练的模型!首先,让我们尝试未经 GRPO 训练的模型:
# %%
# 准备输入文本
text = tokenizer.apply_chat_template([
{"role" : "user", "content" : "Calculate pi."},
], tokenize = False, add_generation_prompt = True)
# 设置采样参数
# 知识点:这些参数影响文本生成的多样性和质量
# 参考:https://arxiv.org/abs/1904.09751
from vllm import SamplingParams
sampling_params = SamplingParams(
temperature = 0.8, # 温度参数,控制生成的随机性
top_p = 0.95, # 累积概率阈值
max_tokens = 1024, # 最大生成标记数
)
# 生成文本
output = model.fast_generate(
[text],
sampling_params = sampling_params,
lora_request = None,
)[0].outputs[0].text
output
# %% [markdown]
# 现在使用我们刚刚用 GRPO 训练的 LoRA - 我们先保存 LoRA!
# %%
# 保存 LoRA 权重
model.save_lora("grpo_saved_lora")
# %% [markdown]
# 现在我们加载 LoRA 并测试:
# %%
# 准备带有系统提示的输入文本
text = tokenizer.apply_chat_template([
{"role" : "system", "content" : SYSTEM_PROMPT},
{"role" : "user", "content" : "Calculate pi."},
], tokenize = False, add_generation_prompt = True)
# 设置采样参数
from vllm import SamplingParams
sampling_params = SamplingParams(
temperature = 0.8,
top_p = 0.95,
max_tokens = 1024,
)
# 使用加载的 LoRA 生成文本
output = model.fast_generate(
text,
sampling_params = sampling_params,
lora_request = model.load_lora("grpo_saved_lora"),
)[0].outputs[0].text
output
# %% [markdown]
# 我们的推理模型变得更好了 - 由于我们只训练了大约一个小时,它并不总是正确的 - 如果我们延长序列长度并训练更长时间,它会变得更好!
# %% [markdown]
# <a name="Save"></a>
# ### 保存为 float16 用于 VLLM
#
# 我们也支持直接保存为 `float16`。选择 `merged_16bit` 用于 float16 或 `merged_4bit` 用于 int4。我们也允许使用 `lora` 适配器作为备选。使用 `push_to_hub_merged` 上传到你的 Hugging Face 账户!你可以在 https://huggingface.co/settings/tokens 获取个人令牌。
# %%
# 合并为 16bit
if False: model.save_pretrained_merged("model", tokenizer, save_method = "merged_16bit",)
if False: model.push_to_hub_merged("hf/model", tokenizer, save_method = "merged_16bit", token = "")
# 合并为 4bit
if False: model.save_pretrained_merged("model", tokenizer, save_method = "merged_4bit",)
if False: model.push_to_hub_merged("hf/model", tokenizer, save_method = "merged_4bit", token = "")
# 仅保存 LoRA 适配器
if False: model.save_pretrained_merged("model", tokenizer, save_method = "lora",)
if False: model.push_to_hub_merged("hf/model", tokenizer, save_method = "lora", token = "")
# %% [markdown]
# ### GGUF / llama.cpp 转换
# 我们现在原生支持保存为 `GGUF` / `llama.cpp`!我们克隆 `llama.cpp` 并默认保存为 `q8_0`。我们支持所有方法如 `q4_k_m`。使用 `save_pretrained_gguf` 进行本地保存,使用 `push_to_hub_gguf` 上传到 HF。
#
# 一些支持的量化方法(完整列表见我们的 [Wiki 页面](https://github.com/unslothai/unsloth/wiki#gguf-quantization-options)):
# * `q8_0` - 快速转换。资源使用高,但通常可接受。
# * `q4_k_m` - 推荐。对 attention.wv 和 feed_forward.w2 张量的一半使用 Q6_K,其余使用 Q4_K。
# * `q5_k_m` - 推荐。对 attention.wv 和 feed_forward.w2 张量的一半使用 Q6_K,其余使用 Q5_K。
#
# [**新功能**] 要微调并自动导出到 Ollama,试试我们的 [Ollama 笔记本](https://colab.research.google.com/drive/1WZDi7APtQ9VsvOrQSSC5DDtxq159j8iZ?usp=sharing)
# %%
# 保存为 8bit Q8_0
if False: model.save_pretrained_gguf("model", tokenizer,)
# 记得在 https://huggingface.co/settings/tokens 获取令牌!
# 并将 hf 改为你的用户名!
if False: model.push_to_hub_gguf("hf/model", tokenizer, token = "")
# 保存为 16bit GGUF
if False: model.save_pretrained_gguf("model", tokenizer, quantization_method = "f16")
if False: model.push_to_hub_gguf("hf/model", tokenizer, quantization_method = "f16", token = "")
# 保存为 q4_k_m GGUF
if False: model.save_pretrained_gguf("model", tokenizer, quantization_method = "q4_k_m")
if False: model.push_to_hub_gguf("hf/model", tokenizer, quantization_method = "q4_k_m", token = "")
# 保存为多个 GGUF 选项 - 如果你想要多个选项,这样更快
if False:
model.push_to_hub_gguf(
"hf/model", # 将 hf 改为你的用户名!
tokenizer,
quantization_method = ["q4_k_m", "q8_0", "q5_k_m",],
token = "",
)
# %% [markdown]
# 现在,在 llama.cpp 或基于 UI 的系统(如 Jan 或 Open WebUI)中使用 `model-unsloth.gguf` 文件或 `model-unsloth-Q4_K_M.gguf` 文件。你可以在[这里](https://github.com/janhq/jan)安装 Jan,在[这里](https://github.com/open-webui/open-webui)安装 Open WebUI。
#
# 我们完成了!如果你对 Unsloth 有任何问题,我们有一个 [Discord](https://discord.gg/unsloth) 频道!如果你发现任何 bug 或想了解最新的 LLM 动态,或需要帮助,加入项目等,欢迎加入我们的 Discord!
#
# 一些其他链接:
# 1. Llama 3.2 对话笔记本。[免费 Colab](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3.2_(1B_and_3B)-Conversational.ipynb)
# 2. 保存微调到 Ollama。[免费笔记本](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3_(8B)-Ollama.ipynb)
# 3. Llama 3.2 Vision 微调 - 放射学用例。[免费 Colab](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3.2_(11B)-Vision.ipynb)
# 6. 在我们的[文档](https://docs.unsloth.ai/get-started/unsloth-notebooks)中查看 DPO、ORPO、持续预训练、对话微调等笔记本!
#
# <div class="align-center">
# <a href="https://unsloth.ai"><img src="https://github.com/unslothai/unsloth/raw/main/images/unsloth%20new%20logo.png" width="115"></a>
# <a href="https://discord.gg/unsloth"><img src="https://github.com/unslothai/unsloth/raw/main/images/Discord.png" width="145"></a>
# <a href="https://docs.unsloth.ai/"><img src="https://github.com/unslothai/unsloth/blob/main/images/documentation%20green%20button.png?raw=true" width="125"></a>
#
# 加入 Discord 获取帮助 + ⭐️ <i>在 <a href="https://github.com/unslothai/unsloth">Github</a> 上给我们点赞 </i> ⭐️
# </div>
#