LoRA(Low-Rank Adaptation)的工作机制 - 使用 LoRA 库来微调深度学习模型的基本步骤
LoRA(Low-Rank Adaptation)的工作机制 - 使用 LoRA 库来微调深度学习模型的基本步骤
flyfish
LoRA: Low-Rank Adaptation of Large Language Models
https://arxiv.org/abs/2106.09685
快速入门
-
安装
loralib
是非常简单的:pip install loralib # 或者 # pip install git+https://github.com/microsoft/LoRA
-
选择性地适应某些层,可以通过将它们替换为
loralib
中实现的对应层。目前我们仅支持nn.Linear
、nn.Embedding
和nn.Conv2d
。我们也支持一个MergedLinear
,用于处理单个nn.Linear
表示多个层的情况,例如在某些注意力机制的qkv
投影实现中。# ===== 之前 ===== # layer = nn.Linear(in_features, out_features) # ===== 之后 ===== import loralib as lora # 添加一对秩为 r=16 的低秩适应矩阵 layer = lora.Linear(in_features, out_features, r=16)
-
在训练循环开始前,标记只有 LoRA 参数为可训练。
import loralib as lora model = BigModel() # 这会将所有名称中不包含 "lora_" 字符串的参数的 `requires_grad` 设置为 `False` lora.mark_only_lora_as_trainable(model) # 训练循环 for batch in dataloader: ...
-
保存检查点时,生成一个只包含 LoRA 参数的
state_dict
。# ===== 之前 ===== # torch.save(model.state_dict(), checkpoint_path) # ===== 之后 ===== torch.save(lora.lora_state_dict(model), checkpoint_path)
-
使用
load_state_dict
加载检查点时,确保设置strict=False
。# 先加载预训练的检查点 model.load_state_dict(torch.load('ckpt_pretrained.pt'), strict=False) # 然后加载 LoRA 检查点 model.load_state_dict(torch.load('ckpt_lora.pt'), strict=False)
现在可以像往常一样进行训练了。
附加说明
-
尽管这是一个简单而有效的设置,即仅适应 Transformer 中的
q
和v
投影,在例子中,LoRA 可以应用于任何预训练权重的子集。鼓励您探索不同的配置,比如通过将nn.Embedding
替换为lora.Embedding
来适应嵌入层,和/或适应 MLP 层。很可能会发现对于不同的模型架构和任务,最佳配置各不相同。 -
一些 Transformer 实现使用单个
nn.Linear
来表示查询、键和值的投影矩阵。如果希望限制对各个矩阵更新的秩,那么要么将其分解成三个独立的矩阵,要么使用lora.MergedLinear
。如果您选择分解该层,请确保相应地修改检查点。# ===== 之前 ===== # qkv_proj = nn.Linear(d_model, 3*d_model) # ===== 之后 ===== # 分解它(记得相应地修改预训练的检查点) q_proj = lora.Linear(d_model, d_model, r=8) k_proj = nn.Linear(d_model, d_model) v_proj = lora.Linear(d_model, d_model, r=8) # 或者,使用 lora.MergedLinear(推荐) qkv_proj = lora.MergedLinear(d_model, 3*d_model, r=8, enable_lora=[True, False, True])
-
与 LoRA 一起训练偏置向量可能是提高任务性能的一种成本效益高的方法(如果仔细调整学习率的话)。虽然论文中没有彻底研究其效果,但尝试变得很容易。可以通过在调用
mark_only_lora_as_trainable
时传递 “all” 或 “lora_only” 给bias=
来标记一些偏置为可训练。记住在保存检查点时使用相同的bias=
(all
或lora_only
) 传递给lora_state_dict
。# ===== 之前 ===== # lora.mark_only_lora_as_trainable(model) # 不训练任何偏置向量 # ===== 之后 ===== # 训练我们应用 LoRA 的模块关联的所有偏置向量 lora.mark_only_lora_as_trainable(model, bias='lora_only') # 或者,我们可以训练模型中的所有偏置向量,包括 LayerNorm 偏置 lora.mark_only_lora_as_trainable(model, bias='all') # 当保存检查点时,使用相同的 bias= ('all' 或 'lora_only') torch.save(lora.lora_state_dict(model, bias='all'), checkpoint_path)
-
调用
model.eval()
会触发 LoRA 参数与相应的预训练参数合并,这消除了后续前向传递的额外延迟。再次调用model.train()
将撤销合并。可以通过向 LoRA 层传递merge_weights=False
来禁用此功能。