当前位置: 首页 > article >正文

微调Qwen2.5-0.5B记录

本文基座模型为Qwen2.5-1.5B

0. 加载模型和数据集

tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-1.5B-Instruct")
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-1.5B-Instruct")

# 也可以用hf的接口
with open('your_path', 'r') as f:
    data = json.load(f)

dataset = datasets.Dataset.from_list(data)

1. 数据

在这里插入图片描述
数据是西安交通大学教务处公开的信息(需要登录访问的通知和文件不能爬哈)。通过scrapy递归爬取\info下的内容,提取出标题和文档。

这里尝试过Newspaper3k等自动提取内容的工具库,但是发现效果都不如自己去识别HTML里的内容标签实在。

数据已上传到huggingface,xjtu-info

数据处理

在预训练时,数据长度是不一样的,比较重要的一点就是,我们需要把文本拼接起来,然后阶段成一样的长度,才能通过batch输入进行训练。

具体的流程就是:tokenize->group

首先把所有文本经过tokenizer,得到对应的嵌入,这时候就得到了N个长度不一的inputs_idsattention_mas。然后我们把每一个seq拼接起来,根据block_size拆分,得到新的数据集中每一条输出长度都相同。

attention_mask在这里其实没什么用,因为在预训练时,所有的文本输入都是有内容的,一般在指令微调时会使用attention_mask,因为在填充到max_seq_len的时候,有些数据上的长度不足,在末尾会填充[PAD],因此需要attention_mask

# step 1: merge title and content
def merge_data(data):
    merged = f"{tokenizer.cls_token}标题:{data['title']}\n\n内容:{data['content']}{tokenizer.pad_token}"
    return {'text': merged}

dataset = dataset.map(merge_data)

# step 2: tokenize
def tokenizer_function(examples):
    return tokenizer(examples['text'], add_special_tokens=True, truncation=False)

tokenized_datasets = dataset.map(tokenizer_function, batched=True, remove_columns=['text', 'title', 'content'])

# step 3: group texts
block_size = 512
def group_texts(examples):
    concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}
    total_length = len(concatenated_examples["input_ids"])
    total_length = (total_length // block_size) * block_size
    result = {k: [t[i:i+block_size] for i in range(0, total_length, block_size)] for k, t in concatenated_examples.items()}
    result["labels"] = result["input_ids"].copy()
    return result

final_dataset = tokenized_datasets.map(group_texts, batched=True)

2. 训练

这里使用Lora进行训练,通过LoraConfig,确认微调的任务,以及需要加入Lora的模块,一般都是q,k,v,o,gate,down,uplora的秩设置为8,如果模型大一点可以设置成16,32。alpha是lora模块的权重。

然后就是训练时的设置,训练步数,每个设备的batch数,学习率等等,这里不一一介绍。

最后只需要调用Trainer,输入我们的模型和配置以及数据集即可开始训练。

# step 4: define training arguments
peft_config = LoraConfig(
    task_type=TaskType.CAUSAL_LM,
    target_modules=["q_proj",
                    "v_proj",
                    "k_proj",
                    "o_proj",
                    "gate_proj",
                    "down_proj",
                    "up_proj"
                    ],
    r=8,
    lora_alpha=16,
    lora_dropout=0.05
)

model = get_peft_model(model, peft_config)
model.print_trainable_parameters()


args=TrainingArguments(
    output_dir="./results",
    num_train_epochs=2,
    per_device_train_batch_size=2,
    warmup_steps=100,
    weight_decay=0.01,
    learning_rate=5e-5,
    logging_steps=10,
    save_strategy='steps',
    save_steps=10,
    save_total_limit=2,
    )

trainer = Trainer(
    model=model,
    args=args,
    train_dataset=final_dataset,
)

trainer.train()

Hugging facetransformers还提供了很多方法,例如BitandBytes量化加载,或者混合精度训练,这里没有一一尝试,有兴趣的可以自己去了解,使用也很简单,transformers都帮我们封装好了。

3. 分布式训练

我们可以使用Accelerate来分布式以及加速训练。使用Trainer这种方法训练是完美兼容Accelerate的,你甚至不用在训练代码里导入Accelerate

只需要在命令行里输入accelerate config,然后配置对应的选项,这里只有单机多卡,于是选择了多卡,2卡,卡1和卡2。你还可以选择是否使用DeepSpeedMegatron。然后选择混合精度训练,fp16还是bf16。
在这里插入图片描述

4. 训练结果

预训练前:模型不知道学校具体有哪些教室,也不知道考场在哪。
在这里插入图片描述

训练后:模型知道了考试时间以及地点,虽然不是很准确,但是掌握了大致的知识。
在这里插入图片描述

模型照猫画虎给出了回答,并且模仿了教务的口吻,但是回答的内容其实很不准,下一步通过RAG,让模型可以找出准确的内容,给出正确的答案。


http://www.kler.cn/a/509675.html

相关文章:

  • 深度学习项目--基于LSTM的火灾预测研究(pytorch实现)
  • 如何在 Google Cloud Shell 中使用 Visual Studio Code (VS Code)?
  • vue 学习笔记 - 创建第一个项目 idea
  • OpenCV相机标定与3D重建(60)用于立体校正的函数stereoRectify()的使用
  • mac配置 iTerm2 使用lrzsz与服务器传输文件
  • 【Idea启动项目报错NegativeArraySizeException】
  • 西门子PLC读取梅安森烟雾传感器数据
  • 5. 使用springboot做一个音乐播放器软件项目【业务逻辑开发】
  • 分布式理解
  • SiamCAR(2019CVPR):用于视觉跟踪的Siamese全卷积分类和回归网络
  • app版本控制java后端接口版本管理
  • Spring Boot 中使用 ShardingSphere-Proxy
  • SpringBoot 项目中配置日志系统文件 logback-spring.xml 原理和用法介绍
  • 数字化的三大战场与开源AI智能名片2+1链动模式S2B2C商城小程序源码的应用探索
  • javaEE安全开发 SQL预编译 Filter过滤器 Listener 监听器 访问控制
  • Delete `␍`eslintprettier/prettier
  • 【Linux】14.Linux进程概念(3)
  • 一个好用的vue+node后台管理系统
  • -bash: /java: cannot execute binary file
  • JS宏进阶:正则表达式介绍
  • One Prompt is not Enough: Automated Construction of a Mixture-of-Expert Prompts
  • Vue 动态生成响应式表格:优化桌面与移动端展示效果
  • MySQL程序之:使用DNS SRV记录连接到服务器
  • SAP租赁资产解决方案【物业、地产、酒店、汽车租赁行业】
  • GCC支持Objective C的故事?Objective-C?GCC只能编译C语言吗?Objective-C 1.0和2.0有什么区别?
  • PenGymy论文阅读