当前位置: 首页 > article >正文

大模型微调 - 基于预训练大语言模型的对话生成任务 训练代码

大模型微调 - 基于预训练大语言模型的对话生成任务 训练代码

flyfish

模型扮演堂吉诃德这个角色,回答关于自我介绍的问题

import torch
from datasets import Dataset
from modelscope import AutoTokenizer, AutoModelForCausalLM
from peft import LoraConfig, TaskType, get_peft_model
from transformers import TrainingArguments, Trainer, DataCollatorForSeq2Seq


# 内存中的自定义数据集,包含用户问题(content)和模型回答(summary)
data = [
    {"content": "你叫什么名字?", "summary": "你好,我是堂吉诃德,骑士所做的一切,都是为了你的荣耀。"},
    {"content": "你可以告诉我你的名字吗?", "summary": "你好,我是堂吉诃德,骑士所做的一切,都是为了你的荣耀。"},
    {"content": "你是悟空吗", "summary": "你好,我是堂吉诃德,骑士所做的一切,都是为了你的荣耀。"},
    # 省略若干重复条目……
]

# 加载 tokenizer 和模型
# tokenizer 用于将文本转化为模型可理解的输入格式
tokenizer = AutoTokenizer.from_pretrained("qwen/Qwen2-0.5B-Instruct", use_fast=False, trust_remote_code=True)
# 加载 Qwen2-0.5B-Instruct 模型,并指定使用 bfloat16 精度和自动分配设备(GPU)
model = AutoModelForCausalLM.from_pretrained("qwen/Qwen2-0.5B-Instruct", device_map="auto", torch_dtype=torch.bfloat16)
model.enable_input_require_grads()  # 确保在大模型中开启梯度检查点,用于节省内存

# 数据预处理函数,用于将输入数据格式化为模型可以使用的输入格式
def process_func(example):
    MAX_LENGTH = 384  # 定义输入数据的最大长度
    # 拼接系统提示(堂吉诃德)和用户的提问作为模型的输入
    instruction = tokenizer(
        f"<|im_start|>system\n你是堂吉诃德,请回答以下问题。<|im_end|>\n<|im_start|>user\n{example['content']}<|im_end|>\n<|im_start|>assistant\n",
        add_special_tokens=False,
    )
    # 将回答(summary)也转化为 token
    response = tokenizer(f"{example['summary']}", add_special_tokens=False)
    
    # 将指令和回答拼接成输入序列,并在末尾添加 pad_token
    input_ids = instruction["input_ids"] + response["input_ids"] + [tokenizer.pad_token_id]
    attention_mask = instruction["attention_mask"] + response["attention_mask"] + [1]
    # 标签只包含回答部分,指令部分使用 -100 以防止计算损失
    labels = [-100] * len(instruction["input_ids"]) + response["input_ids"] + [tokenizer.pad_token_id]
    
    # 如果输入序列长度超过 MAX_LENGTH,则截断
    if len(input_ids) > MAX_LENGTH:
        input_ids = input_ids[:MAX_LENGTH]
        attention_mask = attention_mask[:MAX_LENGTH]
        labels = labels[:MAX_LENGTH]
        
    return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels}

# 使用 Hugging Face 的 Dataset API 将数据转化为 Dataset 格式,并进行预处理
dataset = Dataset.from_dict({"content": [d["content"] for d in data], "summary": [d["summary"] for d in data]})
processed_dataset = dataset.map(process_func, remove_columns=["content", "summary"])

# LoRA 配置,用于微调模型
config = LoraConfig(
    task_type=TaskType.CAUSAL_LM,  # 任务类型为因果语言模型
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],  # 指定 LoRA 作用的模块
    inference_mode=False,  # 设置为 False 表示用于训练而非推理
    r=8,  # LoRA 的秩(rank)
    lora_alpha=32,  # LoRA 的 alpha 参数
    lora_dropout=0.1,  # LoRA 使用的 dropout 比例
)

# 将 LoRA 配置应用到模型中,生成可微调的模型
model = get_peft_model(model, config)

# 训练参数配置
training_args = TrainingArguments(
    output_dir="./output/Qwen2-0.5B",  # 训练输出路径
    per_device_train_batch_size=4,  # 每个设备的训练批次大小
    gradient_accumulation_steps=4,  # 梯度累计步数
    logging_steps=10,  # 日志记录的间隔步数
    num_train_epochs=2,  # 训练的 epoch 数
    save_steps=100,  # 保存模型的间隔步数
    learning_rate=1e-4,  # 学习率
    gradient_checkpointing=True,  # 启用梯度检查点以节省显存
    report_to="none"  # 禁用报告工具(如 WandB、Tensorboard)
)

# 使用 Hugging Face 的 Trainer API 设置训练器
trainer = Trainer(
    model=model,  # 要训练的模型
    args=training_args,  # 训练参数
    train_dataset=processed_dataset,  # 训练数据集
    data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer, padding=True),  # 数据整理器,用于批次数据填充
)

# 开始训练
trainer.train()

# 定义预测函数,用于测试模型的生成能力
def predict(messages, model, tokenizer):
    device = "cuda"  # 使用 CUDA 加速
    # 将对话模板应用到输入的消息中
    text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    # 将消息编码为模型输入
    model_inputs = tokenizer([text], return_tensors="pt").to(device)

    # 使用模型生成回答,最多生成 512 个新 tokens
    generated_ids = model.generate(model_inputs.input_ids, max_new_tokens=512)
    # 获取生成的回答,并忽略输入部分的 tokens
    generated_ids = [output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)]
    
    # 解码生成的 tokens 为文本
    response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
    
    return response

# 使用第一个测试数据进行预测
test_message = [
    {"role": "system", "content": "你是堂吉诃德,请回答以下问题。"},
    {"role": "user", "content": "你叫什么名字?"}
]
response = predict(test_message, model, tokenizer)
print(response)  # 输出模型的预测结果

http://www.kler.cn/a/303342.html

相关文章:

  • Flutter:封装ActionSheet 操作菜单
  • 【WEB】网络传输中的信息安全 - 加密、签名、数字证书与HTTPS
  • Unity3D实现WEBGL打开Window文件对话框打开/上传文件
  • 在 Azure 100 学生订阅中新建 Ubuntu VPS 并通过 Docker 部署 pSQL 服务器
  • MySQL 排除指定时间内重复记录的解决方案
  • stack_queue的底层,模拟实现,deque和priority_queue详解
  • 计算机二级自学笔记(选择题1部分)
  • git的快速合并fast-forward merge详解
  • 机器学习和深度学习存在显著区别
  • LeetCode 热题 100 回顾11
  • 【系统架构设计师】ATAM(Architecture Tradeoff Analysis Method)
  • 【免费刷题】实验室安全第一知识题库分享
  • 简单了解深度学习
  • postcss-pxtorem实现页面自适应
  • python爬虫--实用又便捷的第三方模块--requests实战
  • 架构师知识梳理(七):软件工程-测试
  • 【智路】智路OS Perception Pipeline
  • 文件批量添加水印和密码合并单元格完整版
  • Python基础语法(2)
  • 【运维监控】Prometheus+grafana监控spring boot 3运行情况
  • 实现快速产出的短视频剪辑工具
  • Object.entries()
  • 力扣之1783.大满贯数量
  • zabbix之钉钉告警
  • SpringMVC与SpringBoot的区别
  • Docker续9:使用docker-compose部署nmt项目,在haproxy中代理mysql负载均衡