MiniMax-01中Lightning Attention的由来(线性注意力进化史)
目录
- 引言
- 原始注意力
- 线性注意力
- 因果模型存在的问题
- 累加求和操作的限制
- Lightning Attention
- Lightning Attention-1
- Lightning Attention-2
- 备注
引言
MiniMax-01: Scaling Foundation Models with Lightning Attention表明自己是第一个将线性注意力应用到如此大规模的模型,他所使用的核心技术就是Lightning Attention。
那为什么线性注意力20年在文章Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention中就提出了,现在才出第一个线性注意力的大模型呢?
本文就从线性注意力机制入手,详细探讨其起源、存在的显著局限性,以及Lightning Attention的具体实现细节。
原始注意力
现在主流的有两类模型,一种是应用双向注意力的bert
类模型,另一种是应用单向注意力的gpt
类模型,他们所使用的注意力其实是有细微差别的。
- 双向注意力(bert类),就是传统认知中标准的注意力
Attention ( Q , K , V ) = softmax ( Q K T d ) V \operatorname{Attention}(Q,K,V)=\operatorname{softmax}(\frac{QK^T}{\sqrt{d_\text{}}})V Attention(Q,K,V)=softmax(dQKT)V
- 单向注意力(因果模型,
gpt
类),只能看到当前和前面的token
,所有要在softmax
之前乘上一个掩码矩阵,M
为单向掩码矩阵
Attention ( Q , K , V ) = softmax ( Q K T ⊙ M d ) V \operatorname{Attention}(Q,K,V)=\operatorname{softmax}(\frac{QK^T\odot M}{\sqrt{d_\text{}}})V Attention(Q,K,V)=softmax(dQKT⊙M)V
其中Q、K、V
每个矩阵的维度都是[n, d]
,即[序列长度,隐层维度]
,此时
Q
K
T
QK^T
QKT的维度是[n, n]
,所以整体复杂度是
O
(
n
2
d
)
O(n^2d)
O(n2d)。其中d是固定大小,
n
2
n^2
n2随着序列长度平方增加,就主导了整体的复杂度。
线性注意力
原始出处:Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention
注意力计算可抽象成如下形式,sim表示可以使用任何的相似度函数,不一定非要softmax,可以是类似于多项式注意力或RBF核注意力等
V i ′ = ∑ j = 1 N sim ( Q i , K j ) V j ∑ j = 1 N sim ( Q i , K j ) . V_i^{^{\prime}}=\frac{\sum_{j=1}^{N}\operatorname{sim}\left(Q_i,K_j\right)V_j}{\sum_{j=1}^{N}\operatorname{sim}\left(Q_i,K_j\right)}. Vi′=∑j=1Nsim(Qi,Kj)∑j=1Nsim(Qi,Kj)Vj.
但相似度函数也是有要求的,需要施加唯一约束sim(·)非负。基于这个条件,给定这样一个内核 ϕ ( x ) \phi(x) ϕ(x),可以将上式重写为
V i = ∑ j = 1 N ϕ ( Q i ) T ϕ ( K j ) V j ∑ j = 1 N ϕ ( Q i ) T ϕ ( K j ) V_i = \frac{\sum_{j=1}^N \phi(Q_i)^T \phi(K_j) V_j}{\sum_{j=1}^N \phi(Q_i)^T \phi(K_j)} Vi=∑j=1Nϕ(Qi)Tϕ(Kj)∑j=1Nϕ(Qi)Tϕ(Kj)Vj
利用矩阵乘法结合律得到
V i = ϕ ( Q i ) T ∑ j = 1 N ϕ ( K j ) V j T ϕ ( Q i ) T ∑ j = 1 N ϕ ( K j ) V_i = \frac{ \phi(Q_i)^T \sum_{j=1}^N\phi(K_j) V_j^T}{ \phi(Q_i)^T \sum_{j=1}^N\phi(K_j)} Vi=ϕ(Qi)T∑j=1Nϕ(Kj)ϕ(Qi)T∑j=1Nϕ(Kj)VjT
为简化理解可写成如下形式
Attention ( Q , K , V ) = ( ϕ ( Q ) ϕ ( K ) T ) V = ϕ ( Q ) ( ϕ ( K ) T V ) \operatorname{Attention}(Q,K,V)=(\phi(Q)\phi(K)^T)V=\phi(Q)(\phi(K)^TV) Attention(Q,K,V)=(ϕ(Q)ϕ(K)T)V=ϕ(Q)(ϕ(K)TV)
深层理解:每个时间步的K和V可以提前计算并作为单个向量存储下来,在推理生成时直接用Q乘以每个时间步的KV,简单情况下KV cache的缓存向量都变少了,可以像RNN一样每次预测的时间都几乎是恒定的
注意:此时使用的一般都是绝对位置编码,Q、K矩阵没有乘以参数矩阵
此时
ϕ
(
K
)
T
V
\phi(K)^TV
ϕ(K)TV的复杂度是
O
(
d
2
)
O(d^2)
O(d2),所以整体复杂度变成了
O
(
n
d
2
)
O(nd^2)
O(nd2),随着序列长度n
线性增长,此时就是线性注意力了。
(可选):通常线性注意力的公式还有如下形式
O = Δ − 1 ∗ ( Q ∗ K T ∗ V ) O = Δ^{-1} * (Q * K^T * V) O=Δ−1∗(Q∗KT∗V)
(可选)其中,Δ起到了归一化的作用。Δ的每个对角元素是 K T ∗ 1 K^T*1 KT∗1的值,这反映了每个键向量的重要程度。将 Δ − 1 Δ^{-1} Δ−1乘到结果上,就相当于对注意力输出进行了逆归一化。相当于只对K归一化,Q本身就是一个合适的查询向量,不需要归一化。
因果模型存在的问题
注意上面的线性注意力是类bert模型的情况下,并没有与掩码矩阵相乘,此时可以顺畅的先右乘来降低复杂度。但现在的大模型都是生成模型,使用的因果模型结构,都是单向注意力,就必须要乘以掩码矩阵,所以不能顺畅的右乘了。
左乘线性注意力公式如下,输出为O,每个step的输出为当前的
q
t
q_t
qt乘以前面的
k
j
k_j
kj,再乘以
v
j
v_j
vj累加求和。此时
Q
K
T
QK^T
QKT可以正常进行矩阵运算,然后使用
⊙
\odot
⊙(Hadamard Product)进行逐元素相乘,得到掩码后的矩阵。
O = ( Q K T ⊙ M ) V O=(QK^T\odot M)V O=(QKT⊙M)V
o t = ∑ j = 1 t ( q t T k j ) v j o_t=\sum_{j=1}^t(q_t^Tk_j)v_j ot=j=1∑t(qtTkj)vj
此时注意,上面公式的运算涉及 ⊙ \odot ⊙,它不适用于矩阵乘法交换律和结合律,即无法 Q ( K T ⊙ M V ) Q(K^T\odot MV) Q(KT⊙MV)。 ⊙ \odot ⊙是逐元素相乘,所以两个矩阵的维度必须相同,即使将M的位置放到前面, K T V K^TV KTV的维度是[d, d],也无法与M逐元素相乘。
累加求和操作的限制
双向注意力模型(bert)中使用的线性注意力如下,可以先算KV
( ϕ ( Q ) ϕ ( K ) T ) V = ϕ ( Q ) ( ϕ ( K ) T V ) (\phi(Q)\phi(K)^T)V=\phi(Q)(\phi(K)^TV) (ϕ(Q)ϕ(K)T)V=ϕ(Q)(ϕ(K)TV)
QKV的维度都为[n, d],这里假设序列长度为4,双向和单向注意力如下图
- 双向注意力计算
K和V的矩阵如下,得到的 K T V K^TV KTV的维度是[d, d]
K T = [ k 1 T k 2 T k 3 T k 4 T ] = [ k 11 k 21 k 31 k 41 k 12 k 22 k 32 k 42 ⋮ ⋮ ⋮ ⋮ k 1 d k 2 d k 3 d k 4 d ] K^{T}= \begin{bmatrix} k_{1}^T & k_{2}^T & k_{3}^T & k_{4}^T \\ \end{bmatrix}= \begin{bmatrix} k_{11} & k_{21} & k_{31} & k_{41} \\ k_{12} & k_{22} & k_{32} & k_{42} \\ \vdots & \vdots & \vdots & \vdots \\ k_{1d} & k_{2d} & k_{3d} & k_{4d}\\ \end{bmatrix} KT=[k1Tk2Tk3Tk4T]= k11k12⋮k1dk21k22⋮k2dk31k32⋮k3dk41k42⋮k4d
V = [ v 1 v 2 v 3 v 4 ] = [ v 11 v 12 . . . v 1 d v 21 v 22 . . . v 2 d v 31 v 32 . . . v 3 d v 41 v 42 . . . v 4 d ] V= \begin{bmatrix} v_{1} \\ v_{2} \\ v_{3} \\ v_{4} \\ \end{bmatrix}= \begin{bmatrix} v_{11} & v_{12} & ... & v_{1d} \\ v_{21} & v_{22} & ... & v_{2d} \\ v_{31} & v_{32} & ... & v_{3d} \\ v_{41} & v_{42} & ... & v_{4d} \end{bmatrix} V= v1v2v3v4 = v11v21v31v41v12v22v32v42............v1dv2dv3dv4d
K T V = [ k 1 T v 1 + k 2 T v 2 + k 3 T v 3 + k 4 T v 4 ] = [ [ K T V ] 1 [ K T V ] 2 . . . [ K T V ] d ] K^{T}V= \begin{bmatrix} k_{1}^Tv_1 + k_{2}^Tv_2 + k_{3}^Tv_3 + k_{4}^Tv_4 \\ \end{bmatrix}= \begin{bmatrix} [K^{T}V]_{1} & [K^{T}V]_{2} & ... & [K^{T}V]_{d} \end{bmatrix} KTV=[k1Tv1+k2Tv2+k3Tv3+k4Tv4]=[[KTV]1[KTV]2...[KTV]d]
此时计算 q 3 q_3 q3的注意力输出就可以使用以下方法。注意这是点积,q3是一个向量, K T V K^{T}V KTV是一个矩阵,向量在与矩阵点积的时候会进行广播拓展,复制成多份分别与矩阵中的向量点积。 [ K T V ] 1 [K^{T}V]_{1} [KTV]1是一个向量, q 3 [ K T V ] 1 q_3[K^{T}V]_{1} q3[KTV]1点积后会得到一个值,所以 q 3 K T V q_3K^{T}V q3KTV最终的结果是一个向量,长度为隐层维度d。
q 3 K T V = q 3 [ [ K T V ] 1 [ K T V ] 2 . . . [ K T V ] d ] = [ q 3 [ K T V ] 1 q 3 [ K T V ] 2 . . . q 3 [ K T V ] d ] q_3K^{T}V= q_3 \begin{bmatrix} [K^{T}V]_{1} & [K^{T}V]_{2} & ... & [K^{T}V]_{d} \end{bmatrix}= \begin{bmatrix} q_3[K^{T}V]_{1} & q_3[K^{T}V]_{2} & ... & q_3[K^{T}V]_{d} \end{bmatrix} q3KTV=q3[[KTV]1[KTV]2...[KTV]d]=[q3[KTV]1q3[KTV]2...q3[KTV]d]
也可以使用以下代码测试
import torch
q3 = torch.tensor([1, 2, 3, 4, 5, 6])
print(q3)
# [n, d] = [4, 6]
kT = torch.tensor([[1, 1, 1, 1],
[2, 2, 2, 2],
[3, 3, 3, 3],
[4, 4, 4, 4],
[5, 5, 5, 5],
[6, 6, 6, 6]])
v = torch.tensor([[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1]])
print('kT @ v', kT @ v)
# q与(k.T @ v)的点积
result = torch.matmul(q3, kT @ v)
print('result', result)
此时 K T V K^TV KTV的结果是双向的, k 3 k_3 k3的输出矩阵中使用了 v 4 v_4 v4,这样双向注意力就可以顺畅的右乘得到 K T V K^TV KTV结果再与Q相乘,得到所有token的输出。
但因果模型的注意力是单向的, K T V K^TV KTV在计算的时候前面的K不能与后面的V相乘,所以只能一个一个算然后累加求和。
o 1 = q 1 ( k 1 T v 1 ) o_1 = q_1(k_1^Tv_1) o1=q1(k1Tv1)
o 2 = q 2 ( k 1 T v 1 + k 2 T v 2 ) o_2 = q_2(k_1^Tv_1+k_2^Tv_2) o2=q2(k1Tv1+k2Tv2)
o 3 = q 3 ( k 1 T v 1 + k 2 T v 2 + k 3 T v 3 ) o_3 = q_3(k_1^Tv_1+k_2^Tv_2+k_3^Tv_3) o3=q3(k1Tv1+k2Tv2+k3Tv3)
o 4 = q 4 ( k 1 T v 1 + k 2 T v 2 + k 3 T v 3 + k 4 T v 4 ) o_4 = q_4(k_1^Tv_1+k_2^Tv_2+k_3^Tv_3+k_4^Tv_4) o4=q4(k1Tv1+k2Tv2+k3Tv3+k4Tv4)
这样的累加操作无法进行高效的矩阵乘法,虽然计算复杂度降低了,但实际运算的效率并不高。
Lightning Attention
到这里可以引出MiniMax-01
中所使用的Lightning Attention
了,但其实这个注意力有两个版本,MiniMax-01
中所提到的就是是Lightning Attention-2
,那咱们先看看第一个版本做了什么。
Lightning Attention-1
源自:TransNormerLLM: A Faster and Better Large Language Model with Improved TransNormer
Lightning Attention-1针对于原始注意力取消了softmax,使用Swish激活函数代替。即先变成了
Attention
(
Q
,
K
,
V
)
=
(
ϕ
(
Q
)
ϕ
(
K
)
T
⊙
M
)
V
\operatorname{Attention}(Q,K,V)=(\phi(Q)\phi(K)^T\odot M)V
Attention(Q,K,V)=(ϕ(Q)ϕ(K)T⊙M)V
然后还是先左乘计算,并没有解决线性注意力的根本问题,但是借鉴了flash attention
中的硬件加速。
其前向和反向传播流程如下,就是将QKV切块,放到高速SRAM中去计算。虽然变快了,但此时的复杂度还是
O
(
n
2
d
)
O(n^2d)
O(n2d)。
Lightning Attention-2
源自:Lightning Attention-2: A Free Lunch for Handling Unlimited Sequence Lengths in Large Language Models
Lightning Attention-2
解决了因果模型在计算单向注意力时,需要进行累加求和操作导致无法矩阵运算的情况,实现了单向注意力先计算右乘,成功将复杂度降为
O
(
n
d
2
)
O(nd^2)
O(nd2)。
o
1
=
q
1
(
k
1
T
v
1
)
o_1 = q_1(k_1^Tv_1)
o1=q1(k1Tv1)
o 2 = q 2 ( k 1 T v 1 + k 2 T v 2 ) o_2 = q_2(k_1^Tv_1+k_2^Tv_2) o2=q2(k1Tv1+k2Tv2)
o 3 = q 3 ( k 1 T v 1 + k 2 T v 2 + k 3 T v 3 ) o_3 = q_3(k_1^Tv_1+k_2^Tv_2+k_3^Tv_3) o3=q3(k1Tv1+k2Tv2+k3Tv3)
o 4 = q 4 ( k 1 T v 1 + k 2 T v 2 + k 3 T v 3 + k 4 T v 4 ) o_4 = q_4(k_1^Tv_1+k_2^Tv_2+k_3^Tv_3+k_4^Tv_4) o4=q4(k1Tv1+k2Tv2+k3Tv3+k4Tv4)
再将这个累加求和公式拿过来,配合下图观察发现,之前的问题是每次计算 Q K T QK^T QKT都在整个序列上计算,这样每次都是所有序列的token互相注意到。那如果在序列这个维度拆分成小份,比如图中右侧先计算 k 1 k_1 k1和 k 2 k_2 k2,然后用于 q 3 q_3 q3的计算就完全没有问题, k 4 k_4 k4后面的就不计算了。这样就既能矩阵运算,又能符合单向掩码。
公式中也可以发现,当前step之前的k和v是可以相乘的,比如
q
3
q_3
q3在计算时,可以将
k
1
T
v
1
+
k
2
T
v
2
+
k
3
T
v
3
k_1^Tv_1+k_2^Tv_2+k_3^Tv_3
k1Tv1+k2Tv2+k3Tv3使用矩阵操作运算。所以Lightning Attention-2将大矩阵拆开,类似flash attention拆成多个block。
这些 block 不能拆分成 n 份,这样block的意义就没有了,for循环计算反而更慢。所以每个 block 中会有多个时间步的token。
此时这些 block 就可以分为两类,一类是块内(intra block),一类是块间(inter block)。块内代表当前块 q 的序列下标和 kv 序列下标相同,块间即不同。
块内在计算 q i q_i qi时直接矩阵右乘很容易算上 k i + 1 v i + 1 k_{i+1}v_{i+1} ki+1vi+1,所以块内使用传统的左乘并与掩码矩阵相乘。块间计算时就可以先右乘计算 K t V K^tV KtV,因为之前的kv是可以双向注意力的。然后将之前的kv结果缓存下来并更新,用于下一个step计算。
下图是Lightning Attention-2
的结构图,
λ
\lambda
λ是它的模型所使用的位置编码,忽略即可。
以下是前向传播和反向传播流程。
问题:M矩阵维度是[B, B],相当于每一个块代表了多个序列步n,在对角线位置是1,那在这个块内前面的q就可以注意到后面的kv了
解答:M矩阵维度虽然是[B, B],但只是这么切割,其内部值仍然是下三角。
备注
个人理解,若有不对请指出,谢谢。