当前位置: 首页 > article >正文

【自然语言处理】利用Memory Layer替换Transformer中的FFN

论文地址:https://arxiv.org/pdf/2412.09764

相关博客
【自然语言处理】利用Memory Layer替换Transformer中的FFN
【自然语言处理】【大模型】BitNet:用1-bit Transformer训练LLM
【自然语言处理】BitNet b1.58:1bit LLM时代
【自然语言处理】【长文本处理】RMT:能处理长度超过一百万token的Transformer
【自然语言处理】【大模型】MPT模型结构源码解析(单机版)
【自然语言处理】【大模型】ChatGLM-6B模型结构代码解析(单机版)
【自然语言处理】【大模型】BLOOM模型结构源码解析(单机版)

​ 本文提出了一种memory layer用于替换Transformer中的FFN,从而提升模型的知识容量。

一、Memory Layer

​ 这里定义的memory layer与注意力机制类似。即给定query q ∈ R n q\in\mathbb{R}^n qRn、一组key K ∈ R N × n K\in\mathbb{R}^{N\times n} KRN×n和一组value V ∈ R N × n V\in\mathbb{R}^{N\times n} VRN×n,最终输出value的软组合。但是,memory layer与标准注意力层有两个区别:

​ (1) 标准注意力中key和value是激活值,而memory layer中是可训练参数;

​ (2) memory layer中的key和value规模要比标准注意力大很多,需要稀疏查询和更新;

Memory Layer的正式描述
I = SelectTopkIndices ( K q ) , s = Softmax ( K I q ) , y = s V I (1) I=\text{SelectTopkIndices}(Kq),\quad s=\text{Softmax}(K_I q),\quad y=sV_I\tag{1} \\ I=SelectTopkIndices(Kq),s=Softmax(KIq),y=sVI(1)
其中:

  • I I I是选中的key-value对的索引集合;
  • s ∈ R k s\in\mathbb{R}^k sRk是权重向量;
  • K I K_I KI V I V_I VI是选中的key和value;
  • y ∈ R n y\in\mathbb{R}^n yRn是memory layer的输出;

二、计算Topk索引 I I I的优化

1. memory layer的瓶颈

​ 阻碍memory layer大规模应用的一个主要瓶颈就是query-key检索机制。一般来说,可以利用简单的最近邻搜索来比较每个query-key对,但是对于更大规模的记忆来说,这种方法并不可行。当然,也有一些快速近似向量相似度的技术,但是当key在持续训练中不断更新,则需要不断的重新索引。

2. product-key

K K K的分解。计算 I I I的主要挑战是 K K K太大,那么可以考虑用笛卡尔积的方式分解 K K K。具体来说,随机初始化 K 1 ∈ R N × n 2 K_1\in\mathbb{R}^{\sqrt{N}\times\frac{n}{2}} K1RN ×2n K 2 ∈ R N × n 2 K_2\in\mathbb{R}^{\sqrt{N}\times\frac{n}{2}} K2RN ×2n两个独立的key集合,通过 K 1 K_1 K1 K 2 K_2 K2的笛卡尔积就可以得到 K K K,即两两拼接 K 1 K_1 K1 K 2 K_2 K2中的向量:
K [ i , j ] = concat ( K 1 [ i ] , K 2 [ j ] ) K[i,j]=\text{concat}(K_1[i],K_2[j]) \\ K[i,j]=concat(K1[i],K2[j])
注意,在实际计算中索引 I I I的过程中并不需要计算出 K K K,直接利用 K 1 K_1 K1 K 2 K_2 K2即可。

查询。将query q q q也分解为两个部分 q 1 , q 2 ∈ R n 2 q_1,q_2\in\mathbb{R}^{\frac{n}{2}} q1,q2R2n,然后分别与 K 1 K_1 K1 K 2 K_2 K2进行相似度计算,得到 I 1 , I 2 I_1,I_2 I1,I2 s 1 , s 2 s_1,s_2 s1,s2。最终topk的索引以及分数为
arg ⁡ max ⁡ i 1 ∈ I 1 , i 2 ∈ I 2 s 1 [ i 1 ] + s 2 [ i 2 ] \mathop{\arg\max}_{i_1\in I_1,i_2\in I_2}\quad s_1[i_1]+s_2[i_2] \\ argmaxi1I1,i2I2s1[i1]+s2[i2]

三、并行优化

在这里插入图片描述

​ memory layer本质上是存储密集型的,其包含了大量可训练参数以及对应的优化器状态。为了能够实现包含数百万个key的memory layer,需要在多个GPU上并行化embedding的查找和聚合操作。

​ 具体来说,在embedding的维度上进行分片。每一步中,从进程组收集索引,然后每个进程在其所属的分片上进行查找和聚合操作。最后,每个进程收集与自身部分索引相对应的部分embedding。通过确保每个GPU只获取其自身那部分数据,从而无需实例化整个embedding输出,控制激活内存。

四、共享记忆

​ 在所有memory layerz中使用一个共享的记忆参数池,从而保持参数量不变并最大化参数共享。实验发现,在一定数量的层内,多个memory layer比具有相同总参数量的单个memory layer效果更好。当在超过这个数量的层内替换FFN会导致性能下降,这表明稀疏层和密集层都是必要的,而且很可能具有互补性

五、性能和稳定性提升

  • 通过自定义CUDA核,相较于Pytorch的EmbeddingBag有6倍的提升;
  • 引入silu激活函数来提高memory layer的性能。公式(1)的输出变为

output = ( y ⊙ silu ( x ⊤ W 1 ) ) ⊤ W 2 (2) \text{output} = (y\odot\text{silu}(x^\top W_1))^\top W_2 \tag{2}\\ output=(ysilu(xW1))W2(2)

其中 silu ( x ) = x ⋅ sigmoid ( x ) \text{silu}(x)=x\cdot\text{sigmoid}(x) silu(x)=xsigmoid(x) ⊙ \odot 表示逐元素相乘。

  • 大规模memory layer使得训练不稳定,特别是对小模型。使用qk-normalization来缓解这一问题;

六、实验设置

设置。遵循Llama系列的Transformer模型,然后利用共享记忆层替换一个或多个前馈层。scaling law实验中,参数规模分别是134m、373m、720m、1.3b。

baselines。除了密集型baselines外,还将比较MOE和PEER。MOE模型中,FFN层由多个“专家”组成,对于每个输入,只有一部分“专家”参与计算。PEER类似memory layer。

评估基准。NaturalQuestions、TriviaQA、HotpotQA、MMLU、HellaSwag、OBQA、PIQA、HumanEval、MBPP。

七、实验结果

1. 固定参数量

在这里插入图片描述

​ 上表展示了记忆增加型架构的效果。

​ (1) 配备记忆层的模型相较于密集型模型由显著提升,通常是同参数密集型模型的两倍;

​ (2) Memory+比Memory效果更好;

​ (3) 参数量相同下,PEER与Memory相近,但落后于Memory+;

​ (4) MOE大幅度落后于各种配备记忆层的模型。

2. 缩放memory layer

在这里插入图片描述

​ 如上图所示,在事实问答基准上,模型性能随memory的尺寸增加而增加。当拥有6400万个key后,1.3B模型的效果接近Llama2 7B的性能。

3. 8B模型结果

在这里插入图片描述


http://www.kler.cn/a/538875.html

相关文章:

  • 日志2025.2.9
  • 练习题(2025.2.9)
  • Sentinel——Spring Boot 应用接入 Sentinel 后内存开销增长计算方式
  • webpack系统学习
  • 微信小程序如何使用decimal计算金额
  • 【C语言标准库函数】指数与对数函数:exp(), log(), log10()
  • 缓存实战:Redis 与本地缓存
  • 黑马React保姆级(PPT+笔记)
  • 使用 Three.js 实现热力渐变效果
  • C++线程池
  • 如何设置爬虫的延时避免频繁请求?
  • 使用rustDesk搭建私有远程桌面
  • vue+element-ui简洁完美实现ju动漫网站
  • ASP.NET Core托管服务
  • Java 线程池内部任务出异常后,如何知道是哪个线程出了异常
  • 【Python】元组
  • Deepseek访问受限?换种方式轻松使用
  • 22.3、IIS安全分析与增强
  • 【React】实现TagInput输入框,可以输入多个邮箱并校验是否合法
  • Agent论文阅读:NormEnforcement with a Soft Touch: Faster Emergence, Happier Agents
  • 阿里云服务器XShell连接启动java -jar xxx.jar后退出ssh,后端也退出,使用screen 亲测管用!
  • 【Jetson Nano安装gpu版pytroch1.7torchvision0.8.1 in python3.8 跑 Ultralytics YOLO】
  • 关于预训练后训练、LLM和视频大模型相关学习记录
  • 周报1.0
  • 鸿蒙音视频播放器:libwlmedia
  • 如何解决 Linux 文件系统挂载失败的问题