当前位置: 首页 > article >正文

Flash Attention 算法简介

Flash Attention 算法简介

Flash Attention,是近几年 MLSys 领域最重要的工作之一。它考虑到 self attention 在 GPU 上计算时的 I/O 特性,通过 tiling 的思想对 self attention 过程中的矩阵乘法、softmax 等操作进行分块处理,使得每个块的计算都能在 GPU SRAM 内部完成,减少对 GPU HBM 的访存开销,大大提升了 self attention 的计算速度,并且能保证最终结果与标准 self-attention 一致。同时,采用 recompute 的方法,在模型前向时避免保存用于求梯度的大量中间结果,而是在反向传播利用高效的 tiling self attention 重新高效地计算出来中间值,节省大量的显存空间。

本文基于 Flash Attention 论文和 From online softmax to Flash Attention 手稿,简要介绍 Flash Attention 的思想。

标准 Self Attention

给定输入 Q , K , V ∈ R N × D Q,K,V\in\mathbb{R}^{N\times D} Q,K,VRN×D,其中 N N N 是序列长度(一般几千到几十万), D D D 是隐层维度(一般几百到几千),不考虑 normalization、mask 等,标准的 self attention 的计算方式为:
S = Q K T ∈ R N × N P = softmax ( S ) ∈ R N × N O = P V ∈ R N × D \begin{aligned} S=QK^T &\in\mathbb{R}^{N\times N} \\ P=\text{softmax}(S) &\in\mathbb{R}^{N\times N} \\ O=PV &\in\mathbb{R}^{N\times D} \end{aligned} \notag \\ S=QKTP=softmax(S)O=PVRN×NRN×NRN×D
在标准的 self attention 实现中,中间结果 S , P S,P S,P 的尺寸太大,无法放到容量较小但高速的 SRAM 中,需要暂存到(相对)低速的 HBM,而 self attention 过程中又存在大量的 IO 密集型操作(如 softmax,以及可能存在的 elementwise 的操作 mask、dropout 等),因此大量的访存操作导致标准 self attention 的处理速度很慢。

同时,为了在反向传播时计算梯度,中间结果 S , P S,P S,P 还需要一直保存在显存中,造成了大量的显存占用。

矩阵乘法分块计算

为了保证整个 self attention 操作能够在 SRAM 中完成,减少对 HBM 的访存开销,Flash Attention 提出我们可以对 self attention 中的各个操作进行矩阵分块计算,控制计算每个块所需的内存可以被 SRAM 的容量满足,这样,每个块都单独一次性完成,避免了大量的访存开销,即使完整的 self attention 操作非常大,我们都可以分成更多的块来计算。

一个计算操作能否分块进行的条件是它是否具有结合律。比如我们熟悉的矩阵乘法操作,就具有结合律,因此本身可以直接进行矩阵乘法分块计算。下图以 C = A × B C=A\times B C=A×B 为例,展示了矩阵乘法的分块计算。具体来说,我们将各个矩阵切分成 T × T T\times T T×T 个块(tiles),对于输出的每个块,我们从左到右扫过 A A A 中对应的各个块,从上到下扫过 C C C 中对应的各个块,将这些块中的数值从 HBM 加载到 SRAM 中(图中标为蓝色,SRAM 整体的容量为 O ( T 2 ) \mathcal{O}(T^2) O(T2))。然后我们进行逐块的部分矩阵乘法,对于位置 ( i , j ) (i,j) (i,j) 处,加载 A [ i , k ] A[i,k] A[i,k] B [ k , j ] B[k,j] B[k,j](图中标位红色),将 A [ i , k ] × B [ k , j ] A[i,k]\times B[k,j] A[i,k]×B[k,j] 的计算结果写到 C [ i , j ] C[i,j] C[i,j] 中,在输出矩阵 C C C 一个块的计算完成后,将其写到 HBM 中,然后继续处理下一个块。

在这里插入图片描述

对于矩阵乘法的分块计算,已经很成熟了,上面只是大致说明其思想,实际上的具体实现比这复杂得多,具体可以参考 cutlass。然而,self attention 中可不止有矩阵乘法操作,还有让人头疼的 softmax,它可不具有结合律,如何对其进行分块计算呢?

在介绍 Flash Attention 具体如何对 softmax 进行分块计算之前,我们先介绍一个前置知识:online softmax。

Online Softmax

标准 softmax

对于一个长为 N N N 的序列 { x 1 , x 2 , … , x N } \{x_1,x_2,\dots,x_N\} {x1,x2,,xN}(常被称为 logits),我们最熟悉的 softmax 操作的标准公式如下:
s i = e x i ∑ j = 1 N e x j s_i=\frac{e^{x_i}}{\sum_{j=1}^Ne^{x_j}} \notag \\ si=j=1Nexjexi

safe softmax

Softmax 操作本身存在指数操作,容易出现数值上溢的情况。以我们常用的 fp16 精度为例,其能表示的最大数值为 65504,因此,当 logits 中的值达到 12 时,指数操作 e 12 ≈ 162 , 754.79 > 65504 e^{12}\approx162,754.79>65504 e12162,754.79>65504 ,就已经出现数值溢出了。为了避免这种情况,我们通常会在正式进行 softmax 操作之前,对 logits 序列中的值同时减去该序列的最大值,从而使该序列的值都是负数,再进行指数操作时就不会出现数值溢出了,就 “安全” 了,这称为 safe softmax。
s i ′ = e x i − m N ∑ j = 1 N e x j − m N ,    其中  m N = max ( x 1 , x 2 , … , x N ) s'_i=\frac{e^{x_i-m_N}}{\sum_{j=1}^Ne^{x_j-m_N}},\ \ \ 其中\ m_N=\text{max}(x_1,x_2,\dots,x_N) \notag \\ si=j=1NexjmNeximN,   其中 mN=max(x1,x2,,xN)

这里需要指出的是,safe softmax 操作的合理性在于 softmax 操作的一个性质:即序列同时加上或减去任一个常数值,softmax 的结果不变。这从 softmax 的形式上就能轻松看出来,比如我们对 logits 的值同时加上一个常数 c c c
s i ′ = e x i + c ∑ j = 1 N e x j + c = e x i ⋅ e c ∑ j = 1 N e x j ⋅ e c = e x i ∑ j = 1 N e x j = s i s'_i=\frac{e^{x_i+c}}{\sum_{j=1}^Ne^{x_j+c}}=\frac{e^{x_i}\cdot e^c}{\sum_{j=1}^Ne^{x_j}\cdot e^c}=\frac{e^{x_i}}{\sum_{j=1}^Ne^{x_j}}=s_i \notag \\ si=j=1Nexj+cexi+c=j=1Nexjecexiec=j=1Nexjexi=si

online softmax

online softmax,顾名思义,想要在原长度为 N N N 序列的 softmax 结果上,能够在线地增加新的值,比如增加一个新元素 x N + 1 x_{N+1} xN+1,动态地求出此时 N + 1 N+1 N+1 长度序列的 softmax 结果。为了实现 online softmax,我们需要在维护原 softmax 结果序列的基础上,额外维护 m N m_N mN d N d_N dN 两项,其中前者表示原 logits 序列前 N N N 项的最大值,后者表示对长度为 N N N 的序列 x 1 , x 2 , … , x N x_1,x_2,\dots,x_N x1,x2,,xN 进行 safe softmax 操作时分母上的求和项,即序列指数和:
d N = ∑ j = 1 N e x j − m N d_N=\sum_{j=1}^Ne^{x_j-m_N} \notag \\ dN=j=1NexjmN
safe softmax 保证了 softmax 数值的稳定性,但是也同时引入了一个规约操作:求 max。这使得我们想要进行 online softmax 的时候除了需要考虑新引入元素 x N + 1 x_{N+1} xN+1 本身,还需要考虑引入该元素之后序列最大值 m N m_N mN 可能的变化,以及可能导致的指数和 d N d_N dN 的变化。

在有新增元素 x N + 1 x_{N+1} xN+1 时,我们按照如下三步来进行 online softmax 更新:

第一步:更新最大值 m N + 1 m_{N+1} mN+1
m N + 1 = max ⁡ ( m N , x N + 1 ) m_{N+1}=\max(m_N,x_{N+1}) \notag \\ mN+1=max(mN,xN+1)
第二步:更新指数和 d N + 1 d_{N+1} dN+1
d N + 1 = ∑ j = 1 N + 1 e x j − m N + 1 = ( ∑ j = 1 N e x j − m N + 1 ) + e x N + 1 − m N + 1 = ( ∑ j = 1 N e x j − m N e m N − m N + 1 ) + e x N + 1 − m N + 1 = d N ⋅ e m N − m N + 1 + e x N + 1 − m N + 1 \begin{aligned} d_{N+1}&=\sum_{j=1}^{N+1}e^{x_j-m_{N+1}} \\ &=(\sum_{j=1}^Ne^{x_j-m_{N+1}})+e^{x_{N+1}-m_{N+1}} \\ &=(\sum_{j=1}^Ne^{x_j-m_{N}}e^{m_N-m_{N+1}})+e^{x_{N+1}-m_{N+1}} \\ &=d_N\cdot e^{m_N-m_{N+1}}+e^{x_{N+1}-m_{N+1}} \end{aligned} \notag \\ dN+1=j=1N+1exjmN+1=(j=1NexjmN+1)+exN+1mN+1=(j=1NexjmNemNmN+1)+exN+1mN+1=dNemNmN+1+exN+1mN+1
这里推导其实就是说新增了元素 x N + 1 x_{N+1} xN+1,全局最大值可能改变,需要对原来的指数和补上一个系数 e m N − m N + 1 e^{m_N-m_{N+1}} emNmN+1,把之前计算时可能少减的部分补上。当然,如果引入新元素没有改变全局最大值,即 m N = m N + 1 m_N=m_{N+1} mN=mN+1,那这个系数就是 1 了。然后再加上新元素对应的 e x N + 1 − m N + 1 e^{x_{N+1}-m_{N+1}} exN+1mN+1

第三步:根据 m N + 1 , d N + 1 m_{N+1},d_{N+1} mN+1,dN+1,重新计算长度为 N + 1 N+1 N+1 的整个序列的 softmax 结果:
s i = e x i − m N + 1 d N + 1 s_i=\frac{e^{x_i-m_{N+1}}}{d_{N+1}} \notag \\ si=dN+1eximN+1

block online softmax

Online Softmax 最大的意义不是向序列中逐个增加元素,而是我们可以将原序列分块各自进行处理,最后再归并,这样就能充分地并行,这也就是 block online softmax。

首先,为了进行适配分块,我们重新设置一下记号: m 1 : N m_{1:N} m1:N 表示序列索引从 1 到 N N N 中的最大值, d 1 : N d_{1:N} d1:N 表示从 1 到 N N N 子序列的分母上的求和项。

我们首先考虑分为两个 block 的情景。对于总长度为 2 N 2N 2N 的序列,我们将其分为等长的 x 1 , x 2 , … , x N x_1,x_2,\dots,x_N x1,x2,,xN x N + 1 , x N + 2 , … , x 2 N x_{N+1},x_{N+2},\dots,x_{2N} xN+1,xN+2,,x2N。block online softmax 分为以下几步:

第一步:分别表示出各块最大值 m m m 和指数和 d d d
m 1 : N = max ⁡ ( x 1 , x 2 , … , x N ) d 1 : N = ∑ j = 1 N e x j − m 1 : N m N + 1 : 2 N = max ⁡ ( x N + 1 , x N + 2 , … , x 2 N ) d N + 1 : 2 N = ∑ j = N + 1 2 N e x j − m N + 1 : 2 N \begin{aligned} m_{1:N}&=\max(x_1,x_2,\dots,x_N) \\ d_{1:N}&=\sum_{j=1}^Ne^{x_j-m_{1:N}} \\ m_{N+1:2N}&=\max(x_{N+1},x_{N+2},\dots,x_{2N}) \\ d_{N+1:2N}&=\sum_{j=N+1}^{2N}e^{x_j-m_{N+1:2N}} \end{aligned} \notag \\ m1:Nd1:NmN+1:2NdN+1:2N=max(x1,x2,,xN)=j=1Nexjm1:N=max(xN+1,xN+2,,x2N)=j=N+12NexjmN+1:2N
第二步:合并全局最大值 m 1 : 2 N m_{1:2N} m1:2N
m 1 : 2 N = max ⁡ ( m 1 : N , m N + 1 , 2 N ) m_{1:2N}=\max(m_{1:N},m_{N+1,2N}) \notag \\ m1:2N=max(m1:N,mN+1,2N)
第三步:计算全局指数和 d 1 : 2 N d_{1:2N} d1:2N
d 1 : 2 N = ∑ j = 1 2 N e x j − m 1 : 2 N = … = d 1 : N   e m 1 : N − m 1 : 2 N + d N + 1 : 2 N   e m N + 1 : 2 N − m 1 : 2 N \begin{aligned} d_{1:2N}&=\sum_{j=1}^{2N}e^{x_j-m_{1:2N}} \\ &=\dots \\ &=d_{1:N}\ e^{m_{1:N}-m_{1:2N}}+d_{N+1:2N}\ e^{m_{N+1:2N}-m_{1:2N}} \end{aligned} \notag \\ d1:2N=j=12Nexjm1:2N==d1:N em1:Nm1:2N+dN+1:2N emN+1:2Nm1:2N
这里的结果的意思和之前一样,就是用合并出的全局最大值对每个块的局部最大值进行补偿,推导细节就略过了。

第四步:块合并后的 softmax 结果:
s i = e x i − m 1 : 2 N d 1 : 2 N s_i=\frac{e^{x_i-m_{1:2N}}}{d_{1:2N}} \notag \\ si=d1:2Nexim1:2N

Flash Attention

接下来,我们看 Flash Attention 如何对包含 softmax 操作的 self attention 进行 tiling 分块计算。

3-pass safe softmax

在最常规的 safe softmax 中,由于有 “max 求 m N m_N mN” 和 “sum 求 d N d_N dN” 两个规约操作,以及最后计算 softmax 各元素的值,我们总共需要计算三轮 for 循环,才能完成计算,这称为 3-pass safe softmax。

具体来说,我们记:

  • m i m_i mi 为序列中前 i i i 个值中的最大值,即 m i = max ⁡ j = 1 i { x j } m_i=\max_{j=1}^{i}\{x_j\} mi=maxj=1i{xj},初始值 m 0 = − ∞ m_0=-\infty m0=
  • d i d_{i} di 为序列前 i i i 个值的指数求和结果,即 d i = ∑ j = 1 i e x j − m N d_i=\sum_{j=1}^ie^{x_j-m_N} di=j=1iexjmN,初始值为 d 0 = 0 d_0=0 d0=0,当整体的求和完成,得到 d N d_N dN 即为 safe softmax 的分母;
  • a i a_i ai 为 safe softmax 的结果

那么,最常规的 safe softmax 需要以下三个 for 循环来完成计算。

第一个循环,遍历 x 1 , … , x N x_1,\dots,x_N x1,,xN,求出 m N = max ⁡ ( x 1 , … , x N ) m_N=\max(x_1,\dots,x_N) mN=max(x1,,xN)

第二个循环,遍历 x 1 , … , x N x_1,\dots,x_N x1,,xN,求出 d N = ∑ j = 1 N e x j − m N d_N=\sum_{j=1}^Ne^{x_j-m_N} dN=j=1NexjmN

第三个循环,遍历 x 1 , … , x N x_1,\dots,x_N x1,,xN,求出最终的 softmax 结果 a i = e x i − m N d N a_i=\frac{e^{x_i-m_N}}{d_N} ai=dNeximN

在 attention 中,这里的 softmax 需要这样三个循环,而我们不可能同时将所有的 logits 存放到 GPU 的 SRAM (~20MB)中,那我们就需要三次访问 Q , K Q,K Q,K 矩阵,造成大量的访存时间开销。

2-pass softmax

由于在 GPU 计算的时间开销分布中,访存是大头,因此如果能尽量减少访存时间,稍微增加一点计算复杂度也是非常值得的。现在我们想对上述需要 3 个 for 循环完成计算常规 softmax 实现进行优化,减少其循环次数,从而降低访存的时间开销。但是我们发现上述三个 for 循环存在着逐级数据依赖关系,第二个 for 循环计算 d i d_i di 时需要第一个 for 循环的最终 max 结果 m N m_N mN,第三个 for 循环则需要第二个 for 循环的最终 sum 结果 d N d_N dN,有什么办法能够融合上述的三个 for 循环呢?

首先出场的就是我们刚刚提到的 online softmax,借助它,我们可以实现对前两个 for-loop 的融合。具体来说,我们构造一个新的序列
d i ′ : = ∑ j = 1 i e x j − m i d'_i:=\sum_{j=1}^ie^{x_j-m_i} di:=j=1iexjmi
作为原来 d i ′ : = ∑ j = 1 i e x j − m N d'_i:=\sum_{j=1}^ie^{x_j-m_N} di:=j=1iexjmN 的替代(surrogate)。这样我们就消除了第二个 for-loop 对第一个 for-loop 最终结果 m N m_N mN 的依赖,且两个序列的最后一个元素是相同的 d N ′ = d N d'_N=d_N dN=dN,因此第三个 for-loop 直接用 d N ′ d'_N dN 的结果来替代 d N d_N dN 也是没问题的。但是注意,如同我们在 online softmax 中推导的那样,我们需要在每次添加一个新的元素后,更新序列的最大值,并补偿可能少减的部分。这里的推导和上面 online softmax 的推导完全一致:
d i ′ = ∑ j = 1 i e x j − m i = ( ∑ j = 1 i − 1 e x j − m i ) + e x i − m i = ( ∑ j = 1 i − 1 e x j − m i − 1 ) e m i − 1 − m i + e x i − m i = d i − 1 ′ e m i − 1 − m i + e x i − m i \begin{aligned} d'_i&=\sum_{j=1}^ie^{x_j-m_i} \\ &=\left(\sum_{j=1}^{i-1}e^{x_j-m_i}\right)+e^{x_i-m_i} \\ &=\left(\sum_{j=1}^{i-1}e^{x_j-m_{i-1}}\right)e^{m_{i-1}-m_i}+e^{x_i-m_i} \\ &=d'_{i-1}e^{m_{i-1}-m_i}+e^{x_i-m_i} \end{aligned} \notag \\ di=j=1iexjmi=(j=1i1exjmi)+eximi=(j=1i1exjmi1)emi1mi+eximi=di1emi1mi+eximi
可以看到,这个式子只依赖于 m i − 1 , m i m_{i-1},m_i mi1,mi,这样我们就借助 online softmax,把前两个循环融合到了一起,称为 2-pass online softmax。那么接下来,我们能否更进一步,优化到单个循环搞定,实现 1-pass softmax 呢?

flash attention

很遗憾,如果仅对于 softmax 来说,1-pass 是无法做到的。但是,在 self-attention 中,我们的最终目标不是 attention 矩阵 A A A,而是其与 value 矩阵相乘得到的输出 O = A × V O=A\times V O=A×V。即 self attention 中求完 softmax 之后,每一项的值会与 V V V 中向量相乘,然后累加。这里的累加很关键,有了这个累加的操作,就有成了上面我们构造 surrogate 替代的形式,就可以做成 1-pass 了。

先捋一捋,我们目前已经借助 online softmax 将 3-pass 压缩到了 2-pass,但是 softmax 本身没法再压缩到 1-pass 了,我们来看看 Flash Attention 是如何利用 Self-Attention 的特性,实现 1-pass Self-Attention 的。

首先,我们写出 2-pass self attention 的计算过程。记

  • Q [ k , : ] Q[k,:] Q[k,:] 为矩阵 Q Q Q 的第 k k k 个行向量;
  • K T [ : , i ] K^T[:,i] KT[:,i] 为矩阵 K T K^T KT 的第 i i i 个列向量;
  • V [ i , : ] V[i,:] V[i,:] 为矩阵 V V V 的第 i i i 个行向量;
  • O [ k , : ] O[k,:] O[k,:] 为输出矩阵 O O O 的第 k k k 个行向量;
  • o i = ∑ j = 1 i a j V [ j , : ] \mathbf{o}_i=\sum_{j=1}^ia_jV[j,:] oi=j=1iajV[j,:] 是一个行向量。保存了部分求和结果 A [ k , : i ] × V [ : i , : ] A[k,:i]\times V[:i,:] A[k,:i]×V[:i,:]

2-pass self attention 分为两个 for-loop,第一个 for-loop 计算 m N m_N mN d N ′ d'_N dN (也就是 d N d_N dN):
for  i  in  [ 1 , … , N ] x i = Q [ k , : ] K T [ : , i ] m i = max ⁡ ( m i − 1 , x i ) d i ′ = d i − 1 ′ e m i − 1 − m i + e x i − m i endfor \text{for}\ i\ \text{in}\ [1,\dots,N] \hspace{5cm}\\ \begin{aligned} x_i&=Q[k,:]K^T[:,i] \\ m_i&=\max(m_{i-1},x_i) \\ d'_i&=d'_{i-1}e^{m_{i-1}-m_i}+e^{x_i-m_i} \\ \end{aligned} \\ \text{endfor} \hspace{7.2cm} \notag \\ for i in [1,,N]ximidi=Q[k,:]KT[:,i]=max(mi1,xi)=di1emi1mi+eximiendfor

第二个 for-loop 计算 softmax 结果,并与对应位置的 V V V 相乘,得到最终结果
for  i  in  [ 1 , … , N ] a i = e x i − m N d N ′ o i = o i − 1 + a i V [ i , : ] endfor \text{for}\ i\ \text{in}\ [1,\dots,N] \hspace{5cm}\\ \begin{aligned} a_i&=\frac{e^{x_i-m_N}}{d'_N} \\ \mathbf{o}_i&=\mathbf{o}_{i-1}+a_iV[i,:] \end{aligned} \\ \text{endfor} \hspace{7.2cm} \notag \\ for i in [1,,N]aioi=dNeximN=oi1+aiV[i,:]endfor
第二个循环结束后,我们将结果写入输出矩阵 O O O
O [ k , : ] = o N O[k,:]=\mathbf{o}_N \notag \\ O[k,:]=oN
我们将第二个循环中的两个等式合并(将 a i a_i ai 代入),得到:
o i : = ∑ j = 1 i ( e x j − m N d N ′ V [ j , : ] ) \mathbf{o}_i:=\sum_{j=1}^i(\frac{e^{x_j-m_N}}{d'_N}V[j,:]) \notag \\ oi:=j=1i(dNexjmNV[j,:])
这里可以看到,第二个循环的计算依赖于第一个循环结束后的结果 m N m_N mN d N ′ d'_N dN。这里我们可以再次采用 surrogate 的技巧来实现 1-pass。具体来说,我们构造一个不依赖于 m N , d N ′ m_N,d'_N mN,dN 的替代序列 o i ′ \mathbf{o}'_i oi
o i ′ = ∑ j = 1 i e x j − m i d i ′ V [ j , : ] \mathbf{o}'_i=\sum_{j=1}^i\frac{e^{x_j-m_i}}{d'_i}V[j,:] \notag \\ oi=j=1idiexjmiV[j,:]
同样的,这两个序列的最后一个元素相等 o N = o N ′ \mathbf{o_N}=\mathbf{o}'_N oN=oN。然后我们推导出 o i ′ \mathbf{o}'_i oi o i − 1 ′ \mathbf{o}'_{i-1} oi1 的迭代关系:
o i ′ = ∑ j = 1 i e x j − m i d i ′ V [ j : ] = ( ∑ j = 1 i − 1 e x j − m i d i ′ V [ j : ] ) + e x i − m i d i ′ V [ i : ] = ( ∑ j = 1 i − 1 e x j − m i − 1 d i − 1 ′ e x j − m i e x j − m i − 1 d i − 1 ′ d i ′ V [ j , : ] ) + e x i − m i d i ′ V [ i : ] = ( ∑ j = 1 i − 1 e x j − m i − 1 d i − 1 ′ ) d i − 1 ′ d i ′ e m i − 1 − m i + e x i − m i d i ′ V [ i : ] = o i − 1 ′ d i − 1 ′ e m i − 1 − m i d i ′ + e x i − m i d i ′ V [ i : ] \begin{aligned} \mathbf{o}_i'&=\sum_{j=1}^i\frac{e^{x_j-m_i}}{d'_i}V[j:] \\ &=\left(\sum_{j=1}^{i-1}\frac{e^{x_j-m_i}}{d'_i}V[j:]\right)+\frac{e^{x_i-m_i}}{d'_i}V[i:] \\ &=\left(\sum_{j=1}^{i-1}\frac{e^{x_j-m_{i-1}}}{d_{i-1}'}\frac{e^{x_j-m_i}}{e^{x_j-m_{i-1}}}\frac{d'_{i-1}}{d'_i}V[j,:]\right)+\frac{e^{x_i-m_i}}{d'_i}V[i:]\\ &=\left(\sum_{j=1}^{i-1}\frac{e^{x_j-m_{i-1}}}{d'_{i-1}}\right)\frac{d'_{i-1}}{d'_i}e^{m_{i-1}-m_i}+\frac{e^{x_i-m_i}}{d'_i}V[i:] \\ &=\mathbf{o}'_{i-1}\frac{d'_{i-1}e^{m_{i-1}-m_i}}{d_i'}+\frac{e^{x_i-m_i}}{d'_i}V[i:] \end{aligned} \notag \\ oi=j=1idiexjmiV[j:]=(j=1i1diexjmiV[j:])+dieximiV[i:]=(j=1i1di1exjmi1exjmi1exjmididi1V[j,:])+dieximiV[i:]=(j=1i1di1exjmi1)didi1emi1mi+dieximiV[i:]=oi1didi1emi1mi+dieximiV[i:]
o i ′ \mathbf{o}'_i oi 的计算仅依赖于 d i ′ , d i − 1 ′ , m i , m i − 1 , x i d'_i,d'_{i-1},m_i,m_{i-1},x_i di,di1,mi,mi1,xi,这样我们就将三个循环全都融合到了一起, 实现了 1-pass self attention,也就是 flash attention。在这一个循环中,我们做如下计算:
for  i  in  [ 1 , … , N ] x i = Q [ k , : ] K T [ : , i ] m i = max ⁡ ( m i − 1 , x i ) d i ′ = d i − 1 ′ e m i − 1 − m i + e x i − m i o i ′ = o i − 1 ′ d i − 1 ′ e m i − 1 − m i d i ′ + e x i − m i d i ′ V [ i , : ] endfor O [ k , : ] = o N ′ \text{for}\ i\ \text{in}\ [1,\dots,N] \hspace{12cm}\\ \begin{aligned} x_i&=Q[k,:]K^T[:,i] \\ m_i&=\max(m_{i-1},x_i) \\ d'_i&=d'_{i-1}e^{m_{i-1}-m_i}+e^{x_i-m_i} \\ \mathbf{o}'_i&=\mathbf{o}'_{i-1}\frac{d'_{i-1}e^{m_{i-1}-m_i}}{d'_i}+\frac{e^{x_i-m_i}}{d'_i}V[i,:] \end{aligned} \\ \text{endfor} \hspace{14.2cm} \\ O[k,:]=\mathbf{o}'_N \notag \\ for i in [1,,N]ximidioi=Q[k,:]KT[:,i]=max(mi1,xi)=di1emi1mi+eximi=oi1didi1emi1mi+dieximiV[i,:]endforO[k,:]=oN

状态 x i , m i , d i ′ , o i x_i,m_i,d'_i,\mathbf{o}_i xi,mi,di,oi 这些尺寸都比较小,可以放到 GPU SRAM 里进行计算。由于这些操作都具有结合律,因此都可以进行 tiling 操作。如果我们逐 tile 进行计算,记

  • b b b 为 block size
  • x i x_i xi 为存储第 i i i 个 tile [ ( i − 1 ) b : i b ] [(i-1)b:ib] [(i1)b:ib] Q [ k ] K T Q[k]K^T Q[k]KT 值的向量
  • m i ( local ) m_i^{(\text{local})} mi(local) 表示 x i x_i xi 内的局部最大值

tiling 版的 flash attention 可以写成如下形式:

for  i  in  [ 1 , … , N / b ] x i = Q [ k , : ] K T [ : , ( i − 1 ) b : i b ] m i ( local ) = max ⁡ j = 1 b ( x i [ j ] ) m i = max ⁡ ( m i − 1 , m i ( local ) ) d i ′ = d i − 1 ′ e m i − 1 − m i + ∑ j = 1 b e x i [ j ] − m i o i ′ = o i − 1 ′ d i − 1 ′ e m i − 1 − m i d i ′ + ∑ j = 1 b e x i − m i d i ′ V [ j + ( i − 1 ) b , : ] endfor O [ k , : ] = o N / b ′ \text{for}\ i\ \text{in}\ [1,\dots,N/b] \hspace{12cm}\\ \begin{aligned} x_i&=Q[k,:]K^T[:,(i-1)b:ib] \\ m_i^{(\text{local})}&=\max_{j=1}^b(x_i[j]) \\ m_i&=\max(m_{i-1},m_i^{(\text{local})}) \\ d'_i&=d'_{i-1}e^{m_{i-1}-m_i}+\sum_{j=1}^be^{x_i[j]-m_i} \\ \mathbf{o}'_i&=\mathbf{o}'_{i-1}\frac{d'_{i-1}e^{m_{i-1}-m_i}}{d'_i}+\sum_{j=1}^b\frac{e^{x_i-m_i}}{d'_i}V[j+(i-1)b,:] \end{aligned} \\ \text{endfor} \hspace{14.2cm} \\ O[k,:]=\mathbf{o}'_{N/b} \notag \\ for i in [1,,N/b]ximi(local)midioi=Q[k,:]KT[:,(i1)b:ib]=j=1maxb(xi[j])=max(mi1,mi(local))=di1emi1mi+j=1bexi[j]mi=oi1didi1emi1mi+j=1bdieximiV[j+(i1)b,:]endforO[k,:]=oN/b

下图展示了 Flash Attention 在 GPU 上的计算方式。蓝色块表示驻留在 SRAM 中的 tiles,而红色块对应于第 i i i 行。 L L L 表示序列长度,可能非常大(例如 16k),D 表示隐层维度,在 Transformer 中通常较小(例如 GPT3 为 128), B B B 是可以控制的块大小。值得注意的是,整体 SRAM 内存占用仅取决于 B B B D D D,与 L L L 无关。因此,该算法可以扩展到长上下文而不会遇到内存问题。在计算过程中,我们从左到右扫描 K T K^T KT A A A 的 tiles,从上到下扫描 V V V 的 tiles,并相应地更新 m , d , O m,d ,O m,d,O 的状态。

在这里插入图片描述

Recomputation

以上我们介绍的是 Flash Attention 中最关键的 tiling 的技巧,通过将 self-attention 中各个操作(GEMM、softmax 等)进行分块计算,每个块用一个 kernel 直接算完,避免了大量的访存时间开销,从而加快了 self attention 的计算。那么 Flash Attention 又是如何节省显存的呢?

在标准的 self attention 实现中,会存储中间结果矩阵 S , P ∈ R N × N S,P\in\mathbb{R}^{N\times N} S,PRN×N,用于反向传播时计算梯度。这会带来 O ( N 2 ) \mathcal{O}(N^2) O(N2) 的空间复杂度,当序列长度 N N N 很大时,这里的显存占用还是很可观的。但是在 Flash Attention 中,还记得我们除了保存输出 O O O 之外还保存了 online softmax 计算过程中的一些标准化统计量:序列最大值 m N m_N mN 和序列指数和 d N d_N dN,这样我们就能在反向传播时很容易地计算出中间矩阵 S , P S,P S,P,因此我们可以在模型前向过程中丢掉这些中间结果以节省显存。实际上,这就是我们常用的 gradient checkpointing 技巧,但是一般的 gradient checkpointing 技巧由于丢掉了一些中间结果,在反向传播重新计算时会额外增加一些耗时,整体需要在显存和计算耗时之间进行 trade-off,而 Flash Attention 却得益于 tiling 带来的计算高效性,虽然额外增加了一些 Flops,但是整体时间和显存占用都大幅降低。

总结

在 Transformer 大模型时代,Flash Attention 的重要性不言而喻。它从 I/O aware 的角度,通过 tiling 的思想来对 self attention 进行分块计算,减少访问 HBM 的时间开销,又通过 recompute 来节省训练时的显存占用,实在是软硬件结合算法设计的典范,十分值得学习。


http://www.kler.cn/a/583215.html

相关文章:

  • openai-cua-sample-app - 使用计算机的 Agent示例应用
  • 【C语言系列】字符函数和字符串函数
  • golang算法二分查找
  • Qt开源控件库(qt-material-widgets)的编译及使用
  • 【原创】springboot+vue音乐教育培训管理系统设计与实现
  • 2.angular指令
  • AI驱动的数字供应链安全情报预警服务:云脉XSBOM
  • Token登录授权、续期和主动终止的方案(Redis+Token(非jwtToken))
  • 点云深度学习系列:PVRCNN——point-voxel融合的分割模型
  • 攻防世界web:NewsCenter(含sqlmap基本参数讲解)
  • 引入其他 YML 配置源 —— Spring Boot 中的 `import` 功能
  • Axios简单说明,快速上手
  • 3.12-2 html
  • 电商数据分析 电商平台销售数据分析 电商平台数据库设计 揭秘电商怎么做数据分析
  • hadoop框架与核心组件刨析(五)ZOOKEEPER及选举深度刨析
  • llamaindex实现企业级RAG应用(一)
  • stm32-RTC时实时钟
  • C++复试笔记(二)
  • 下载文件,文件名乱码问题
  • Java高频面试之集合-10