当前位置: 首页 > article >正文

论文阅读笔记——LORA: LOW-RANK ADAPTATION OF LARGE LANGUAGE MODELS

LoRA 论文
传统全面微调,对每个任务学习的参数与原始模型相同:
m a x Φ ∑ ( x , y ) ∈ Z ∑ t = 1 ∣ y ∣ l o g ( P Φ ( y t ∣ x , y < t ) ) 式(1) max_{\Phi}\sum_{(x,y)\in Z}\sum^{|y|}_{t=1}log(P_{\Phi}(y_t|x,y<t)) \qquad \text{式(1)} maxΦ(x,y)Zt=1ylog(PΦ(ytx,y<t))(1)
LoRA 提出对模型中权重更新部分低秩分解,编码任务特定的参数,大幅减少所需参数规模,同时优化 Θ \Theta Θ 来寻找 Δ Θ \Delta \Theta ΔΘ 。对于 175B 的 GPT-3 参数量只有原来的 0.01%。
m a x Θ ∑ ( x , y ) ∈ Z ∑ t = 1 ∣ y ∣ l o g ( p Φ 0 + Δ Φ ( Θ ) ( y t ∣ x , y < t ) ) max_{\Theta}\sum_{(x,y)\in Z}\sum^{|y|}_{t=1}log(p_{\Phi_0+\Delta \Phi(\Theta})(y_t|x,y<t)) maxΘ(x,y)Zt=1ylog(pΦ0+ΔΦ(Θ)(ytx,y<t))

传统方法不足

  • 添加 adapters 的策略虽然参数少,但会在推理阶段引入延迟——增加了模型深度。并且有额外参数和计算,在模型中这些会被放大。
  • 直接优化输入层(prefix)在训练参数方面并非单调变化且保留一部分长度进行调整降低了下游任务的序列长度——占用了一部分序列长度,减少了可用的输入序列长度。并且他的优化难度也大。

LoRA

在这里插入图片描述
核心思想:对于一个预训练的 W 0 ∈ R d × k W_0 \in R^{d×k} W0Rd×k ,训练低秩矩阵 B A BA BA 来替代权重更新部分 Δ W \Delta W ΔW d i m ( A ) = r × k , d i m ( B ) = d × r r < < m i n ( d , k ) dim(A) = r×k, \quad dim(B)=d×r \quad r << min(d,k) dim(A)=r×k,dim(B)=d×rr<<min(d,k)
h = W 0 x + Δ W x = W 0 x + B A x h=W_0x+\Delta Wx=W_0x+BAx h=W0x+ΔWx=W0x+BAx
其中,A 采取随机高斯初始化,B 为 0。

LoRA 在适应期间不需要满足满秩的条件,只需要将 r 设置为预训练权重矩阵的秩,大致可恢复完全微调的能力,可以维持原来架构。

LoRA 优势:

  1. 参数高效:训练参数减少了数千倍,例如在GPT-3中,训练参数从1750亿减少到数百万甚至更少。
  2. 计算资源节省:由于需要计算梯度的参数大大减少,显存占用降低,训练速度加快。
  3. 无额外推理延迟:训练完成后,可以将低秩更新融合到预训练权重中,推理时无需额外计算。
  4. 任务切换灵活:不同任务只需存储和加载小的低秩矩阵,实现快速切换,减少存储需求。
    将 LoRA 应用于 transformer 架构中,只需要对自注意力模块 ( W q , W k , W v , W 0 ) (W_q,W_k,W_v,W_0) (Wq,Wk,Wv,W0) 中的 W q , W v W_q,W_v Wq,Wv 进行适应,保持 MLP 不变。
    在下游部署时,只需要减去 B A BA BA 即可恢复 W 0 W_0 W0,再根据任务需求加上对应 B ′ A ′ B^{'}A^{'} BA 。最明显的好处在于内存和存储使用量减少。

实验结果

在这里插入图片描述

代码实现

class LoRALayer():
    def __init__(
        self, 
        r: int, 
        lora_alpha: int, 
        lora_dropout: float,
        merge_weights: bool,
    ):
        self.r = r
        self.lora_alpha = lora_alpha
        # Optional dropout
        if lora_dropout > 0.:
            self.lora_dropout = nn.Dropout(p=lora_dropout)
        else:
            self.lora_dropout = lambda x: x
        # Mark the weight as unmerged
        self.merged = False
        self.merge_weights = merge_weights
class Embedding(nn.Embedding, LoRALayer):
    def __init__(
        self,
        num_embeddings: int,
        embedding_dim: int,
        r: int = 0,
        lora_alpha: int = 1,
        merge_weights: bool = True,
        **kwargs
    ):
        nn.Embedding.__init__(self, num_embeddings, embedding_dim, **kwargs)
        LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=0,
                           merge_weights=merge_weights)

        if r > 0:
            self.lora_A = nn.Parameter(self.weight.new_zeros((r, num_embeddings)))
            self.lora_B = nn.Parameter(self.weight.new_zeros((embedding_dim, r)))
            self.scaling = self.lora_alpha / self.r
            # 冻结预训练权重
            self.weight.requires_grad = False
        self.reset_parameters()

    def reset_parameters(self):
        nn.Embedding.reset_parameters(self)
        if hasattr(self, 'lora_A'):
            nn.init.zeros_(self.lora_A)
            nn.init.normal_(self.lora_B)

    def train(self, mode: bool = True):
        nn.Embedding.train(self, mode)
        if mode:
            if self.merge_weights and self.merged:
                if self.r > 0:
                    self.weight.data -= (self.lora_B @ self.lora_A).transpose(0, 1) * self.scaling
                self.merged = False
        else:
            if self.merge_weights and not self.merged:
                # Merge the weights and mark it
                if self.r > 0:
                    self.weight.data += (self.lora_B @ self.lora_A).transpose(0, 1) * self.scaling
                self.merged = True
                
	    def forward(self, x: torch.Tensor):
	        if self.r > 0 and not self.merged:
	            result = nn.Embedding.forward(self, x)
	            after_A = F.embedding(
	                x, self.lora_A.transpose(0, 1), self.padding_idx, self.max_norm,
	                self.norm_type, self.scale_grad_by_freq, self.sparse
	            )
	            result += (after_A @ self.lora_B.transpose(0, 1)) * self.scaling
	            return result
	        else:
	            return nn.Embedding.forward(self, x)

http://www.kler.cn/a/585293.html

相关文章:

  • 【RISCV LAB】0x01-安装实验仿真辅助工具
  • AI建模智能生成:从2D到3D,AI只需一步!
  • 结构型模式之适配器模式:让不兼容的接口兼容
  • 工业数采适配99%协议EG8200Mini 边缘计算网关
  • 【零基础入门unity游戏开发——unity3D篇】3D物理系统之 —— 碰撞检测和触发器检测的特殊生命周期函数
  • 【QT】认识 QT 安装 QT 相关软件
  • YOLOv12优化之区域注意力机制(A2)和残差高效层聚合网络(R-ELAN)
  • 【第七节】windows sdk编程:Windows 中的对话框
  • 计算机安全 第四节:访问控制(中)
  • 学习threejs,使用MeshFaceMaterial面材质容器
  • pom.xml中配置的repository,在编译器下载依赖包没生效,怎么解决
  • CNN 稠密任务经典结构
  • 市场趋势分析与策略优化
  • DeepSeek 加持!IvorySQL 文档智能助手正式上线!
  • 疗养院管理系统设计与实现(代码+数据库+LW)
  • 【Rust基础】Rust后端开发常用库
  • linux 命令 case
  • CNN的激活函数
  • 印刷店常用的PDF批量页码统计软件
  • 网络测试工具:涵盖网络测速、密码查看、故障判断与网络监测