当前位置: 首页 > article >正文

【LLM学习笔记】第三篇:模型微调及LoRA介绍(附PyTorch实例)

文章目录

      • 1. 模型微调
        • 1.1 为什么要进行模型微调?
        • 1.2 模型微调的分类
      • 2. 高效模型微调LoRA
        • 2.1 LoRA概述
        • 2.2 LoRA的机理
      • 3. LoRA的变体
        • 3.1 分布式LoRA (Distributed LoRA)
        • 3.2 动态LoRA (Dynamic LoRA)
        • 3.3 多任务LoRA (Multi-Task LoRA)
        • 3.4 层级LoRA (Hierarchical LoRA)
        • 3.5 自适应LoRA (Adaptive LoRA)
        • 3.6 集成LoRA (Ensemble LoRA)
      • 4. `peft`库
      • 5.结论

1. 模型微调

1.1 为什么要进行模型微调?

模型微调(Fine-Tuning)是一种在深度学习中广泛应用的技术,旨在通过在特定任务上对预训练模型进行进一步训练,以提高模型在该任务上的性能。这一过程不仅仅是简单的训练调整,而是通过利用预训练模型已有的知识和特征表示,使其更好地适应新的任务需求。
在这里插入图片描述

预训练模型通常是在大规模数据集上训练的,这些模型已经学到了丰富的、通用的特征表示。这些特征对于许多任务都是非常有价值的。例如,在自然语言处理任务中,预训练的语言模型(如BERT、GPT等)已经学会了词义、语法结构和语境信息;在计算机视觉任务中,预训练的卷积神经网络(如ResNet、VGG等)已经学会了边缘、纹理、形状等低级特征,以及物体、场景等高级特征。通过微调,可以将这些学到的特征迁移到新的任务中,从而提高模型的初始性能。

尽管预训练模型具有强大的通用特征提取能力,但它们可能不完全适合特定任务的数据分布和特点。通过微调,可以针对特定任务的数据进行进一步训练,使模型更好地适应任务需求。例如,在医学影像分类任务中,预训练模型可能已经学会了识别一般的图像特征,但通过微调,可以使模型更专注于识别特定的医学影像特征,从而提高分类准确率。

1.2 模型微调的分类

模型微调分为全局微调(全量微调)和局部微调:

  • 全局微调

    • 定义:全局微调是指对整个预训练模型的所有参数进行微调。这种方法适用于数据量较大且计算资源充足的情况。
    • 优点:可以充分利用所有参数的调整空间,使模型更好地适应特定任务。
    • 缺点:计算成本较高,容易导致过拟合,尤其是在数据量较少的情况下。
  • 局部微调

    • 定义:局部微调是指只对预训练模型的部分参数进行微调,通常是最后几层或特定的模块。这种方法适用于数据量较小或计算资源有限的情况。
    • 优点:计算成本较低,可以有效避免过拟合。
    • 缺点:可能无法充分利用预训练模型的全部潜力。

本文要介绍的LoRA即属于一种高效的局部微调方法。

2. 高效模型微调LoRA

ji

2.1 LoRA概述

LoRA(Low-Rank Adaptation)是一种高效的模型微调技术,旨在通过在预训练模型中插入低秩矩阵来减少微调所需的参数量,从而提高训练效率并避免过拟合。LoRA 的核心思想是在保持预训练模型大部分参数不变的情况下,通过添加低秩矩阵来模拟参数的变化量,从而实现对特定任务的适应。

关于“低秩近似”的概念,我在此前的文章有介绍过深度学习中的常用线性代数知识汇总——第一篇:基础概念、秩、奇异值

研究表明:语言模型针对特定任务微调之后,权重矩阵通常具有很低的本征秩(Intrinsic Rank),参数更新量即便投影到较小的子空间中,也不会影响学习的有效性。因此,提出固定预训练模型参数不变,在原权重矩阵旁路添加低秩矩阵的乘积作为可训练参数,用以模拟参数的变化量。

2.2 LoRA的机理

具体来说,假设预训练权重为 W 0 ∈ R d × k W_0 \in \mathbb{R}^{d \times k} W0Rd×k,可训练参数为 Δ W = B A \Delta W = BA ΔW=BA,其中 A ∈ R d × r A \in \mathbb{R}^{d \times r} ARd×r B ∈ R r × k B \in \mathbb{R}^{r \times k} BRr×k。初始化时,矩阵A通过高斯函数初始化,矩阵B为零初始化,使得训练开始之前旁路对原模型不造成影响,即参数变化量为0。对于该权重的输入 x ∈ R d x \in \mathbb{R}^d xRd 来说,输出如下:

h = W 0 x + Δ W x = W 0 x + B A x h = W_0 x + \Delta W x = W_0 x + BA x h=W0x+ΔWx=W0x+BAx

在这里插入图片描述

3. LoRA的变体

3.1 分布式LoRA (Distributed LoRA)

分布式LoRA通过将低秩矩阵分布在多个层或模块中,进一步提高参数效率和模型性能。

  • 多层低秩矩阵:在多个层中插入低秩矩阵,每个层的低秩矩阵可以有不同的秩 r i r_i ri
  • 参数更新:每个层的权重矩阵 W i W_i Wi 更新为 W i + B i A i W_i + B_i A_i Wi+BiAi
  • 前向传播:每个层的输出 h i h_i hi 计算为 h i = W i x + B i A i x h_i = W_i x + B_i A_i x hi=Wix+BiAix
3.2 动态LoRA (Dynamic LoRA)

动态LoRA通过在训练过程中动态调整低秩矩阵的秩 r r r,以适应不同的训练阶段和任务需求。

  • 动态秩调整:在训练初期使用较小的秩 r r r,随着训练的进行逐渐增加秩 r r r
  • 参数更新:根据当前的秩 r r r动态计算$ \Delta W$。
  • 前向传播 h = W 0 x + B ( r ) A ( r ) x h = W_0 x + B(r) A(r) x h=W0x+B(r)A(r)x,其中 B ( r ) B(r) B(r) A ( r ) A(r) A(r)是根据当前秩 r r r 动态调整的矩阵。
3.3 多任务LoRA (Multi-Task LoRA)

多任务LoRA通过在多任务学习中共享低秩矩阵,提高模型的泛化能力和资源利用率。

  • 共享低秩矩阵:多个任务共享同一个低秩矩阵 A A A B B B,但每个任务可以有自己的权重矩阵 W 0 W_0 W0
  • 参数更新:每个任务的权重矩阵 W 0 , i W_{0,i} W0,i 更新为 W 0 , i + B A W_{0,i} + BA W0,i+BA
  • 前向传播:每个任务的输出 h i h_i hi 计算为 h i = W 0 , i x + B A x h_i = W_{0,i} x + BA x hi=W0,ix+BAx
3.4 层级LoRA (Hierarchical LoRA)

层级LoRA通过在不同层级的模块中插入低秩矩阵,实现更细粒度的参数控制和优化。

  • 多层级低秩矩阵:在不同层级的模块中插入低秩矩阵,每个模块的低秩矩阵可以有不同的秩 r i r_i ri
  • 参数更新:每个模块的权重矩阵 W i W_i Wi 更新为 W i + B i A i W_i + B_i A_i Wi+BiAi
  • 前向传播:每个模块的输出 h i h_i hi 计算为 h i = W i x + B i A i x h_i = W_i x + B_i A_i x hi=Wix+BiAix
3.5 自适应LoRA (Adaptive LoRA)

自适应LoRA通过引入自适应机制,使低秩矩阵根据输入数据动态调整,从而提高模型的适应性和性能。

  • 自适应低秩矩阵:低秩矩阵 A A A B B B根据输入数据 x x x动态调整。
  • 参数更新 Δ W = B ( x ) A ( x ) \Delta W = B(x) A(x) ΔW=B(x)A(x)
  • 前向传播 h = W 0 x + B ( x ) A ( x ) x h = W_0 x + B(x) A(x) x h=W0x+B(x)A(x)x
3.6 集成LoRA (Ensemble LoRA)

集成LoRA通过组合多个低秩矩阵,提高模型的表达能力和鲁棒性。

  • 多个低秩矩阵:使用多个低秩矩阵 A i A_i Ai B i B_i Bi,每个矩阵可以有不同的秩 r i r_i ri
  • 参数更新:权重矩阵 W W W更新为 W = W 0 + ∑ i B i A i W = W_0 + \sum_i B_i A_i W=W0+iBiAi
  • 前向传播 h = W 0 x + ∑ i B i A i x h = W_0 x + \sum_i B_i A_i x h=W0x+iBiAix

LoRA的这些变体在不同的应用场景中各有优势,可以根据具体任务的需求选择合适的变体。基本LoRA适用于大多数微调任务,而分布式LoRA、动态LoRA、多任务LoRA、层级LoRA、自适应LoRA和集成LoRA则在特定场景下提供了更高的灵活性和性能。通过这些变体,LoRA技术可以更好地适应各种复杂任务,提高模型的性能和效率。

4. peft

peft(Parameter-Efficient Fine-Tuning)是一个用于参数高效微调的库,支持多种方法,包括LoRA。下面是一个基于PyTorch和peft库的LoRA微调示例代码。这个示例将展示如何使用Hugging Face的transformers库和peft库来微调一个预训练的BERT模型。

import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from peft import LoraConfig, get_peft_model

# 加载预训练模型和分词器
model_name = "bert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2)

# 定义LoRA配置
lora_config = LoraConfig(
    r=8,  # 低秩矩阵的秩
    lora_alpha=16,  # LoRA的缩放因子
    target_modules=["query", "value"]  # 需要应用LoRA的模块
)

# 将LoRA配置应用到模型
model = get_peft_model(model, lora_config)

# 准备一些简单的数据
train_texts = ["I love this movie.", "This movie is terrible."]
train_labels = [1, 0]  # 1 表示正类,0 表示负类

# 对数据进行编码
train_encodings = tokenizer(train_texts, truncation=True, padding=True, max_length=128, return_tensors='pt')

# 定义损失函数和优化器
loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)

# 训练模型
model.train()
for epoch in range(3):  # 训练3个epoch
    optimizer.zero_grad()
    outputs = model(input_ids=train_encodings['input_ids'], attention_mask=train_encodings['attention_mask'])
    loss = loss_fn(outputs.logits, torch.tensor(train_labels))
    loss.backward()
    optimizer.step()
    print(f"Epoch {epoch + 1}, Loss: {loss.item()}")

# 保存模型
model.save_pretrained('./simple_lora_model')
tokenizer.save_pretrained('./simple_lora_model')

print("Training complete!")

5.结论

本文介绍了模型微调的目的及方法,详细说明了局部微调中的LoRA方法。


http://www.kler.cn/a/393505.html

相关文章:

  • 代码随想录第二十一天| 669. 修剪二叉搜索树 108.将有序数组转换为二叉搜索树 538.把二叉搜索树转换为累加树
  • 贪心算法day03(最长递增序列问题)
  • 远离生成式AI大乱斗,SAS公司揭示亚太区千亿AI市场蓝图
  • Node.js笔记
  • Autosar CP 基于CAN的时间同步规范导读
  • 【算法】——二分查找合集
  • 有什么好用的 WebSocket 调试工具吗?
  • Nuxt 版本 2 和 版本 3 的区别
  • LeetCode 216-组合总数Ⅲ
  • 【Qualcomm】Ubuntu20.04安装QualcommPackageManager3
  • HTML 基础架构:理解网页的骨架
  • 【Git】Git Clone 指定自定义文件夹名称:详尽指南
  • 多态之魂:C++中的优雅与力量
  • Leetcode 最后一个单词的长度
  • Clickhouse集群新建用户、授权以及remote权限问题
  • 怎么用家用电脑做服务器(web服务器、ftp服务器、小程序服务器,云电脑)
  • sql专题 之 三大范式
  • 标准C++ 字符串
  • 操作系统lab3-进程同步实验
  • uniapp使用scroll-view下拉刷新与上滑加载
  • idea正则表达式-正则替换示例-2024.11笔记
  • 2024中国游戏出海情况
  • attention 注意力机制 学习笔记-GPT2
  • python---基础语法
  • 【HarmonyOS】Install Failed: error: failed to install bundle.code:9568289
  • CCF认证-202403-04 | 十滴水