当前位置: 首页 > article >正文

(NIPS-2024)PISSA:大型语言模型的主奇异值和奇异向量适配

PISSA:大型语言模型的主奇异值和奇异向量适配

Paper是北京大学发表在NIPS 2024的工作

Paper Title:PISSA: PRINCIPAL SINGULAR VALUES AND SINGULAR VECTORS ADAPTATION OF LARGE LANGUAGE MODELS

Code地址

ABSTRACT

为了参数高效地微调(PEFT)大规模语言模型(LLMs),低秩适配(LoRA)方法通过两个矩阵的乘积来近似模型的变化 Δ W ∈ R m × n \Delta W \in \mathbb{R}^{m \times n} ΔWRm×n:即 A ∈ R m × r A \in \mathbb{R}^{m \times r} ARm×r B ∈ R r × n B \in \mathbb{R}^{r \times n} BRr×n,其中 r ≪ min ⁡ ( m , n ) r \ll \min (m, n) rmin(m,n) A A A 初始化为高斯噪声, B B B 初始化为零。LoRA 冻结原始模型权重 W W W,仅更新 “噪声和零” 适配器,这可能导致收敛速度较慢。为了解决这一限制,我们引入了主奇异值和奇异向量适配(PiSSA)。PiSSA 与 LoRA 具有相同的架构,但将适配器矩阵 A A A B B B 初始化为原始矩阵 W W W 的主成分,并将剩余部分置于残差矩阵 W res ∈ R m × n W^{\text{res}} \in \mathbb{R}^{m \times n} WresRm×n 中,该矩阵在微调过程中被冻结。与 LoRA 相比,PiSSA 更新了主成分而冻结了“残差”部分,从而实现了更快的收敛速度和更高的性能。PiSSA 和 LoRA 在 12 种不同模型(规模从 184M 到 70B)、涵盖 5 项 NLG 任务和 8 项 NLU 任务的对比实验中,表明 PiSSA 在相同实验设置下始终优于 LoRA。在 GSM8K 基准上,使用 PiSSA 微调的 Mistral-7B 模型的准确率达到 72.86 % 72.86\% 72.86%,比 LoRA 的 67.7 % 67.7\% 67.7% 提高了 5.16 % 5.16\% 5.16%。由于具有相同的架构,PiSSA 也兼容量化,从而进一步降低微调所需的内存。与 QLoRA 相比,QPiSSA(PiSSA 使用 4-bit 量化)在初始阶段表现出更小的量化误差。在 GSM8K 上微调 LLaMA-3-70B 时,QPiSSA 的准确率达到 86.05 % 86.05\% 86.05%,超过了 QLoRA 的 81.73 % 81.73\% 81.73%。借助快速 SVD 技术,PiSSA 的初始化仅需几秒钟,为从 LoRA 过渡到 PiSSA 提供了几乎可以忽略的成本。

1 Introduction

微调大型语言模型(LLMs)是一种非常有效的技术,可以增强模型在各种任务中的能力 [ 1 , 2 , 3 , 4 ] [1,2,3,4] [1,2,3,4],确保模型遵循指令 [ 5 , 6 , 7 ] [5,6,7] [5,6,7],并赋予模型理想的行为,同时消除不良行为 [ 8 , 9 ] [8,9] [8,9]。然而,对于非常大的模型,微调过程伴随着高昂的成本。例如,常规 16-bit 微调 LLaMA 的 65B 参数模型需要超过 780 GB 的 GPU 内存 [ 10 ] [10] [10],而训练 GPT-3 175B 的 VRAM 消耗高达 1.2 TB [ 11 ] [11] [11]。因此,各种参数高效微调(PEFT)方法 [ 12 , 13 ] [12,13] [12,13] 被提出,以减少微调所需的参数量和内存使用量。由于能够在不增加推理延迟的情况下保持完整微调的性能,低秩适配(Low-Rank Adaptation, LoRA) [ 11 ] [11] [11] 已成为一种流行的 PEFT 方法。

LoRA [ 11 ] [11] [11] 假设微调过程中参数矩阵的修改具有低秩特性。如图 1b 所示,对于一个预训练权重矩阵 W ∈ R m × n W \in \mathbb{R}^{m \times n} WRm×n,LoRA 将更新替换为低秩分解 Δ W = A B \Delta W=A B ΔW=AB,其中 A ∈ R m × r A \in \mathbb{R}^{m \times r} ARm×r B ∈ R r × n B \in \mathbb{R}^{r \times n} BRr×n,且秩 r ≪ min ⁡ ( m , n ) r \ll \min (m, n) rmin(m,n)。对于 Y = X W Y=X W Y=XW,修改后的前向传播如下:

Y = X ( W + Δ W ) = X ( W + A B ) Y=X(W+\Delta W)=X(W+A B) Y=X(W+ΔW)=X(W+AB)

图1

图 1:完全微调、使用 LoRA 进行训练和 PiSSA 之间的比较。在此可视化中,蓝色模块表示模型中在训练期间参数被冻结的部分,而橙色模块表示需要更新的组件。QLoRA 将 LoRA 中的预训练矩阵量化为 4 位,而 QPiSSA 将 PiSSA 中的残差矩阵量化。

图2

图 2:我们说明了 PiSSA 的两个关键优势:收敛速度更快更好,以及减少量化误差。 在左图中,我们使用一个示例来展示 PiSSA 更快的收敛速度,其中我们首先训练一个两层 MLP 对 MNIST 的奇数进行分类,然后在偶数上对模型进行微调。PiSSA 更快地找到了正确的方向,并且在相同步数下实现了更低的损失。在右图中,我们展示了与直接量化基础模型相比,PiSSA 的量化误差减少情况。与基线相比,PiSSA 可以减少更多的量化误差。

其中, X ∈ R b × m , Y ∈ R b × n X \in \mathbb{R}^{b \times m}, Y \in \mathbb{R}^{b \times n} XRb×m,YRb×n b b b 表示输入数据的批量大小。 A A A 使用随机高斯初始化, B B B 初始化为零,这使得在训练初期 A B = 0 A B=0 AB=0,从而插入的适配器不会影响模型的初始输出。LoRA 避免了对原始矩阵 W W W 计算梯度或维护优化器状态,而是优化注入的显著更小的低秩矩阵 A , B A, B A,B。因此,LoRA 可将可训练参数量减少 10,000 倍,GPU 内存需求减少 3 倍 [ 11 ] [11] [11]。此外,LoRA 通常能够达到与完整参数微调相当或更优的性能,这表明微调完整参数的“部分”已足以应对下游任务。通过对预训练矩阵 W W W 的量化,LoRA 还可将平均内存需求减少 16 倍 [ 10 ] [10] [10]。同时,适配器仍然可以使用更高精度的权重,因此量化通常不会显著降低 LoRA 的性能。

根据公式 (1), A A A B B B 的梯度分别为:

∂ L ∂ A = X T ( ∂ L ∂ Y ) B T , ∂ L ∂ B = A T X T ( ∂ L ∂ Y ) \frac{\partial L}{\partial A}=X^T\left(\frac{\partial L}{\partial Y}\right) B^T, \quad \frac{\partial L}{\partial B}=A^T X^T\left(\frac{\partial L}{\partial Y}\right) AL=XT(YL)BT,BL=ATXT(YL)

与完整微调相比,LoRA 初始阶段对相同的输入 X X X 不改变输出 Y Y Y,因此梯度的大小主要由 A A A B B B 的值决定。由于 LoRA 中 A A A B B B 初始化为高斯噪声和零,其梯度可能非常小,从而导致微调过程的收敛速度较慢。我们通过实验观察到这一现象,LoRA 通常在初始点附近浪费大量时间。


主奇异值与奇异向量适配器 (PiSSA)

我们的 PiSSA 方法区别于 LoRA 及其后继者,不是通过近似 Δ W \Delta W ΔW,而是直接针对 W W W 进行优化。我们对矩阵 W W W 进行奇异值分解(SVD),并根据奇异值的大小将 W W W 分为两部分:

  • 主低秩矩阵 W pri W^{\text{pri}} Wpri:包含最大的几个奇异值;
  • 残差矩阵 W res W^{\text{res}} Wres:包含其余较小的奇异值(数量较多,可能表现为长尾分布)。

主矩阵 W pri W^{\text{pri}} Wpri 可表示为 A ∈ R m × r A \in \mathbb{R}^{m \times r} ARm×r B ∈ R r × n B \in \mathbb{R}^{r \times n} BRr×n 的乘积,其中 r ≪ min ⁡ ( m , n ) r \ll \min (m, n) rmin(m,n)。如图 1c 所示, A A A B B B 基于主奇异值和奇异向量进行初始化,并在训练中可更新。相反,残差矩阵 W res W^{\text{res}} Wres 由残差奇异值和奇异向量的乘积初始化,并在微调过程中被冻结。

由于主奇异向量表示矩阵 W W W 具有最大拉伸或影响的方向,直接微调这些主成分使得 PiSSA 能够更快更好地拟合训练数据(如图 2a 所示)。此外,在我们的实验中,PiSSA 的损失和梯度范数曲线往往与完整参数微调表现出类似的趋势(见图 4),这表明微调主成分在某种程度上与完整矩阵微调的行为一致。


兼容量化

由于主成分 W pri W^{\text{pri}} Wpri 在适配器中被保留并具有完整精度,PiSSA 的另一个优势是当对冻结部分 W res W^{\text{res}} Wres 进行量化时,与 QLoRA(对整个 W W W 进行量化)相比,可以显著降低量化误差(如图 2b 所示)。因此,PiSSA 完美兼容量化,成为 LoRA 的即插即用替代方案。

2 Related Works

具有数十亿个参数的大型语言模型 (LLM) 具有巨大的复杂性和计算需求,这给将其用于特定的下游任务带来了巨大的障碍。参数高效微调 (PEFT) [12, 13] 是一种引人注目的解决方案,它通过最小化微调参数和内存需求,同时实现与完全微调相当的性能。PEFT 包含部分微调 [14, 15, 16, 17, 18, 19, 20, 21]、软提示微调 [22, 23, 24, 25, 26, 27, 28]、非线性适配器微调 [29, 30, 31, 32, 33, 34, 35, 36, 37, 38] 和基于低秩适配器的微调 [39, 40, 11, 41] 等策略。

LoRA [11] 将可训练的适配器注入线性层。经过微调后,这些调整可以重新参数化到标准模型结构中,因此获得了广泛采用,因为它们能够在保持模型原始架构的同时实现高效的微调。继 LoRA 之后,AdaLoRA [42、41、43] 动态学习 LoRA 在模型每一层所需的秩大小。DeltaLoRA [44、45] 使用适配器层的参数更新模型的原始权重,增强了 LoRA 的表示能力。LoSparse [46] 结合了 LoRA,以防止修剪消除过多的表达神经元。DoRA [47] 引入了一个幅度分量来学习 ∆W 的尺度,同时利用原始 AB 作为 ∆W 的方向分量。与专注于学习权重更新的低秩近似值的 LoRA 及其后续方法不同,我们的 PiSSA 方法直接调整模型中重要但低秩的部分,同时保持噪声较大、高秩和非重要部分不变。虽然我们的方法在理念上与 LoRA 不同,但它具有 LoRA 的大部分结构优势,并且可以通过这些方法进行扩展以提高其性能。

QLoRA [10] 将 LoRA 与 4 位 NormalFloat (NF4) 量化以及双量化和分页优化器集成在一起,从而能够在单个 48GB GPU 上对 65B 参数模型进行微调,同时保留完整 16 位微调任务的性能。QA-LoRA [48] 引入了分组运算符来增加低位量化的自由度。LoftQ [49] 通过分解 QLoRA 的量化误差矩阵并使用适配器保留主成分来减少量化误差。我们的 PiSSA 方法也可以与量化技术相结合,我们发现与 QLoRA 和 LoftQ 相比,PiSSA 显著降低了量化误差。

3 PiSSA: Principal Singular Values and Singular Vectors Adaptation

本节正式介绍了 主奇异值与奇异向量适配 (PiSSA) 方法。PiSSA 对自注意力和多层感知机 (MLP) 层中的矩阵 W W W 进行奇异值分解 (SVD)。矩阵 W ∈ R m × n W \in \mathbb{R}^{m \times n} WRm×n 的 (经济型) SVD 表示为 W = U S V T W = U S V^T W=USVT,其中:

  • U ∈ R m × min ⁡ ( m , n ) U \in \mathbb{R}^{m \times \min(m, n)} URm×min(m,n), V ∈ R n × min ⁡ ( m , n ) V \in \mathbb{R}^{n \times \min(m, n)} VRn×min(m,n) 是具有正交列的奇异向量, V T V^T VT V V V 的转置;
  • S = diag ⁡ ( s ) ∈ R min ⁡ ( m , n ) × min ⁡ ( m , n ) S=\operatorname{diag}(\mathbf{s}) \in \mathbb{R}^{\min(m, n) \times \min(m, n)} S=diag(s)Rmin(m,n)×min(m,n) 是对角矩阵, s ∈ R ≥ 0 min ⁡ ( m , n ) \mathbf{s} \in \mathbb{R}_{\geq 0}^{\min(m, n)} sR0min(m,n) 表示按降序排列的奇异值。

当前 r r r 个奇异值 s [ : r ] \mathbf{s}_{[:r]} s[:r] 显著大于剩余奇异值 s [ r : ] \mathbf{s}_{[r:]} s[r:] 时,称 W W W 的内在秩为 r r r。因此, S S S U U U V V V 可分为两部分:

  • 主奇异值与向量 { U [ : , : r ] , S [ : r , : r ] , V [ : , : r ] } \{U_{[:, :r]}, S_{[:r, :r]}, V_{[:, :r]}\} {U[:,:r],S[:r,:r],V[:,:r]}
  • 残差奇异值与向量 { U [ : , r : ] , S [ r : , r : ] , V [ : , r : ] } \{U_{[:, r:]}, S_{[r:, r:]}, V_{[:, r:]}\} {U[:,r:],S[r:,r:],V[:,r:]}

主奇异值与向量用于初始化适配器,适配器由 A ∈ R m × r A \in \mathbb{R}^{m \times r} ARm×r B ∈ R r × n B \in \mathbb{R}^{r \times n} BRr×n 组成:

A = U [ : , : r ] S [ : r , : r ] 1 / 2 ∈ R m × r , B = S [ : r , : r ] 1 / 2 V [ : , : r ] T ∈ R r × n . \begin{aligned} & A = U_{[:, :r]} S_{[:r, :r]}^{1/2} \in \mathbb{R}^{m \times r}, \\ & B = S_{[:r, :r]}^{1/2} V_{[:, :r]}^T \in \mathbb{R}^{r \times n}. \end{aligned} A=U[:,:r]S[:r,:r]1/2Rm×r,B=S[:r,:r]1/2V[:,:r]TRr×n.

残差奇异值与向量用于构建冻结的残差矩阵:

W res = U [ : , r : ] S [ r : , r : ] V [ : , r : ] T ∈ R m × n . W^{\text{res}} = U_{[:, r:]} S_{[r:, r:]} V_{[:, r:]}^T \in \mathbb{R}^{m \times n}. Wres=U[:,r:]S[r:,r:]V[:,r:]TRm×n.

根据公式 (5),主矩阵和残差矩阵的结合在微调开始时能够保留预训练模型的全部能力:

Y = X W = X ( W res + W pri ) = X ( W res + A B ) . Y = X W = X(W^{\text{res}} + W^{\text{pri}}) = X(W^{\text{res}} + A B). Y=XW=X(Wres+Wpri)=X(Wres+AB).


梯度更新与训练

与 LoRA 类似, A A A B B B 的梯度为:

∂ L ∂ A = X T ( ∂ L ∂ Y ) B T , ∂ L ∂ B = A T X T ( ∂ L ∂ Y ) . \frac{\partial L}{\partial A} = X^T\left(\frac{\partial L}{\partial Y}\right) B^T, \quad \frac{\partial L}{\partial B} = A^T X^T\left(\frac{\partial L}{\partial Y}\right). AL=XT(YL)BT,BL=ATXT(YL).

由于 s [ : r ] ≫ s [ r : ] \mathbf{s}_{[:r]} \gg \mathbf{s}_{[r:]} s[:r]s[r:],可训练适配器 W pri = A B W^{\text{pri}} = A B Wpri=AB 包含了矩阵 W W W 最重要的方向。在理想情况下,训练 A B A B AB 能够模拟微调整个模型的过程,同时使用更少的参数。直接微调模型的最重要部分使得 PiSSA 能够更快更好地收敛。

相比之下,LoRA 使用高斯噪声和零初始化适配器 A A A B B B,并冻结 W W W。因此,在微调的早期阶段,梯度可能很小或指向随机方向,浪费了许多梯度下降步骤。此外,不良的初始化可能导致找到次优的局部最小点,从而导致较差的泛化性能。


继承与扩展

由于 PiSSA 与 LoRA 具有相同的架构,它继承了 LoRA 的大部分优势,包括但不限于以下几点:

  1. 减少可训练参数数量:以低秩形式微调模型;
  2. 兼容量化:对残差模型进行量化以减少训练期间前向传播的内存消耗;
  3. 易于部署:适配器的线性结构使得在部署时能将可训练矩阵与预训练权重无缝集成,从而保持完全微调模型的原始推理速度。

通过快速 SVD 技术 [50],PiSSA 的初始化可以在几秒内完成(见附录 B),几乎不产生额外成本。


存储效率

为了提高存储效率,我们可以选择不存储密集参数矩阵 Δ W \Delta W ΔW,而只存储低秩矩阵 Δ A \Delta A ΔA Δ B \Delta B ΔB。如附录 C 所示,仅依赖 Δ A \Delta A ΔA Δ B \Delta B ΔB 即可实现与原始预训练模型的无缝集成。最后,一个预训练模型可以容纳多个由不同 PiSSA 或 LoRA 方法微调得到的 Δ A , Δ B \Delta A, \Delta B ΔA,ΔB,从而使模型能够快速适应不同的下游任务。

4 QPiSSA: PiSSA with Quantization

量化将矩阵的取值范围划分为多个连续区域,并将落入同一区域的所有值映射为相同的“量化”值。这是一种减少前向传播内存消耗的有效技术,但在反向传播中表现不佳。同时,LoRA 大大降低了反向传播的内存需求,使其非常适合与量化结合使用。其中,基础模型经过量化以实现内存高效的前向传播,而 LoRA 适配器保持全精度以便进行准确的反向参数更新。


QLoRA 的量化误差

一个典型的相关工作 QLoRA 将基础模型量化为 Normal Float 4-bit (NF4),并使用高斯-零初始化全精度的 A A A B B B。因此,总体误差表示为:

QLoRA 的量化误差 = ∥ W − ( n f 4 ( W ) + A B ) ∥ ∗ = ∥ W − n f 4 ( W ) ∥ ∗ \text{QLoRA 的量化误差} = \|W-(n f 4(W) + A B)\|_* = \|W - n f 4(W)\|_* QLoRA 的量化误差=W(nf4(W)+AB)=Wnf4(W)

其中, ∥ M ∥ ∗ \|M\|_* M 表示核范数(也称迹范数) [ 51 ] [51] [51],定义为:

∥ M ∥ ∗ = trace ⁡ ( M ∗ M ) = ∑ i = 1 min ⁡ { m , n } σ i ( M ) \|M\|_* = \operatorname{trace}\left(\sqrt{M^* M}\right) = \sum_{i=1}^{\min \{m, n\}} \sigma_i(M) M=trace(MM )=i=1min{m,n}σi(M)

σ i ( M ) \sigma_i(M) σi(M) M M M 的第 i i i 个奇异值。如上所示,QLoRA 的量化误差与直接对基础模型进行量化的误差相同。


QPiSSA 的量化误差

相比之下,我们的 QPiSSA 不对基础模型量化,而是对残差模型进行量化。因此,其误差表示为:

QPiSSA 的量化误差 = ∥ W − ( n f 4 ( W res ) + A B ) ∥ ∗ = ∥ W res − n f 4 ( W res ) ∥ ∗ \text{QPiSSA 的量化误差} = \left\|W - \left(n f 4\left(W^{\text{res}}\right) + A B\right)\right\|_* = \left\|W^{\text{res}} - n f 4\left(W^{\text{res}}\right)\right\|_* QPiSSA 的量化误差=W(nf4(Wres)+AB)=Wresnf4(Wres)

由于残差模型移除了大奇异值成分, W res W^{\text{res}} Wres 的分布比 W W W 更窄。如图 3a 和 3b(比较 W W W W res W^{\text{res}} Wres 的奇异值分布)以及图 3c 和 3f(比较 W W W W res W^{\text{res}} Wres 的取值分布)所示,这种分布特性有助于减少量化误差。


NF4 的优化

由于 NF4 针对正态分布的数据进行了优化 [ 10 ] [10] [10],我们分别对 W W W W res W^{\text{res}} Wres 的取值分布拟合高斯分布。如图 3c 和 3f 所示, W res W^{\text{res}} Wres 更接近高斯分布,并且具有更小的标准差,使其比 W W W 更适合应用 NF4。这些特性使得 QPiSSA 显著降低了量化误差(如图 3d 和 3e 所示)。


梯度方向与性能

除了减少量化误差的优势外,QPiSSA 的梯度方向与 PiSSA 相似,从而在微调性能上显著优于 QLoRA。


http://www.kler.cn/a/454317.html

相关文章:

  • 【Compose multiplatform教程12】【组件】Box组件
  • 金融租赁系统的发展与全球化战略实施探讨
  • scala基础学习(数据类型)-数组
  • 双指针——查找总价格为目标值的两个商品
  • 只谈C++11新特性 - 删除函数
  • 使用vue3搭建前端模拟增删改查
  • 社区二手物品交易小程序ssm+论文源码调试讲解
  • 如何通过HTTP API插入或更新Doc
  • Android Framework 目录下的 AV/Camera 定制常见问题及解决方法
  • Coding(Jenkinsfile)+ Docker 自动化部署 Springboot —— 图文细节和一些注意事项说明
  • 【NIFI】实现ORACLE->ORACLE数据同步
  • Springboot 整合 Duird
  • 【计算机网络安全】加密解密及其在ssh上的应用
  • 面试场景题系列:设计支付系统
  • UnoCSS 的作用与特点
  • idea配置gitee仓库
  • 讯飞语音听写WebApi(流式)【React Native版】
  • 报警推送消息升级的名厨亮灶开源了。
  • 【Django篇】--动手实践Django基础知识
  • 《Go 语言变量》
  • C语言学习笔记(1)
  • 游戏引擎学习第62天
  • Maven核心概念总结
  • Blender高效优化工作流程快捷小功能插件 Haggis Tools V1.1.5
  • jvm排查问题-实践追踪问题 与思路--堆内堆外内存泄漏排查方针
  • HarmonyOS NEXT 实战之元服务:静态案例效果---咖啡制作实况窗