当前位置: 首页 > article >正文

LoRA(Low-Rank Adaptation)的工作机制 - 使用 LoRA 库来微调深度学习模型的基本步骤

LoRA(Low-Rank Adaptation)的工作机制 - 使用 LoRA 库来微调深度学习模型的基本步骤

flyfish

LoRA: Low-Rank Adaptation of Large Language Models

https://arxiv.org/abs/2106.09685

快速入门

  1. 安装 loralib 是非常简单的:

    pip install loralib
    # 或者
    # pip install git+https://github.com/microsoft/LoRA
    
  2. 选择性地适应某些层,可以通过将它们替换为 loralib 中实现的对应层。目前我们仅支持 nn.Linearnn.Embeddingnn.Conv2d。我们也支持一个 MergedLinear,用于处理单个 nn.Linear 表示多个层的情况,例如在某些注意力机制的 qkv 投影实现中。

    # ===== 之前 =====
    # layer = nn.Linear(in_features, out_features)
    
    # ===== 之后 =====
    import loralib as lora
    # 添加一对秩为 r=16 的低秩适应矩阵
    layer = lora.Linear(in_features, out_features, r=16)
    
  3. 在训练循环开始前,标记只有 LoRA 参数为可训练。

    import loralib as lora
    model = BigModel()
    # 这会将所有名称中不包含 "lora_" 字符串的参数的 `requires_grad` 设置为 `False`
    lora.mark_only_lora_as_trainable(model)
    # 训练循环
    for batch in dataloader:
       ...
    
  4. 保存检查点时,生成一个只包含 LoRA 参数的 state_dict

    # ===== 之前 =====
    # torch.save(model.state_dict(), checkpoint_path)
    # ===== 之后 =====
    torch.save(lora.lora_state_dict(model), checkpoint_path)
    
  5. 使用 load_state_dict 加载检查点时,确保设置 strict=False

    # 先加载预训练的检查点
    model.load_state_dict(torch.load('ckpt_pretrained.pt'), strict=False)
    # 然后加载 LoRA 检查点
    model.load_state_dict(torch.load('ckpt_lora.pt'), strict=False)
    
现在可以像往常一样进行训练了。

附加说明

  1. 尽管这是一个简单而有效的设置,即仅适应 Transformer 中的 qv 投影,在例子中,LoRA 可以应用于任何预训练权重的子集。鼓励您探索不同的配置,比如通过将 nn.Embedding 替换为 lora.Embedding 来适应嵌入层,和/或适应 MLP 层。很可能会发现对于不同的模型架构和任务,最佳配置各不相同。

  2. 一些 Transformer 实现使用单个 nn.Linear 来表示查询、键和值的投影矩阵。如果希望限制对各个矩阵更新的秩,那么要么将其分解成三个独立的矩阵,要么使用 lora.MergedLinear。如果您选择分解该层,请确保相应地修改检查点。

    # ===== 之前 =====
    # qkv_proj = nn.Linear(d_model, 3*d_model)
    # ===== 之后 =====
    # 分解它(记得相应地修改预训练的检查点)
    q_proj = lora.Linear(d_model, d_model, r=8)
    k_proj = nn.Linear(d_model, d_model)
    v_proj = lora.Linear(d_model, d_model, r=8)
    # 或者,使用 lora.MergedLinear(推荐)
    qkv_proj = lora.MergedLinear(d_model, 3*d_model, r=8, enable_lora=[True, False, True])
    
  3. 与 LoRA 一起训练偏置向量可能是提高任务性能的一种成本效益高的方法(如果仔细调整学习率的话)。虽然论文中没有彻底研究其效果,但尝试变得很容易。可以通过在调用 mark_only_lora_as_trainable 时传递 “all” 或 “lora_only” 给 bias= 来标记一些偏置为可训练。记住在保存检查点时使用相同的 bias= (alllora_only) 传递给 lora_state_dict

    # ===== 之前 =====
    # lora.mark_only_lora_as_trainable(model) # 不训练任何偏置向量
    # ===== 之后 =====
    # 训练我们应用 LoRA 的模块关联的所有偏置向量
    lora.mark_only_lora_as_trainable(model, bias='lora_only')
    # 或者,我们可以训练模型中的所有偏置向量,包括 LayerNorm 偏置
    lora.mark_only_lora_as_trainable(model, bias='all')
    # 当保存检查点时,使用相同的 bias= ('all' 或 'lora_only')
    torch.save(lora.lora_state_dict(model, bias='all'), checkpoint_path)
    
  4. 调用 model.eval() 会触发 LoRA 参数与相应的预训练参数合并,这消除了后续前向传递的额外延迟。再次调用 model.train() 将撤销合并。可以通过向 LoRA 层传递 merge_weights=False 来禁用此功能。


http://www.kler.cn/a/380438.html

相关文章:

  • apisix的hmac-auth认证
  • 产品初探Devops!以及AI如何赋能Devops?
  • nvidia docker, nvidia docker2, nvidia container toolkits区别
  • 【论文阅读笔记】IC-Light
  • linux命令中cp命令-rf与-a的差别
  • vue3入门教程:计算属性
  • 学习笔记:黑马程序员JavaWeb开发教程(2024.11.4)
  • 虚拟机 Ubuntu 扩容
  • Qt第三课 ----------输入类的控件属性
  • 深度学习之Dropout
  • K8S flannel网络模式对比
  • 恒创科技:如何知道一台服务器能承载多少用户?
  • 【Elasticsearch系列】更改 Elasticsearch 用户密码的详细指南
  • 【RAG多模态】多模态RAG-ColPali:使用视觉语言模型实现高效的文档检索
  • Python pyautogui库:自动化操作的强大工具
  • Redis-06 Redis面试高频问题、Redis日常开发规避问题
  • 【LLM-多模态】MM1:多模态大模型预训练的方法、分析与见解
  • mybatis 参数判断报错的问题
  • ML2001-2 机器学习/深度学习 过拟合(overfit)
  • Qt中的Model与View5: QStyledItemDelegate
  • 【含文档+源码】基于SpringBoot+Vue的新型吃住玩一体化旅游管理系统的设计与实现
  • 【格式化查看JSON文件】coco的json文件内容都在一行如何按照json格式查看
  • Hadoop生态系统主要包括哪些组件以及它们的作用
  • 探索 MarsCode:代码练习-AI助你提升编码/算法能力
  • OpenCV图像基础
  • 红队知识学习入门(3)Shodan使用6