FlashAttention v1 论文解读
论文标题:FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness
论文地址:https://arxiv.org/pdf/2205.14135
FlashAttention 是一种重新排序注意力计算的算法,它无需任何近似即可加速注意力计算并减少内存占用。所以作为目前LLM的模型加速它是一个非常好的解决方案,本文介绍经典的V1版本。
目前FlashAttention已经推出了V1~V3版本,遗憾的是,FlashAttention V3目前只支持Nvidia Hopper架构的GPU。目前transformers库已经集成了FlashAttention。
【注】穷人玩不起系列。
FlashAttention是用于在训练或推理时加速注意力计算的方法,参考其官方仓库可以看到对于训练精度和显卡还是有较大限制的:
https://github.com/Dao-AILab/flash-attention
带有 CUDA 的 FlashAttention-2 目前支持:
GPU架构 Ampere, Ada, or Hopper GPUs(例如 A100、RTX 3090、RTX 4090、H100)。对Turing GPU(T4、RTX 2080)的支持即将推出,目前请为Turing GPU 使用 FlashAttention 1.x。
数据类型 fp16 和 bf16(bf16 需要Ampere, Ada, or Hopper GPUs)。
标准注意力机制
在介绍FlashAttention前,一定要深入了解标准注意力机制计算原理。
在 Transformer 架构当中,Attention 是整个模型中最重要的运算,而这个 Attention 的运算示意图如下:
首先我们把 Q Q Q和 K K K做矩阵相乘,接下来就是除以隐藏层维度的开根号 d \sqrt{d} d,然后我们会把运算出来的结果 S(Score)丢进 Softmax 函数得到 P,最后 P 再和 V V V做矩阵相乘就会得到 Attention 的输出 O O O。
但实际上我们会发现这一连串的运算非常的耗时间,且会使用到非常大量的内存。
在我们的 GPU 架构中,可以把内存简单地分成 HBM(高带宽内存)和 SRAM(静态随机存取存储器)两个部分。
HBM 的内存空间虽然很大,但是它的带宽比较低。
SRAM 的内存空间虽然很小,但是它的带宽非常高。
所以我们常常看到 GPU 的参数,像是 Nvidia RTX 4090 24GB,就是这张 GPU 有大约 24GB 大小的 HBM。而 SRAM 这块又贵又小的内存,就是拿来做运算的。
因此我们可以看到今天你在GPU 上运行 标准Attention 的流程如下(N:序列长度、d 是 隐藏层维度):
首先我们会把 Q Q Q和 K K K从 HBM 拉到 SRAM 运算,接下来把算出来的结果 S S S写回 HBM,然后 GPU 又把 S S S拉到 SRAM 计算 S o f t m a x Softmax Softmax,算出来的 P P P又写回 HBM,最后 P P P和 V V V从 HBM 写到 SRAM 做矩阵运算,最后输出 O O O写回 HBM。
而实际情况当然没那么简单,我们知道 SRAM 这块内存又贵又小,所以当然不可能直接把整个 Q Q Q或是 K K K加载进 SRAM,而是一小块一小块地加载。所以这样大量的读写导致 Attention 运算速度很慢,而且会有内存碎片化问题。
【注】有了上面的背景之后,我们来看看FlashAttention V1是如何优化的,下面为大家带来FlashAttention V1论文精读。
Abstract
针对Transformer在处理长序列时速度慢、内存消耗大的问题,论文提出了FlashAttention,一种IO感知的精确注意力算法。该算法通过使用平铺(tiling)技术减少GPU内存(HBM)与SRAM之间的内存读写次数,从而降低计算复杂性。
分析显示,FlashAttention减少了HBM访问次数,并优化了SRAM使用。此外,本研究将FlashAttention扩展至块稀疏注意力,实现了比现有近似注意力方法更快的近似注意力算法,为长序列处理提供了高效解决方案。
【注】标准自注意力机制的时间复杂度是 O ( n 2 ∗ d ) O(n^2*d) O(n2∗d),其中 n n n是序列长度, d d d是隐藏层维度。多头注意力只是把 d d d进行了多头拆分,单头的时间复杂度是 O ( n 2 ∗ d h ) O(n^2*d_h) O(n2∗dh),其中 d h d_h dh是单头的隐藏层维度,虽然多头之间可以并行计算,但是仍然没有解决平方量的复杂度。
Introduction
目前许多优化 attention 的方法旨在降低 attention 的计算和内存需求。这些方法专注于减少 FLOP,并且倾向于忽略内存访问 (IO) 的开销。
但是本文认为attention的一个优化方向是使算法具有 IO 感知能力。
【注】也就是说,让求注意力的操作尽可能放在SRAM里,而不是频繁的让SRAM与HBM通信。
现代的GPU,计算速度超过了内存IO速度,当读取和写入数据可能占据运行时间的很大一部分时,IO 感知算法对于加速与降内存就变得很重要了。并且深度学习的常见 Python 库(如 PyTorch 和 Tensorflow)目前还不允许对内存访问进行精细控制。
因此,FlashAttention应运而生。
论文提到,为了实现计算注意力时多使用SRAM而少与HBM交换数据,需要克服两点:
- 在输入不完整的情况下,计算 S o f t m a x Softmax Softmax;
- 不存储用于反向传播的中间结果;
FlashAttention
第一招:内核融合(Kernel Fusion)
相信聪明的朋友立刻就能明白,何必这样反复加载和卸载,一次性在SRAM中完成所有计算不就好了?没错,这就是FlashAttention的精髓之一。
FlashAttention就是直接将 Q K V QKV QKV一次性加载到SRAM中完成所有计算,然后再将 O O O写回HBM。
这样大大减少了读写次数,这种一次性完成所有计算的流程被称为内核融合(Kernel Fusion)。
第二招:反向重计算(Backward Recomputation)
但是等一下,我们是不是忘了什么?我们直接计算出了 O O O,那么 P P P和 S S S难道就直接丢弃不存回HBM吗?在进行反向传播时,我们需要从 O O O推回 P P P,再从 P P P推回 S S S,它们都被我们丢弃了,怎么进行反向传播?没错,这就是FlashAttention的第二招,反向重计算(Backward Recomputation)。
因为 P P P和 S S S这两者实在太占用空间了,所以
在前向传播时, P P P和 S S S都不会被存储起来。当进行反向传播时,我们就会重新计算一次前向传播,重新计算出 P P P和 S S S,以便执行反向传播。
所以说:我们执行了2次前向传播和1次反向传播。
这里大家可能又会问:啊这样计算量不是更多了吗,怎么可能会更快?事实上,虽然我们重新计算了一次前向传播,但它不仅帮我们省下了存储P和S的内存空间,还省下了 P P P和 S S S在HBM和SRAM之间搬运的时间,让我们可以开启更大的batch size,所以总的来说,GPU每秒能处理的数据量依然是大幅增加的。
第三招:Softmax分块(Softmax Tiling)
最后是FlashAttention的最后一招分块(Tiling)。首先我们需要知道注意力机制中的最难搞的就是 S o f t m a x Softmax Softmax函数:
s o f t m a x ( { x 1 , . . . , x N } ) = { e x i ∑ j = 1 N e x j } i = 1 N (1) softmax(\{x_1, ..., x_N\}) = \left\{\frac{e^{x_i}}{\sum_{j=1}^N e^{x_j}}\right\}_{i=1}^N \tag1 softmax({x1,...,xN})={∑j=1Nexjexi}i=1N(1)
主要原因是在计算分母时,我们需要将所有位的exp值加总。但由于SRAM的大小限制,我们不可能一次性计算出所有数值的 S o f t m a x Softmax Softmax,一定是需要一块一块地丢进SRAM进行计算,所以需要将所有中间计算的数值存储在HBM中。
在FP16精度下,最大可以表示65536,而
e 12 = 162754 e^{12} = 162754 e12=162754
为了防止在计算 S o f t m a x Softmax Softmax产生数值溢出,引入了 S a f e − s o f t m a x Safe-softmax Safe−softmax概念,其公式如下:
S a f e − s o f t m a x ( { x 1 , . . . , x N } ) = { e x i − m ∑ j = 1 N e x j − m } i = 1 N (2) Safe-softmax(\{x_1, ..., x_N\}) = \left\{\frac{e^{x_i-m}}{\sum_{j=1}^N e^{x_j-m}}\right\}_{i=1}^N \tag2 Safe−softmax({x1,...,xN})={∑j=1Nexj−mexi−m}i=1N(2)
在公式(2)中,有如下定义:
x = [ x 1 , . . . , x N ] (3) x=[x_1,...,x_N] \tag3 x=[x1,...,xN](3)
m ( x ) : = m a x ( x ) (4) m(x):=max(x) \tag4 m(x):=max(x)(4)
p ( x ) : = [ e x 1 − m ( x ) , . . . , e x N − m ( x ) ] (5) p(x):=[e^{x_1-m(x)},...,e^{x_N-m(x)}] \tag5 p(x):=[ex1−m(x),...,exN−m(x)](5)
l ( x ) : = ∑ i p ( x ) i (6) l(x):=\sum_ip(x)_i \tag6 l(x):=i∑p(x)i(6)
s o f t m a x ( x ) : = p ( x ) l ( x ) (7) softmax(x):=\frac{p(x)}{l(x)} \tag7 softmax(x):=l(x)p(x)(7)
其原理就是,从 x x x中找出最大值 m m m,在计算 S o f t m a x Softmax Softmax时,分子分母同除以 e m e^m em,这样既可以防止数据溢出,也能保证 S o f t m a x Softmax Softmax值保持不变。
【注】类似于归一化。
x = [ x 1 , … , x N , … , x 2 N ] x 1 = [ x 1 , … , x N ] x 2 = [ x N + 1 , … , x 2 N ] m ( x 1 ) p ( x 1 ) l ( x 1 ) m ( x 2 ) p ( x 2 ) l ( x 2 ) m ( x ) : = max ( m ( x 1 ) , m ( x 2 ) ) p ( x ) : = [ e m ( x 1 ) − m ( x ) p ( x 1 ) , e m ( x 2 ) − m ( x ) p ( x 2 ) ] l ( x ) : = e m ( x 1 ) − m ( x ) l ( x 1 ) + e m ( x 2 ) − m ( x ) l ( x 2 ) s o f t m a x ( x ) : = p ( x ) l ( x ) (8) \begin{align*} & x = [x_1, \ldots, x_N, \ldots, x_{2N}] \\ & x^1 = [x_1, \ldots, x_N] \\ & x^2 = [x_{N+1}, \ldots, x_{2N}] \\ & m(x^1) \ p(x^1) \ l(x^1) \ m(x^2) \ p(x^2) \ l(x^2) \\ & m(x) := \max(m(x^1), m(x^2)) \\ & p(x) := [e^{m(x^1)-m(x)} p(x^1), e^{m(x^2)-m(x)} p(x^2)] \\ & l(x) := e^{m(x^1)-m(x)} l(x^1) + e^{m(x^2)-m(x)} l(x^2) \\ & softmax(x) := \frac{p(x)}{l(x)} \end{align*}\tag8 x=[x1,…,xN,…,x2N]x1=[x1,…,xN]x2=[xN+1,…,x2N]m(x1) p(x1) l(x1) m(x2) p(x2) l(x2)m(x):=max(m(x1),m(x2))p(x):=[em(x1)−m(x)p(x1),em(x2)−m(x)p(x2)]l(x):=em(x1)−m(x)l(x1)+em(x2)−m(x)l(x2)softmax(x):=l(x)p(x)(8)
而本文softmax分块的做法如公式(8)所示。
我们首先将一块数据 x x x中的第一块 x 1 x_1 x1丢进去计算出softmax,这里的 m 1 m_1 m1代表的是这一块加载到SRAM的最大值,所以我们称之为局部最大值。接下来,我们可以根据 m 1 m_1 m1计算出局部softmax。
接下来第二块数据进来时,我们将第一块的最大值 m 1 m_1 m1和第二块的最大值 m 2 m_2 m2取最大值,就可以得到这两块数据的最大值 m ( x ) m(x) m(x)。这个时候定义 p ( x ) : = [ e m ( x 1 ) − m ( x ) p ( x 1 ) , e m ( x 2 ) − m ( x ) p ( x 2 ) ] p(x) := [e^{m(x^1)-m(x)} p(x^1), e^{m(x^2)-m(x)} p(x^2)] p(x):=[em(x1)−m(x)p(x1),em(x2)−m(x)p(x2)],再与公式(5)结合,只会出现两种情况:
- 当 m ( x 1 ) m(x^1) m(x1)最大,最后可化简为 p ( x ) : = [ e x 1 − m ( x 1 ) , . . . , e x N − m ( x 1 ) ] p(x) := [e^{x_1-m(x^1)},...,e^{x_N-m(x^1)}] p(x):=[ex1−m(x1),...,exN−m(x1)]
- 当 m ( x 2 ) m(x^2) m(x2)最大,最后可化简为 p ( x ) : = [ e x 1 − m ( x 2 ) , . . . , e x N − m ( x 2 ) ] p(x) := [e^{x_1-m(x^2)},...,e^{x_N-m(x^2)}] p(x):=[ex1−m(x2),...,exN−m(x2)]
l ( x ) l(x) l(x)的计算化简也同理,所以我们只需要将第一块的局部softmax乘上这次更新的数值。如此一来,我们就得到了这两块的局部softmax。
没错!接下来依此类推,我们就可以将整个softmax计算完。而通过这种方式:
我们就不需要将每块计算出来的数值存储在HBM中,我们只需要存储当前的最大值 m ( x ) m(x) m(x)和分母加总值 l ( x ) l(x) l(x)就可以了。
而这两者都非常小,所以可以进一步帮我们节省更多内存空间。
另外,这里还有一个小细节,就是由于softmax计算出来后需要与value state进行矩阵相乘,但同样由于SRAM有限,我们一次只能加载一块进行内核融合运算,所以第一块QKV进去后,它计算出来的O是不准确的。但由于矩阵相乘就是数字相乘,所以同样道理,我们只要在计算到下一块时,使用l和m更新O就可以了。
我们可以看到实际的流程就是这样,蓝色的区域就是HBM,橙色虚线的区域就是SRAM。每次运算时,由于SRAM大小有限,所以我们只加载一部分的Key和Value。红色的字就是我们的第一个block的计算,蓝色的字就是我们的第二个block的计算。
这边我们可以更深入探讨算法和实现部分。静态随机存取存储器(SRAM)容量较小,当序列长度很长时,根本不可能一次性将如此庞大的查询(query)、键(key)、值(value)状态全部塞进SRAM。
一开始我们会把查询状态(Query State)切成 T r T_r Tr块,键/值状态(Key/Value State)切成 T c T_c Tc块,查询状态块的大小为 ( B r , d ) (B_r, d) (Br,d),键/值状态块的大小为 ( B c , d ) (B_c, d) (Bc,d)。切好的这些块再放入SRAM进行Flash Attention运算。 你可能会好奇 B r B_r Br和 B c B_c Bc是什么神奇的数字,其实非常简单, M M M是我们SRAM的大小,并且查询(Q)、键(K)、值(V)、输出(O)这四个矩阵大小完全相同,所以当然是 M / 4 d M/4d M/4d啦,这样Q、K、V、O四个矩阵的块加起来不就刚好是 M M M嘛,也就是说刚好填满SRAM。
比如说,假设M = 1000, d = 5。那么块大小为(1000/4*5)= 50。所以一次加载50个q, k, v, o个向量的块,这样可以减少HBM/SRAM之间的读/写次数。
性能
我们可以看到 FlashAttention 大大地加速了运算,达到 3 倍以上。