当前位置: 首页 > article >正文

MiniMax-01中Lightning Attention的由来(线性注意力进化史)

目录

  • 引言
  • 原始注意力
  • 线性注意力
  • 因果模型存在的问题
  • 累加求和操作的限制
  • Lightning Attention
    • Lightning Attention-1
    • Lightning Attention-2
  • 备注

引言

MiniMax-01: Scaling Foundation Models with Lightning Attention表明自己是第一个将线性注意力应用到如此大规模的模型,他所使用的核心技术就是Lightning Attention。

那为什么线性注意力20年在文章Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention中就提出了,现在才出第一个线性注意力的大模型呢?

本文就从线性注意力机制入手,详细探讨其起源、存在的显著局限性,以及Lightning Attention的具体实现细节。

原始注意力

现在主流的有两类模型,一种是应用双向注意力的bert类模型,另一种是应用单向注意力的gpt类模型,他们所使用的注意力其实是有细微差别的。

  • 双向注意力(bert类),就是传统认知中标准的注意力

Attention ⁡ ( Q , K , V ) = softmax ⁡ ( Q K T d ) V \operatorname{Attention}(Q,K,V)=\operatorname{softmax}(\frac{QK^T}{\sqrt{d_\text{}}})V Attention(Q,K,V)=softmax(d QKT)V

  • 单向注意力(因果模型,gpt类),只能看到当前和前面的token,所有要在softmax之前乘上一个掩码矩阵,M为单向掩码矩阵

Attention ⁡ ( Q , K , V ) = softmax ⁡ ( Q K T ⊙ M d ) V \operatorname{Attention}(Q,K,V)=\operatorname{softmax}(\frac{QK^T\odot M}{\sqrt{d_\text{}}})V Attention(Q,K,V)=softmax(d QKTM)V

其中Q、K、V每个矩阵的维度都是[n, d],即[序列长度,隐层维度],此时 Q K T QK^T QKT的维度是[n, n],所以整体复杂度是 O ( n 2 d ) O(n^2d) O(n2d)。其中d是固定大小, n 2 n^2 n2随着序列长度平方增加,就主导了整体的复杂度。

线性注意力

原始出处:Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention

注意力计算可抽象成如下形式,sim表示可以使用任何的相似度函数,不一定非要softmax,可以是类似于多项式注意力或RBF核注意力等

V i ′ = ∑ j = 1 N sim ⁡ ( Q i , K j ) V j ∑ j = 1 N sim ⁡ ( Q i , K j ) . V_i^{^{\prime}}=\frac{\sum_{j=1}^{N}\operatorname{sim}\left(Q_i,K_j\right)V_j}{\sum_{j=1}^{N}\operatorname{sim}\left(Q_i,K_j\right)}. Vi=j=1Nsim(Qi,Kj)j=1Nsim(Qi,Kj)Vj.

但相似度函数也是有要求的,需要施加唯一约束sim(·)非负。基于这个条件,给定这样一个内核 ϕ ( x ) \phi(x) ϕ(x),可以将上式重写为

V i = ∑ j = 1 N ϕ ( Q i ) T ϕ ( K j ) V j ∑ j = 1 N ϕ ( Q i ) T ϕ ( K j ) V_i = \frac{\sum_{j=1}^N \phi(Q_i)^T \phi(K_j) V_j}{\sum_{j=1}^N \phi(Q_i)^T \phi(K_j)} Vi=j=1Nϕ(Qi)Tϕ(Kj)j=1Nϕ(Qi)Tϕ(Kj)Vj

利用矩阵乘法结合律得到

V i = ϕ ( Q i ) T ∑ j = 1 N ϕ ( K j ) V j T ϕ ( Q i ) T ∑ j = 1 N ϕ ( K j ) V_i = \frac{ \phi(Q_i)^T \sum_{j=1}^N\phi(K_j) V_j^T}{ \phi(Q_i)^T \sum_{j=1}^N\phi(K_j)} Vi=ϕ(Qi)Tj=1Nϕ(Kj)ϕ(Qi)Tj=1Nϕ(Kj)VjT

为简化理解可写成如下形式

Attention ⁡ ( Q , K , V ) = ( ϕ ( Q ) ϕ ( K ) T ) V = ϕ ( Q ) ( ϕ ( K ) T V ) \operatorname{Attention}(Q,K,V)=(\phi(Q)\phi(K)^T)V=\phi(Q)(\phi(K)^TV) Attention(Q,K,V)=(ϕ(Q)ϕ(K)T)V=ϕ(Q)(ϕ(K)TV)

深层理解:每个时间步的K和V可以提前计算并作为单个向量存储下来,在推理生成时直接用Q乘以每个时间步的KV,简单情况下KV cache的缓存向量都变少了,可以像RNN一样每次预测的时间都几乎是恒定的

注意:此时使用的一般都是绝对位置编码,Q、K矩阵没有乘以参数矩阵

此时 ϕ ( K ) T V \phi(K)^TV ϕ(K)TV的复杂度是 O ( d 2 ) O(d^2) O(d2),所以整体复杂度变成了 O ( n d 2 ) O(nd^2) O(nd2),随着序列长度n线性增长,此时就是线性注意力了。

(可选):通常线性注意力的公式还有如下形式

O = Δ − 1 ∗ ( Q ∗ K T ∗ V ) O = Δ^{-1} * (Q * K^T * V) O=Δ1(QKTV)

(可选)其中,Δ起到了归一化的作用。Δ的每个对角元素是 K T ∗ 1 K^T*1 KT1的值,这反映了每个键向量的重要程度。将 Δ − 1 Δ^{-1} Δ1乘到结果上,就相当于对注意力输出进行了逆归一化。相当于只对K归一化,Q本身就是一个合适的查询向量,不需要归一化。

因果模型存在的问题

注意上面的线性注意力是类bert模型的情况下,并没有与掩码矩阵相乘,此时可以顺畅的先右乘来降低复杂度。但现在的大模型都是生成模型,使用的因果模型结构,都是单向注意力,就必须要乘以掩码矩阵,所以不能顺畅的右乘了。
左乘线性注意力公式如下,输出为O,每个step的输出为当前的 q t q_t qt乘以前面的 k j k_j kj,再乘以 v j v_j vj累加求和。此时 Q K T QK^T QKT可以正常进行矩阵运算,然后使用 ⊙ \odot (Hadamard Product)进行逐元素相乘,得到掩码后的矩阵。

O = ( Q K T ⊙ M ) V O=(QK^T\odot M)V O=(QKTM)V

o t = ∑ j = 1 t ( q t T k j ) v j o_t=\sum_{j=1}^t(q_t^Tk_j)v_j ot=j=1t(qtTkj)vj

此时注意,上面公式的运算涉及 ⊙ \odot ,它不适用于矩阵乘法交换律和结合律,即无法 Q ( K T ⊙ M V ) Q(K^T\odot MV) Q(KTMV) ⊙ \odot 是逐元素相乘,所以两个矩阵的维度必须相同,即使将M的位置放到前面, K T V K^TV KTV的维度是[d, d],也无法与M逐元素相乘。

累加求和操作的限制

双向注意力模型(bert)中使用的线性注意力如下,可以先算KV

( ϕ ( Q ) ϕ ( K ) T ) V = ϕ ( Q ) ( ϕ ( K ) T V ) (\phi(Q)\phi(K)^T)V=\phi(Q)(\phi(K)^TV) (ϕ(Q)ϕ(K)T)V=ϕ(Q)(ϕ(K)TV)

QKV的维度都为[n, d],这里假设序列长度为4,双向和单向注意力如下图

在这里插入图片描述

  • 双向注意力计算
    K和V的矩阵如下,得到的 K T V K^TV KTV的维度是[d, d]

K T = [ k 1 T k 2 T k 3 T k 4 T ] = [ k 11 k 21 k 31 k 41 k 12 k 22 k 32 k 42 ⋮ ⋮ ⋮ ⋮ k 1 d k 2 d k 3 d k 4 d ] K^{T}= \begin{bmatrix} k_{1}^T & k_{2}^T & k_{3}^T & k_{4}^T \\ \end{bmatrix}= \begin{bmatrix} k_{11} & k_{21} & k_{31} & k_{41} \\ k_{12} & k_{22} & k_{32} & k_{42} \\ \vdots & \vdots & \vdots & \vdots \\ k_{1d} & k_{2d} & k_{3d} & k_{4d}\\ \end{bmatrix} KT=[k1Tk2Tk3Tk4T]= k11k12k1dk21k22k2dk31k32k3dk41k42k4d

V = [ v 1 v 2 v 3 v 4 ] = [ v 11 v 12 . . . v 1 d v 21 v 22 . . . v 2 d v 31 v 32 . . . v 3 d v 41 v 42 . . . v 4 d ] V= \begin{bmatrix} v_{1} \\ v_{2} \\ v_{3} \\ v_{4} \\ \end{bmatrix}= \begin{bmatrix} v_{11} & v_{12} & ... & v_{1d} \\ v_{21} & v_{22} & ... & v_{2d} \\ v_{31} & v_{32} & ... & v_{3d} \\ v_{41} & v_{42} & ... & v_{4d} \end{bmatrix} V= v1v2v3v4 = v11v21v31v41v12v22v32v42............v1dv2dv3dv4d

K T V = [ k 1 T v 1 + k 2 T v 2 + k 3 T v 3 + k 4 T v 4 ] = [ [ K T V ] 1 [ K T V ] 2 . . . [ K T V ] d ] K^{T}V= \begin{bmatrix} k_{1}^Tv_1 + k_{2}^Tv_2 + k_{3}^Tv_3 + k_{4}^Tv_4 \\ \end{bmatrix}= \begin{bmatrix} [K^{T}V]_{1} & [K^{T}V]_{2} & ... & [K^{T}V]_{d} \end{bmatrix} KTV=[k1Tv1+k2Tv2+k3Tv3+k4Tv4]=[[KTV]1[KTV]2...[KTV]d]

此时计算 q 3 q_3 q3的注意力输出就可以使用以下方法。注意这是点积,q3是一个向量, K T V K^{T}V KTV是一个矩阵,向量在与矩阵点积的时候会进行广播拓展,复制成多份分别与矩阵中的向量点积。 [ K T V ] 1 [K^{T}V]_{1} [KTV]1是一个向量, q 3 [ K T V ] 1 q_3[K^{T}V]_{1} q3[KTV]1点积后会得到一个值,所以 q 3 K T V q_3K^{T}V q3KTV最终的结果是一个向量,长度为隐层维度d。

q 3 K T V = q 3 [ [ K T V ] 1 [ K T V ] 2 . . . [ K T V ] d ] = [ q 3 [ K T V ] 1 q 3 [ K T V ] 2 . . . q 3 [ K T V ] d ] q_3K^{T}V= q_3 \begin{bmatrix} [K^{T}V]_{1} & [K^{T}V]_{2} & ... & [K^{T}V]_{d} \end{bmatrix}= \begin{bmatrix} q_3[K^{T}V]_{1} & q_3[K^{T}V]_{2} & ... & q_3[K^{T}V]_{d} \end{bmatrix} q3KTV=q3[[KTV]1[KTV]2...[KTV]d]=[q3[KTV]1q3[KTV]2...q3[KTV]d]

也可以使用以下代码测试

import torch
q3 = torch.tensor([1, 2, 3, 4, 5, 6])
print(q3)

# [n, d] = [4, 6]
kT = torch.tensor([[1, 1, 1, 1], 
                   [2, 2, 2, 2], 
                   [3, 3, 3, 3], 
                   [4, 4, 4, 4],
                   [5, 5, 5, 5],
                   [6, 6, 6, 6]])
v = torch.tensor([[1, 1, 1, 1, 1, 1], 
                  [1, 1, 1, 1, 1, 1], 
                  [1, 1, 1, 1, 1, 1], 
                  [1, 1, 1, 1, 1, 1]])

print('kT @ v', kT @ v)
# q与(k.T @ v)的点积
result = torch.matmul(q3, kT @ v)
print('result', result)

在这里插入图片描述

此时 K T V K^TV KTV的结果是双向的, k 3 k_3 k3的输出矩阵中使用了 v 4 v_4 v4,这样双向注意力就可以顺畅的右乘得到 K T V K^TV KTV结果再与Q相乘,得到所有token的输出。

但因果模型的注意力是单向的, K T V K^TV KTV在计算的时候前面的K不能与后面的V相乘,所以只能一个一个算然后累加求和。

o 1 = q 1 ( k 1 T v 1 ) o_1 = q_1(k_1^Tv_1) o1=q1(k1Tv1)

o 2 = q 2 ( k 1 T v 1 + k 2 T v 2 ) o_2 = q_2(k_1^Tv_1+k_2^Tv_2) o2=q2(k1Tv1+k2Tv2)

o 3 = q 3 ( k 1 T v 1 + k 2 T v 2 + k 3 T v 3 ) o_3 = q_3(k_1^Tv_1+k_2^Tv_2+k_3^Tv_3) o3=q3(k1Tv1+k2Tv2+k3Tv3)

o 4 = q 4 ( k 1 T v 1 + k 2 T v 2 + k 3 T v 3 + k 4 T v 4 ) o_4 = q_4(k_1^Tv_1+k_2^Tv_2+k_3^Tv_3+k_4^Tv_4) o4=q4(k1Tv1+k2Tv2+k3Tv3+k4Tv4)

这样的累加操作无法进行高效的矩阵乘法,虽然计算复杂度降低了,但实际运算的效率并不高。

Lightning Attention

到这里可以引出MiniMax-01 中所使用的Lightning Attention了,但其实这个注意力有两个版本,MiniMax-01中所提到的就是是Lightning Attention-2,那咱们先看看第一个版本做了什么。

Lightning Attention-1

源自:TransNormerLLM: A Faster and Better Large Language Model with Improved TransNormer

Lightning Attention-1针对于原始注意力取消了softmax,使用Swish激活函数代替。即先变成了
Attention ⁡ ( Q , K , V ) = ( ϕ ( Q ) ϕ ( K ) T ⊙ M ) V \operatorname{Attention}(Q,K,V)=(\phi(Q)\phi(K)^T\odot M)V Attention(Q,K,V)=(ϕ(Q)ϕ(K)TM)V
然后还是先左乘计算,并没有解决线性注意力的根本问题,但是借鉴了flash attention中的硬件加速。

其前向和反向传播流程如下,就是将QKV切块,放到高速SRAM中去计算。虽然变快了,但此时的复杂度还是 O ( n 2 d ) O(n^2d) O(n2d)
在这里插入图片描述
在这里插入图片描述

Lightning Attention-2

源自:Lightning Attention-2: A Free Lunch for Handling Unlimited Sequence Lengths in Large Language Models

Lightning Attention-2解决了因果模型在计算单向注意力时,需要进行累加求和操作导致无法矩阵运算的情况,实现了单向注意力先计算右乘,成功将复杂度降为 O ( n d 2 ) O(nd^2) O(nd2)
o 1 = q 1 ( k 1 T v 1 ) o_1 = q_1(k_1^Tv_1) o1=q1(k1Tv1)

o 2 = q 2 ( k 1 T v 1 + k 2 T v 2 ) o_2 = q_2(k_1^Tv_1+k_2^Tv_2) o2=q2(k1Tv1+k2Tv2)

o 3 = q 3 ( k 1 T v 1 + k 2 T v 2 + k 3 T v 3 ) o_3 = q_3(k_1^Tv_1+k_2^Tv_2+k_3^Tv_3) o3=q3(k1Tv1+k2Tv2+k3Tv3)

o 4 = q 4 ( k 1 T v 1 + k 2 T v 2 + k 3 T v 3 + k 4 T v 4 ) o_4 = q_4(k_1^Tv_1+k_2^Tv_2+k_3^Tv_3+k_4^Tv_4) o4=q4(k1Tv1+k2Tv2+k3Tv3+k4Tv4)

再将这个累加求和公式拿过来,配合下图观察发现,之前的问题是每次计算 Q K T QK^T QKT都在整个序列上计算,这样每次都是所有序列的token互相注意到。那如果在序列这个维度拆分成小份,比如图中右侧先计算 k 1 k_1 k1 k 2 k_2 k2,然后用于 q 3 q_3 q3的计算就完全没有问题, k 4 k_4 k4后面的就不计算了。这样就既能矩阵运算,又能符合单向掩码。

公式中也可以发现,当前step之前的k和v是可以相乘的,比如 q 3 q_3 q3在计算时,可以将 k 1 T v 1 + k 2 T v 2 + k 3 T v 3 k_1^Tv_1+k_2^Tv_2+k_3^Tv_3 k1Tv1+k2Tv2+k3Tv3使用矩阵操作运算。所以Lightning Attention-2将大矩阵拆开,类似flash attention拆成多个block。
在这里插入图片描述
这些 block 不能拆分成 n 份,这样block的意义就没有了,for循环计算反而更慢。所以每个 block 中会有多个时间步的token。

此时这些 block 就可以分为两类,一类是块内(intra block),一类是块间(inter block)。块内代表当前块 q 的序列下标和 kv 序列下标相同,块间即不同。

块内在计算 q i q_i qi时直接矩阵右乘很容易算上 k i + 1 v i + 1 k_{i+1}v_{i+1} ki+1vi+1,所以块内使用传统的左乘并与掩码矩阵相乘。块间计算时就可以先右乘计算 K t V K^tV KtV,因为之前的kv是可以双向注意力的。然后将之前的kv结果缓存下来并更新,用于下一个step计算。

下图是Lightning Attention-2的结构图, λ \lambda λ是它的模型所使用的位置编码,忽略即可。
在这里插入图片描述
以下是前向传播和反向传播流程。
在这里插入图片描述
在这里插入图片描述
问题:M矩阵维度是[B, B],相当于每一个块代表了多个序列步n,在对角线位置是1,那在这个块内前面的q就可以注意到后面的kv了

解答:M矩阵维度虽然是[B, B],但只是这么切割,其内部值仍然是下三角。

备注

个人理解,若有不对请指出,谢谢。


http://www.kler.cn/a/520966.html

相关文章:

  • 实现B-树
  • 【数据结构】 并查集 + 路径压缩与按秩合并 python
  • 算法题(48):反转链表
  • Vue3 30天精进之旅:Day01 - 初识Vue.js的奇妙世界
  • 云计算架构学习之LNMP架构部署、架构拆分、负载均衡-会话保持
  • 嵌入式基础 -- PCIe 控制器中断管理之MSI与MSI-X简介
  • API接口设计模板
  • Zotero中使用Deepseek翻译
  • 基于Python的哔哩哔哩综合热门数据分析系统的设计与实现
  • 小程序开发实战:记录一天的 Bug 修复历程
  • 绘制决策树尝试2 内含添加环境变量步骤
  • AIGC时代下的Vue组件开发深度探索
  • Centos7系统php8编译安装ImageMagick/Imagick扩展教程整理
  • 数据结构课设——模糊查询汉字和其位置
  • 机器学习2 (笔记)(朴素贝叶斯,集成学习,KNN和matlab运用)
  • 推箱子游戏
  • 第04章 17 实现一个逐步收缩球体的视觉效果
  • 分布式系统学习:小结
  • 从项目复查做一些TypeScript使用上的总结
  • 多模态论文笔记——VDT
  • ZooKeeper 数据模型
  • react-native网络调试工具Reactotron保姆级教程
  • java8-日期时间Api
  • 83,【7】BUUCTF WEB [MRCTF2020]你传你[特殊字符]呢
  • PyCharm接入DeepSeek实现AI编程
  • 【2024年华为OD机试】 (C卷,200分)- 机器人走迷宫(JavaScriptJava PythonC/C++)