Flash Attention 算法简介
Flash Attention 算法简介
Flash Attention,是近几年 MLSys 领域最重要的工作之一。它考虑到 self attention 在 GPU 上计算时的 I/O 特性,通过 tiling 的思想对 self attention 过程中的矩阵乘法、softmax 等操作进行分块处理,使得每个块的计算都能在 GPU SRAM 内部完成,减少对 GPU HBM 的访存开销,大大提升了 self attention 的计算速度,并且能保证最终结果与标准 self-attention 一致。同时,采用 recompute 的方法,在模型前向时避免保存用于求梯度的大量中间结果,而是在反向传播利用高效的 tiling self attention 重新高效地计算出来中间值,节省大量的显存空间。
本文基于 Flash Attention 论文和 From online softmax to Flash Attention 手稿,简要介绍 Flash Attention 的思想。
标准 Self Attention
给定输入
Q
,
K
,
V
∈
R
N
×
D
Q,K,V\in\mathbb{R}^{N\times D}
Q,K,V∈RN×D,其中
N
N
N 是序列长度(一般几千到几十万),
D
D
D 是隐层维度(一般几百到几千),不考虑 normalization、mask 等,标准的 self attention 的计算方式为:
S
=
Q
K
T
∈
R
N
×
N
P
=
softmax
(
S
)
∈
R
N
×
N
O
=
P
V
∈
R
N
×
D
\begin{aligned} S=QK^T &\in\mathbb{R}^{N\times N} \\ P=\text{softmax}(S) &\in\mathbb{R}^{N\times N} \\ O=PV &\in\mathbb{R}^{N\times D} \end{aligned} \notag \\
S=QKTP=softmax(S)O=PV∈RN×N∈RN×N∈RN×D
在标准的 self attention 实现中,中间结果
S
,
P
S,P
S,P 的尺寸太大,无法放到容量较小但高速的 SRAM 中,需要暂存到(相对)低速的 HBM,而 self attention 过程中又存在大量的 IO 密集型操作(如 softmax,以及可能存在的 elementwise 的操作 mask、dropout 等),因此大量的访存操作导致标准 self attention 的处理速度很慢。
同时,为了在反向传播时计算梯度,中间结果 S , P S,P S,P 还需要一直保存在显存中,造成了大量的显存占用。
矩阵乘法分块计算
为了保证整个 self attention 操作能够在 SRAM 中完成,减少对 HBM 的访存开销,Flash Attention 提出我们可以对 self attention 中的各个操作进行矩阵分块计算,控制计算每个块所需的内存可以被 SRAM 的容量满足,这样,每个块都单独一次性完成,避免了大量的访存开销,即使完整的 self attention 操作非常大,我们都可以分成更多的块来计算。
一个计算操作能否分块进行的条件是它是否具有结合律。比如我们熟悉的矩阵乘法操作,就具有结合律,因此本身可以直接进行矩阵乘法分块计算。下图以 C = A × B C=A\times B C=A×B 为例,展示了矩阵乘法的分块计算。具体来说,我们将各个矩阵切分成 T × T T\times T T×T 个块(tiles),对于输出的每个块,我们从左到右扫过 A A A 中对应的各个块,从上到下扫过 C C C 中对应的各个块,将这些块中的数值从 HBM 加载到 SRAM 中(图中标为蓝色,SRAM 整体的容量为 O ( T 2 ) \mathcal{O}(T^2) O(T2))。然后我们进行逐块的部分矩阵乘法,对于位置 ( i , j ) (i,j) (i,j) 处,加载 A [ i , k ] A[i,k] A[i,k] 和 B [ k , j ] B[k,j] B[k,j](图中标位红色),将 A [ i , k ] × B [ k , j ] A[i,k]\times B[k,j] A[i,k]×B[k,j] 的计算结果写到 C [ i , j ] C[i,j] C[i,j] 中,在输出矩阵 C C C 一个块的计算完成后,将其写到 HBM 中,然后继续处理下一个块。
对于矩阵乘法的分块计算,已经很成熟了,上面只是大致说明其思想,实际上的具体实现比这复杂得多,具体可以参考 cutlass。然而,self attention 中可不止有矩阵乘法操作,还有让人头疼的 softmax,它可不具有结合律,如何对其进行分块计算呢?
在介绍 Flash Attention 具体如何对 softmax 进行分块计算之前,我们先介绍一个前置知识:online softmax。
Online Softmax
标准 softmax
对于一个长为
N
N
N 的序列
{
x
1
,
x
2
,
…
,
x
N
}
\{x_1,x_2,\dots,x_N\}
{x1,x2,…,xN}(常被称为 logits),我们最熟悉的 softmax 操作的标准公式如下:
s
i
=
e
x
i
∑
j
=
1
N
e
x
j
s_i=\frac{e^{x_i}}{\sum_{j=1}^Ne^{x_j}} \notag \\
si=∑j=1Nexjexi
safe softmax
Softmax 操作本身存在指数操作,容易出现数值上溢的情况。以我们常用的 fp16 精度为例,其能表示的最大数值为 65504,因此,当 logits 中的值达到 12 时,指数操作
e
12
≈
162
,
754.79
>
65504
e^{12}\approx162,754.79>65504
e12≈162,754.79>65504 ,就已经出现数值溢出了。为了避免这种情况,我们通常会在正式进行 softmax 操作之前,对 logits 序列中的值同时减去该序列的最大值,从而使该序列的值都是负数,再进行指数操作时就不会出现数值溢出了,就 “安全” 了,这称为 safe softmax。
s
i
′
=
e
x
i
−
m
N
∑
j
=
1
N
e
x
j
−
m
N
,
其中
m
N
=
max
(
x
1
,
x
2
,
…
,
x
N
)
s'_i=\frac{e^{x_i-m_N}}{\sum_{j=1}^Ne^{x_j-m_N}},\ \ \ 其中\ m_N=\text{max}(x_1,x_2,\dots,x_N) \notag \\
si′=∑j=1Nexj−mNexi−mN, 其中 mN=max(x1,x2,…,xN)
这里需要指出的是,safe softmax 操作的合理性在于 softmax 操作的一个性质:即序列同时加上或减去任一个常数值,softmax 的结果不变。这从 softmax 的形式上就能轻松看出来,比如我们对 logits 的值同时加上一个常数
c
c
c:
s
i
′
=
e
x
i
+
c
∑
j
=
1
N
e
x
j
+
c
=
e
x
i
⋅
e
c
∑
j
=
1
N
e
x
j
⋅
e
c
=
e
x
i
∑
j
=
1
N
e
x
j
=
s
i
s'_i=\frac{e^{x_i+c}}{\sum_{j=1}^Ne^{x_j+c}}=\frac{e^{x_i}\cdot e^c}{\sum_{j=1}^Ne^{x_j}\cdot e^c}=\frac{e^{x_i}}{\sum_{j=1}^Ne^{x_j}}=s_i \notag \\
si′=∑j=1Nexj+cexi+c=∑j=1Nexj⋅ecexi⋅ec=∑j=1Nexjexi=si
online softmax
online softmax,顾名思义,想要在原长度为
N
N
N 序列的 softmax 结果上,能够在线地增加新的值,比如增加一个新元素
x
N
+
1
x_{N+1}
xN+1,动态地求出此时
N
+
1
N+1
N+1 长度序列的 softmax 结果。为了实现 online softmax,我们需要在维护原 softmax 结果序列的基础上,额外维护
m
N
m_N
mN 和
d
N
d_N
dN 两项,其中前者表示原 logits 序列前
N
N
N 项的最大值,后者表示对长度为
N
N
N 的序列
x
1
,
x
2
,
…
,
x
N
x_1,x_2,\dots,x_N
x1,x2,…,xN 进行 safe softmax 操作时分母上的求和项,即序列指数和:
d
N
=
∑
j
=
1
N
e
x
j
−
m
N
d_N=\sum_{j=1}^Ne^{x_j-m_N} \notag \\
dN=j=1∑Nexj−mN
safe softmax 保证了 softmax 数值的稳定性,但是也同时引入了一个规约操作:求 max。这使得我们想要进行 online softmax 的时候除了需要考虑新引入元素
x
N
+
1
x_{N+1}
xN+1 本身,还需要考虑引入该元素之后序列最大值
m
N
m_N
mN 可能的变化,以及可能导致的指数和
d
N
d_N
dN 的变化。
在有新增元素 x N + 1 x_{N+1} xN+1 时,我们按照如下三步来进行 online softmax 更新:
第一步:更新最大值
m
N
+
1
m_{N+1}
mN+1:
m
N
+
1
=
max
(
m
N
,
x
N
+
1
)
m_{N+1}=\max(m_N,x_{N+1}) \notag \\
mN+1=max(mN,xN+1)
第二步:更新指数和
d
N
+
1
d_{N+1}
dN+1:
d
N
+
1
=
∑
j
=
1
N
+
1
e
x
j
−
m
N
+
1
=
(
∑
j
=
1
N
e
x
j
−
m
N
+
1
)
+
e
x
N
+
1
−
m
N
+
1
=
(
∑
j
=
1
N
e
x
j
−
m
N
e
m
N
−
m
N
+
1
)
+
e
x
N
+
1
−
m
N
+
1
=
d
N
⋅
e
m
N
−
m
N
+
1
+
e
x
N
+
1
−
m
N
+
1
\begin{aligned} d_{N+1}&=\sum_{j=1}^{N+1}e^{x_j-m_{N+1}} \\ &=(\sum_{j=1}^Ne^{x_j-m_{N+1}})+e^{x_{N+1}-m_{N+1}} \\ &=(\sum_{j=1}^Ne^{x_j-m_{N}}e^{m_N-m_{N+1}})+e^{x_{N+1}-m_{N+1}} \\ &=d_N\cdot e^{m_N-m_{N+1}}+e^{x_{N+1}-m_{N+1}} \end{aligned} \notag \\
dN+1=j=1∑N+1exj−mN+1=(j=1∑Nexj−mN+1)+exN+1−mN+1=(j=1∑Nexj−mNemN−mN+1)+exN+1−mN+1=dN⋅emN−mN+1+exN+1−mN+1
这里推导其实就是说新增了元素
x
N
+
1
x_{N+1}
xN+1,全局最大值可能改变,需要对原来的指数和补上一个系数
e
m
N
−
m
N
+
1
e^{m_N-m_{N+1}}
emN−mN+1,把之前计算时可能少减的部分补上。当然,如果引入新元素没有改变全局最大值,即
m
N
=
m
N
+
1
m_N=m_{N+1}
mN=mN+1,那这个系数就是 1 了。然后再加上新元素对应的
e
x
N
+
1
−
m
N
+
1
e^{x_{N+1}-m_{N+1}}
exN+1−mN+1。
第三步:根据
m
N
+
1
,
d
N
+
1
m_{N+1},d_{N+1}
mN+1,dN+1,重新计算长度为
N
+
1
N+1
N+1 的整个序列的 softmax 结果:
s
i
=
e
x
i
−
m
N
+
1
d
N
+
1
s_i=\frac{e^{x_i-m_{N+1}}}{d_{N+1}} \notag \\
si=dN+1exi−mN+1
block online softmax
Online Softmax 最大的意义不是向序列中逐个增加元素,而是我们可以将原序列分块各自进行处理,最后再归并,这样就能充分地并行,这也就是 block online softmax。
首先,为了进行适配分块,我们重新设置一下记号: m 1 : N m_{1:N} m1:N 表示序列索引从 1 到 N N N 中的最大值, d 1 : N d_{1:N} d1:N 表示从 1 到 N N N 子序列的分母上的求和项。
我们首先考虑分为两个 block 的情景。对于总长度为 2 N 2N 2N 的序列,我们将其分为等长的 x 1 , x 2 , … , x N x_1,x_2,\dots,x_N x1,x2,…,xN 和 x N + 1 , x N + 2 , … , x 2 N x_{N+1},x_{N+2},\dots,x_{2N} xN+1,xN+2,…,x2N。block online softmax 分为以下几步:
第一步:分别表示出各块最大值
m
m
m 和指数和
d
d
d:
m
1
:
N
=
max
(
x
1
,
x
2
,
…
,
x
N
)
d
1
:
N
=
∑
j
=
1
N
e
x
j
−
m
1
:
N
m
N
+
1
:
2
N
=
max
(
x
N
+
1
,
x
N
+
2
,
…
,
x
2
N
)
d
N
+
1
:
2
N
=
∑
j
=
N
+
1
2
N
e
x
j
−
m
N
+
1
:
2
N
\begin{aligned} m_{1:N}&=\max(x_1,x_2,\dots,x_N) \\ d_{1:N}&=\sum_{j=1}^Ne^{x_j-m_{1:N}} \\ m_{N+1:2N}&=\max(x_{N+1},x_{N+2},\dots,x_{2N}) \\ d_{N+1:2N}&=\sum_{j=N+1}^{2N}e^{x_j-m_{N+1:2N}} \end{aligned} \notag \\
m1:Nd1:NmN+1:2NdN+1:2N=max(x1,x2,…,xN)=j=1∑Nexj−m1:N=max(xN+1,xN+2,…,x2N)=j=N+1∑2Nexj−mN+1:2N
第二步:合并全局最大值
m
1
:
2
N
m_{1:2N}
m1:2N:
m
1
:
2
N
=
max
(
m
1
:
N
,
m
N
+
1
,
2
N
)
m_{1:2N}=\max(m_{1:N},m_{N+1,2N}) \notag \\
m1:2N=max(m1:N,mN+1,2N)
第三步:计算全局指数和
d
1
:
2
N
d_{1:2N}
d1:2N:
d
1
:
2
N
=
∑
j
=
1
2
N
e
x
j
−
m
1
:
2
N
=
…
=
d
1
:
N
e
m
1
:
N
−
m
1
:
2
N
+
d
N
+
1
:
2
N
e
m
N
+
1
:
2
N
−
m
1
:
2
N
\begin{aligned} d_{1:2N}&=\sum_{j=1}^{2N}e^{x_j-m_{1:2N}} \\ &=\dots \\ &=d_{1:N}\ e^{m_{1:N}-m_{1:2N}}+d_{N+1:2N}\ e^{m_{N+1:2N}-m_{1:2N}} \end{aligned} \notag \\
d1:2N=j=1∑2Nexj−m1:2N=…=d1:N em1:N−m1:2N+dN+1:2N emN+1:2N−m1:2N
这里的结果的意思和之前一样,就是用合并出的全局最大值对每个块的局部最大值进行补偿,推导细节就略过了。
第四步:块合并后的 softmax 结果:
s
i
=
e
x
i
−
m
1
:
2
N
d
1
:
2
N
s_i=\frac{e^{x_i-m_{1:2N}}}{d_{1:2N}} \notag \\
si=d1:2Nexi−m1:2N
Flash Attention
接下来,我们看 Flash Attention 如何对包含 softmax 操作的 self attention 进行 tiling 分块计算。
3-pass safe softmax
在最常规的 safe softmax 中,由于有 “max 求 m N m_N mN” 和 “sum 求 d N d_N dN” 两个规约操作,以及最后计算 softmax 各元素的值,我们总共需要计算三轮 for 循环,才能完成计算,这称为 3-pass safe softmax。
具体来说,我们记:
- m i m_i mi 为序列中前 i i i 个值中的最大值,即 m i = max j = 1 i { x j } m_i=\max_{j=1}^{i}\{x_j\} mi=maxj=1i{xj},初始值 m 0 = − ∞ m_0=-\infty m0=−∞;
- d i d_{i} di 为序列前 i i i 个值的指数求和结果,即 d i = ∑ j = 1 i e x j − m N d_i=\sum_{j=1}^ie^{x_j-m_N} di=∑j=1iexj−mN,初始值为 d 0 = 0 d_0=0 d0=0,当整体的求和完成,得到 d N d_N dN 即为 safe softmax 的分母;
- a i a_i ai 为 safe softmax 的结果
那么,最常规的 safe softmax 需要以下三个 for 循环来完成计算。
第一个循环,遍历 x 1 , … , x N x_1,\dots,x_N x1,…,xN,求出 m N = max ( x 1 , … , x N ) m_N=\max(x_1,\dots,x_N) mN=max(x1,…,xN);
第二个循环,遍历 x 1 , … , x N x_1,\dots,x_N x1,…,xN,求出 d N = ∑ j = 1 N e x j − m N d_N=\sum_{j=1}^Ne^{x_j-m_N} dN=∑j=1Nexj−mN;
第三个循环,遍历 x 1 , … , x N x_1,\dots,x_N x1,…,xN,求出最终的 softmax 结果 a i = e x i − m N d N a_i=\frac{e^{x_i-m_N}}{d_N} ai=dNexi−mN。
在 attention 中,这里的 softmax 需要这样三个循环,而我们不可能同时将所有的 logits 存放到 GPU 的 SRAM (~20MB)中,那我们就需要三次访问 Q , K Q,K Q,K 矩阵,造成大量的访存时间开销。
2-pass softmax
由于在 GPU 计算的时间开销分布中,访存是大头,因此如果能尽量减少访存时间,稍微增加一点计算复杂度也是非常值得的。现在我们想对上述需要 3 个 for 循环完成计算常规 softmax 实现进行优化,减少其循环次数,从而降低访存的时间开销。但是我们发现上述三个 for 循环存在着逐级数据依赖关系,第二个 for 循环计算 d i d_i di 时需要第一个 for 循环的最终 max 结果 m N m_N mN,第三个 for 循环则需要第二个 for 循环的最终 sum 结果 d N d_N dN,有什么办法能够融合上述的三个 for 循环呢?
首先出场的就是我们刚刚提到的 online softmax,借助它,我们可以实现对前两个 for-loop 的融合。具体来说,我们构造一个新的序列
d
i
′
:
=
∑
j
=
1
i
e
x
j
−
m
i
d'_i:=\sum_{j=1}^ie^{x_j-m_i}
di′:=j=1∑iexj−mi
作为原来
d
i
′
:
=
∑
j
=
1
i
e
x
j
−
m
N
d'_i:=\sum_{j=1}^ie^{x_j-m_N}
di′:=∑j=1iexj−mN 的替代(surrogate)。这样我们就消除了第二个 for-loop 对第一个 for-loop 最终结果
m
N
m_N
mN 的依赖,且两个序列的最后一个元素是相同的
d
N
′
=
d
N
d'_N=d_N
dN′=dN,因此第三个 for-loop 直接用
d
N
′
d'_N
dN′ 的结果来替代
d
N
d_N
dN 也是没问题的。但是注意,如同我们在 online softmax 中推导的那样,我们需要在每次添加一个新的元素后,更新序列的最大值,并补偿可能少减的部分。这里的推导和上面 online softmax 的推导完全一致:
d
i
′
=
∑
j
=
1
i
e
x
j
−
m
i
=
(
∑
j
=
1
i
−
1
e
x
j
−
m
i
)
+
e
x
i
−
m
i
=
(
∑
j
=
1
i
−
1
e
x
j
−
m
i
−
1
)
e
m
i
−
1
−
m
i
+
e
x
i
−
m
i
=
d
i
−
1
′
e
m
i
−
1
−
m
i
+
e
x
i
−
m
i
\begin{aligned} d'_i&=\sum_{j=1}^ie^{x_j-m_i} \\ &=\left(\sum_{j=1}^{i-1}e^{x_j-m_i}\right)+e^{x_i-m_i} \\ &=\left(\sum_{j=1}^{i-1}e^{x_j-m_{i-1}}\right)e^{m_{i-1}-m_i}+e^{x_i-m_i} \\ &=d'_{i-1}e^{m_{i-1}-m_i}+e^{x_i-m_i} \end{aligned} \notag \\
di′=j=1∑iexj−mi=(j=1∑i−1exj−mi)+exi−mi=(j=1∑i−1exj−mi−1)emi−1−mi+exi−mi=di−1′emi−1−mi+exi−mi
可以看到,这个式子只依赖于
m
i
−
1
,
m
i
m_{i-1},m_i
mi−1,mi,这样我们就借助 online softmax,把前两个循环融合到了一起,称为 2-pass online softmax。那么接下来,我们能否更进一步,优化到单个循环搞定,实现 1-pass softmax 呢?
flash attention
很遗憾,如果仅对于 softmax 来说,1-pass 是无法做到的。但是,在 self-attention 中,我们的最终目标不是 attention 矩阵 A A A,而是其与 value 矩阵相乘得到的输出 O = A × V O=A\times V O=A×V。即 self attention 中求完 softmax 之后,每一项的值会与 V V V 中向量相乘,然后累加。这里的累加很关键,有了这个累加的操作,就有成了上面我们构造 surrogate 替代的形式,就可以做成 1-pass 了。
先捋一捋,我们目前已经借助 online softmax 将 3-pass 压缩到了 2-pass,但是 softmax 本身没法再压缩到 1-pass 了,我们来看看 Flash Attention 是如何利用 Self-Attention 的特性,实现 1-pass Self-Attention 的。
首先,我们写出 2-pass self attention 的计算过程。记
- Q [ k , : ] Q[k,:] Q[k,:] 为矩阵 Q Q Q 的第 k k k 个行向量;
- K T [ : , i ] K^T[:,i] KT[:,i] 为矩阵 K T K^T KT 的第 i i i 个列向量;
- V [ i , : ] V[i,:] V[i,:] 为矩阵 V V V 的第 i i i 个行向量;
- O [ k , : ] O[k,:] O[k,:] 为输出矩阵 O O O 的第 k k k 个行向量;
- o i = ∑ j = 1 i a j V [ j , : ] \mathbf{o}_i=\sum_{j=1}^ia_jV[j,:] oi=∑j=1iajV[j,:] 是一个行向量。保存了部分求和结果 A [ k , : i ] × V [ : i , : ] A[k,:i]\times V[:i,:] A[k,:i]×V[:i,:]。
2-pass self attention 分为两个 for-loop,第一个 for-loop 计算
m
N
m_N
mN 和
d
N
′
d'_N
dN′ (也就是
d
N
d_N
dN):
for
i
in
[
1
,
…
,
N
]
x
i
=
Q
[
k
,
:
]
K
T
[
:
,
i
]
m
i
=
max
(
m
i
−
1
,
x
i
)
d
i
′
=
d
i
−
1
′
e
m
i
−
1
−
m
i
+
e
x
i
−
m
i
endfor
\text{for}\ i\ \text{in}\ [1,\dots,N] \hspace{5cm}\\ \begin{aligned} x_i&=Q[k,:]K^T[:,i] \\ m_i&=\max(m_{i-1},x_i) \\ d'_i&=d'_{i-1}e^{m_{i-1}-m_i}+e^{x_i-m_i} \\ \end{aligned} \\ \text{endfor} \hspace{7.2cm} \notag \\
for i in [1,…,N]ximidi′=Q[k,:]KT[:,i]=max(mi−1,xi)=di−1′emi−1−mi+exi−miendfor
第二个 for-loop 计算 softmax 结果,并与对应位置的
V
V
V 相乘,得到最终结果
for
i
in
[
1
,
…
,
N
]
a
i
=
e
x
i
−
m
N
d
N
′
o
i
=
o
i
−
1
+
a
i
V
[
i
,
:
]
endfor
\text{for}\ i\ \text{in}\ [1,\dots,N] \hspace{5cm}\\ \begin{aligned} a_i&=\frac{e^{x_i-m_N}}{d'_N} \\ \mathbf{o}_i&=\mathbf{o}_{i-1}+a_iV[i,:] \end{aligned} \\ \text{endfor} \hspace{7.2cm} \notag \\
for i in [1,…,N]aioi=dN′exi−mN=oi−1+aiV[i,:]endfor
第二个循环结束后,我们将结果写入输出矩阵
O
O
O:
O
[
k
,
:
]
=
o
N
O[k,:]=\mathbf{o}_N \notag \\
O[k,:]=oN
我们将第二个循环中的两个等式合并(将
a
i
a_i
ai 代入),得到:
o
i
:
=
∑
j
=
1
i
(
e
x
j
−
m
N
d
N
′
V
[
j
,
:
]
)
\mathbf{o}_i:=\sum_{j=1}^i(\frac{e^{x_j-m_N}}{d'_N}V[j,:]) \notag \\
oi:=j=1∑i(dN′exj−mNV[j,:])
这里可以看到,第二个循环的计算依赖于第一个循环结束后的结果
m
N
m_N
mN 和
d
N
′
d'_N
dN′。这里我们可以再次采用 surrogate 的技巧来实现 1-pass。具体来说,我们构造一个不依赖于
m
N
,
d
N
′
m_N,d'_N
mN,dN′ 的替代序列
o
i
′
\mathbf{o}'_i
oi′:
o
i
′
=
∑
j
=
1
i
e
x
j
−
m
i
d
i
′
V
[
j
,
:
]
\mathbf{o}'_i=\sum_{j=1}^i\frac{e^{x_j-m_i}}{d'_i}V[j,:] \notag \\
oi′=j=1∑idi′exj−miV[j,:]
同样的,这两个序列的最后一个元素相等
o
N
=
o
N
′
\mathbf{o_N}=\mathbf{o}'_N
oN=oN′。然后我们推导出
o
i
′
\mathbf{o}'_i
oi′ 和
o
i
−
1
′
\mathbf{o}'_{i-1}
oi−1′ 的迭代关系:
o
i
′
=
∑
j
=
1
i
e
x
j
−
m
i
d
i
′
V
[
j
:
]
=
(
∑
j
=
1
i
−
1
e
x
j
−
m
i
d
i
′
V
[
j
:
]
)
+
e
x
i
−
m
i
d
i
′
V
[
i
:
]
=
(
∑
j
=
1
i
−
1
e
x
j
−
m
i
−
1
d
i
−
1
′
e
x
j
−
m
i
e
x
j
−
m
i
−
1
d
i
−
1
′
d
i
′
V
[
j
,
:
]
)
+
e
x
i
−
m
i
d
i
′
V
[
i
:
]
=
(
∑
j
=
1
i
−
1
e
x
j
−
m
i
−
1
d
i
−
1
′
)
d
i
−
1
′
d
i
′
e
m
i
−
1
−
m
i
+
e
x
i
−
m
i
d
i
′
V
[
i
:
]
=
o
i
−
1
′
d
i
−
1
′
e
m
i
−
1
−
m
i
d
i
′
+
e
x
i
−
m
i
d
i
′
V
[
i
:
]
\begin{aligned} \mathbf{o}_i'&=\sum_{j=1}^i\frac{e^{x_j-m_i}}{d'_i}V[j:] \\ &=\left(\sum_{j=1}^{i-1}\frac{e^{x_j-m_i}}{d'_i}V[j:]\right)+\frac{e^{x_i-m_i}}{d'_i}V[i:] \\ &=\left(\sum_{j=1}^{i-1}\frac{e^{x_j-m_{i-1}}}{d_{i-1}'}\frac{e^{x_j-m_i}}{e^{x_j-m_{i-1}}}\frac{d'_{i-1}}{d'_i}V[j,:]\right)+\frac{e^{x_i-m_i}}{d'_i}V[i:]\\ &=\left(\sum_{j=1}^{i-1}\frac{e^{x_j-m_{i-1}}}{d'_{i-1}}\right)\frac{d'_{i-1}}{d'_i}e^{m_{i-1}-m_i}+\frac{e^{x_i-m_i}}{d'_i}V[i:] \\ &=\mathbf{o}'_{i-1}\frac{d'_{i-1}e^{m_{i-1}-m_i}}{d_i'}+\frac{e^{x_i-m_i}}{d'_i}V[i:] \end{aligned} \notag \\
oi′=j=1∑idi′exj−miV[j:]=(j=1∑i−1di′exj−miV[j:])+di′exi−miV[i:]=(j=1∑i−1di−1′exj−mi−1exj−mi−1exj−midi′di−1′V[j,:])+di′exi−miV[i:]=(j=1∑i−1di−1′exj−mi−1)di′di−1′emi−1−mi+di′exi−miV[i:]=oi−1′di′di−1′emi−1−mi+di′exi−miV[i:]
o
i
′
\mathbf{o}'_i
oi′ 的计算仅依赖于
d
i
′
,
d
i
−
1
′
,
m
i
,
m
i
−
1
,
x
i
d'_i,d'_{i-1},m_i,m_{i-1},x_i
di′,di−1′,mi,mi−1,xi,这样我们就将三个循环全都融合到了一起, 实现了 1-pass self attention,也就是 flash attention。在这一个循环中,我们做如下计算:
for
i
in
[
1
,
…
,
N
]
x
i
=
Q
[
k
,
:
]
K
T
[
:
,
i
]
m
i
=
max
(
m
i
−
1
,
x
i
)
d
i
′
=
d
i
−
1
′
e
m
i
−
1
−
m
i
+
e
x
i
−
m
i
o
i
′
=
o
i
−
1
′
d
i
−
1
′
e
m
i
−
1
−
m
i
d
i
′
+
e
x
i
−
m
i
d
i
′
V
[
i
,
:
]
endfor
O
[
k
,
:
]
=
o
N
′
\text{for}\ i\ \text{in}\ [1,\dots,N] \hspace{12cm}\\ \begin{aligned} x_i&=Q[k,:]K^T[:,i] \\ m_i&=\max(m_{i-1},x_i) \\ d'_i&=d'_{i-1}e^{m_{i-1}-m_i}+e^{x_i-m_i} \\ \mathbf{o}'_i&=\mathbf{o}'_{i-1}\frac{d'_{i-1}e^{m_{i-1}-m_i}}{d'_i}+\frac{e^{x_i-m_i}}{d'_i}V[i,:] \end{aligned} \\ \text{endfor} \hspace{14.2cm} \\ O[k,:]=\mathbf{o}'_N \notag \\
for i in [1,…,N]ximidi′oi′=Q[k,:]KT[:,i]=max(mi−1,xi)=di−1′emi−1−mi+exi−mi=oi−1′di′di−1′emi−1−mi+di′exi−miV[i,:]endforO[k,:]=oN′
状态 x i , m i , d i ′ , o i x_i,m_i,d'_i,\mathbf{o}_i xi,mi,di′,oi 这些尺寸都比较小,可以放到 GPU SRAM 里进行计算。由于这些操作都具有结合律,因此都可以进行 tiling 操作。如果我们逐 tile 进行计算,记
- b b b 为 block size
- x i x_i xi 为存储第 i i i 个 tile [ ( i − 1 ) b : i b ] [(i-1)b:ib] [(i−1)b:ib] 的 Q [ k ] K T Q[k]K^T Q[k]KT 值的向量
- m i ( local ) m_i^{(\text{local})} mi(local) 表示 x i x_i xi 内的局部最大值
tiling 版的 flash attention 可以写成如下形式:
for i in [ 1 , … , N / b ] x i = Q [ k , : ] K T [ : , ( i − 1 ) b : i b ] m i ( local ) = max j = 1 b ( x i [ j ] ) m i = max ( m i − 1 , m i ( local ) ) d i ′ = d i − 1 ′ e m i − 1 − m i + ∑ j = 1 b e x i [ j ] − m i o i ′ = o i − 1 ′ d i − 1 ′ e m i − 1 − m i d i ′ + ∑ j = 1 b e x i − m i d i ′ V [ j + ( i − 1 ) b , : ] endfor O [ k , : ] = o N / b ′ \text{for}\ i\ \text{in}\ [1,\dots,N/b] \hspace{12cm}\\ \begin{aligned} x_i&=Q[k,:]K^T[:,(i-1)b:ib] \\ m_i^{(\text{local})}&=\max_{j=1}^b(x_i[j]) \\ m_i&=\max(m_{i-1},m_i^{(\text{local})}) \\ d'_i&=d'_{i-1}e^{m_{i-1}-m_i}+\sum_{j=1}^be^{x_i[j]-m_i} \\ \mathbf{o}'_i&=\mathbf{o}'_{i-1}\frac{d'_{i-1}e^{m_{i-1}-m_i}}{d'_i}+\sum_{j=1}^b\frac{e^{x_i-m_i}}{d'_i}V[j+(i-1)b,:] \end{aligned} \\ \text{endfor} \hspace{14.2cm} \\ O[k,:]=\mathbf{o}'_{N/b} \notag \\ for i in [1,…,N/b]ximi(local)midi′oi′=Q[k,:]KT[:,(i−1)b:ib]=j=1maxb(xi[j])=max(mi−1,mi(local))=di−1′emi−1−mi+j=1∑bexi[j]−mi=oi−1′di′di−1′emi−1−mi+j=1∑bdi′exi−miV[j+(i−1)b,:]endforO[k,:]=oN/b′
下图展示了 Flash Attention 在 GPU 上的计算方式。蓝色块表示驻留在 SRAM 中的 tiles,而红色块对应于第 i i i 行。 L L L 表示序列长度,可能非常大(例如 16k),D 表示隐层维度,在 Transformer 中通常较小(例如 GPT3 为 128), B B B 是可以控制的块大小。值得注意的是,整体 SRAM 内存占用仅取决于 B B B 和 D D D,与 L L L 无关。因此,该算法可以扩展到长上下文而不会遇到内存问题。在计算过程中,我们从左到右扫描 K T K^T KT 和 A A A 的 tiles,从上到下扫描 V V V 的 tiles,并相应地更新 m , d , O m,d ,O m,d,O 的状态。
Recomputation
以上我们介绍的是 Flash Attention 中最关键的 tiling 的技巧,通过将 self-attention 中各个操作(GEMM、softmax 等)进行分块计算,每个块用一个 kernel 直接算完,避免了大量的访存时间开销,从而加快了 self attention 的计算。那么 Flash Attention 又是如何节省显存的呢?
在标准的 self attention 实现中,会存储中间结果矩阵 S , P ∈ R N × N S,P\in\mathbb{R}^{N\times N} S,P∈RN×N,用于反向传播时计算梯度。这会带来 O ( N 2 ) \mathcal{O}(N^2) O(N2) 的空间复杂度,当序列长度 N N N 很大时,这里的显存占用还是很可观的。但是在 Flash Attention 中,还记得我们除了保存输出 O O O 之外还保存了 online softmax 计算过程中的一些标准化统计量:序列最大值 m N m_N mN 和序列指数和 d N d_N dN,这样我们就能在反向传播时很容易地计算出中间矩阵 S , P S,P S,P,因此我们可以在模型前向过程中丢掉这些中间结果以节省显存。实际上,这就是我们常用的 gradient checkpointing 技巧,但是一般的 gradient checkpointing 技巧由于丢掉了一些中间结果,在反向传播重新计算时会额外增加一些耗时,整体需要在显存和计算耗时之间进行 trade-off,而 Flash Attention 却得益于 tiling 带来的计算高效性,虽然额外增加了一些 Flops,但是整体时间和显存占用都大幅降低。
总结
在 Transformer 大模型时代,Flash Attention 的重要性不言而喻。它从 I/O aware 的角度,通过 tiling 的思想来对 self attention 进行分块计算,减少访问 HBM 的时间开销,又通过 recompute 来节省训练时的显存占用,实在是软硬件结合算法设计的典范,十分值得学习。