当前位置: 首页 > article >正文

【Finetune】(五)、transformers之LORA微调

文章目录

  • 0、LORA基本原理
  • 1、LORA微调实战
    • 1.1、导包
    • 1.2、加载数据集
    • 1.3、数据预处理
    • 1.4、创建模型
    • 1.5、LORA微调
      • 1.5.1、配置文件
      • 1.5.2、创建模型
    • 1.6、配置训练参数
    • 1.7、创建训练器
    • 1.8、模型训练
    • 1.9、模型推理
  • 2、模型合并
    • 2.1、导包
    • 2.2、加载基础模型
    • 2.3、加载LORA模型
    • 2.4、模型合并
    • 2.5、模型保存

0、LORA基本原理

  • 预训练模型中存在一个极小的内在维度,这个内在维度是发挥核心作用的地方

  • 再继续训练过程中,权重的更新依然也有如此的特点即存在一个内在维度(内在秩)

  • 权重更新,w’ = w + Δ \Delta Δw

  • 因此,可以通过矩阵分解的方式,将原本更新的大矩阵转变为两个小的矩阵相乘

  • 权重更新:w’ = w + Δ \Delta Δw = w + BA

  • 具体做法,即再矩阵计算中增加一个旁系分支,旁系分支有两个低秩矩阵组成。

 训练时,输入分别与原始权重和两个低秩矩阵进行计算,共同得到最终的结果,优化则优化A和B
 训练完成后,可以将两个低秩矩阵与原始模型中的权重进行合并,合并后的模型与原始模型无异,避免推理期间prompt系列方法带来的额外计算量
在这里插入图片描述

1、LORA微调实战

1.1、导包

from datasets import Dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, DataCollatorForSeq2Seq, TrainingArguments, Trainer

1.2、加载数据集

ds = Dataset.load_from_disk("../Data/alpaca_data_zh/")
ds

1.3、数据预处理

tokenizer = AutoTokenizer.from_pretrained("../Model/bloom-389m-zh")
tokenizer
def process_func(example):
    MAX_LENGTH = 256
    input_ids, attention_mask, labels = [], [], []
    instruction = tokenizer("\n".join(["Human: " + example["instruction"], example["input"]]).strip() + "\n\nAssistant: ")
    response = tokenizer(example["output"] + tokenizer.eos_token)
    input_ids = instruction["input_ids"] + response["input_ids"]
    attention_mask = instruction["attention_mask"] + response["attention_mask"]
    labels = [-100] * len(instruction["input_ids"]) + response["input_ids"]
    if len(input_ids) > MAX_LENGTH:
        input_ids = input_ids[:MAX_LENGTH]
        attention_mask = attention_mask[:MAX_LENGTH]
        labels = labels[:MAX_LENGTH]
    return {
        "input_ids": input_ids,
        "attention_mask": attention_mask,
        "labels": labels
    }
tokenized_ds = ds.map(process_func, remove_columns=ds.column_names)
tokenized_ds

1.4、创建模型

model = AutoModelForCausalLM.from_pretrained("../Model/bloom-389m-zh", low_cpu_mem_usage=True)

1.5、LORA微调

1.5.1、配置文件

from peft import LoraConfig, get_peft_model, TaskType
from peft import LoraConfig, TaskType, get_peft_model
'''
config = LoraConfig(
    task_type=TaskType.CAUSAL_LM, 
    target_modules=".*\.1.*query_key_value", 
    modules_to_save=["word_embeddings"])#除了value以外还要训练哪部分
'''
config = LoraConfig(
    task_type=TaskType.CAUSAL_LM, 
    r= 8,
    target_modules=['query_key_value','dense_4h_to_h'],
    lora_alpha=8,
    lora_dropout=0,
   # modules_to_save=["word_embeddings"],
    )
config

1.5.2、创建模型

model = get_peft_model(model, config)
model.print_trainable_parameters()

1.6、配置训练参数

args = TrainingArguments(
    output_dir="./lora",
    per_device_train_batch_size=1,
    gradient_accumulation_steps=8,
    logging_steps=10,
    num_train_epochs=1
)

1.7、创建训练器

trainer = Trainer(
    model=model,
    args=args,
    train_dataset=tokenized_ds,
    data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer, padding=True),
)

1.8、模型训练

trainer.train()

1.9、模型推理

model = model.cuda()
ipt = tokenizer("Human: {}\n{}".format("考试有哪些技巧?", "").strip() + "\n\nAssistant: ", return_tensors="pt").to(model.device)
tokenizer.decode(model.generate(**ipt, max_length=128, do_sample=True)[0], skip_special_tokens=True)

2、模型合并

2.1、导包

from transformers import AutoModelForCausalLM, AutoTokenizer

from peft import PeftModel

2.2、加载基础模型

model = AutoModelForCausalLM.from_pretrained("../Model/bloom-389m-zh", low_cpu_mem_usage=True)
tokenizer = AutoTokenizer.from_pretrained("../Model/bloom-389m-zh")

2.3、加载LORA模型

p_model = PeftModel.from_pretrained(model, model_id="./chatbot/checkpoint-500/")
p_model
ipt = tokenizer("Human: {}\n{}".format("考试有哪些技巧?", "").strip() + "\n\nAssistant: ", return_tensors="pt")
tokenizer.decode(p_model.generate(**ipt, do_sample=False)[0], skip_special_tokens=True)

2.4、模型合并

merge_model = p_model.merge_and_unload()
merge_model
ipt = tokenizer("Human: {}\n{}".format("考试有哪些技巧?", "").strip() + "\n\nAssistant: ", return_tensors="pt")
tokenizer.decode(merge_model.generate(**ipt, do_sample=False)[0], skip_special_tokens=True)

2.5、模型保存

merge_model.save_pretrained("./chatbot/merge_model")

http://www.kler.cn/a/318116.html

相关文章:

  • 力扣104 : 二叉树最大深度
  • 24.11.13 Javascript3
  • POI实现根据PPTX模板渲染PPT
  • 【mySql 语句使用】
  • 闯关leetcode——3174. Clear Digits
  • gpu-V100显卡相关知识
  • JetLinks物联网学习(前后端项目启动)
  • 学习编程利器《西蒙学习法》
  • 边学英语边学 Java|Synchronization in java
  • vite配置将es6打包成es5
  • Java-ArrayList和LinkedList区别
  • 速通LLaMA3:《The Llama 3 Herd of Models》全文解读
  • Ubuntu中常用的操作指令
  • vsomeip客户端/服务端大致运行流程
  • STL之vector篇(上)还在为学习vector而感到烦恼吗?每次做算法题都要回忆很久,不如来看看我的文章,精简又易懂,帮你快速掌握vector的相关用法
  • kafka 生产者拦截器
  • yum 安装gcc 时,提示glibc错误依赖
  • LeetCode题练习与总结:二叉树的最近公共祖先--236
  • 读书笔记——DDIA-v2 设计数据密集型应用(第二版)
  • 卷积神经网络——手写数字识别
  • PX4固定翼控制器详解(五)——L1、NPFG控制器
  • 347. 前 K 个高频元素
  • 【2024W36】肖恩技术周刊(第 14 期):什么是完美副业?
  • 大模型培训讲师叶梓:Llama Factory 微调模型实战分享提纲
  • 用Swift实现验证回文字符串
  • 空栈压数 - 华为OD统一考试(E卷)