浅谈FlashAttention优化原理
FlashAttention原理简介
背景:
在Transformer 结构中,自注意力机制的时间和存储复杂度与序列的长度呈平方的关系,因此占用了大量的计算设备内存和并消耗大量计算资源。如何优化自注意力机制的时空复杂度、增强计算效率是大语言模型需要面临的重要问题
1.FlashAttention
NVIDIA GPU 中的内存(显存)按照它们物理上是在GPU 芯片内部还是板卡RAM 存储芯片上,决定了它们的速度、大小以及访问限制。
GPU 显存分为:全局内存(Global memory)、本地内存(Local memory)、共享内存(Shared memory,SRAM)、寄存器内存(Register memory)、常量内存(Constant memory)、纹理内存(Texture memory)等六大类。
下图给出了NVIDIA GPU 内存的整体结构。其中全局内存、本地内存、共享内存和寄存器内存具有读写能力。全局内存和本地内存使用的高带宽显存(High Bandwidth Memory,HBM
)位于板卡RAM 存储芯片上,该部分内存容量很大。
全局内存是所有线程都可以访问,而本地内存则只能当前线程访问。NVIDIA H100 中全局内存有80GB 空间,其访问速度虽然可以达到3.35TB/s,但是如果全部线程同时访问全局内存时,其平均带宽仍然很低。
共享内存和寄存器位于GPU 芯片上,因此容量很小,并且共享内存只有在同一个GPU 线程块(Thread Block
)内的线程才可以共享访问,而寄存器仅限于同一个线程内部才能访问。NVIDIA H100 中每个GPU 线程块在流式多处理器(Stream Multi-processor,SM
)可以使用的共享存储容量仅有228KB,但是其速度非常快,远高于全局内存的访问速度。
通过自注意力机制的原理可知,在GPU 中进行计算时,传统的方法还需要引入两个中间矩阵S 和P
并存储到全局内存中。具体计算过程如下:
S = Q ×K, P = Softmax(S), O = P × V (2.27)
按照上述计算过程:
-
需要首先从全局内存中读取
矩阵Q 和K
,并将计算好的矩阵S
再写入全局内存 -
再从全局内存中获取
矩阵S
,计算Softmax 得到矩阵P
,再写入全局内容 -
读取
矩阵P
和矩阵V
,计算得到矩阵O
。
这样的过程会极大占用显存的带宽。在自注意力机制中,计算速度比内存速度快得多,因此计算效率越来越多地受到全局内存访问的瓶颈。
FlashAttention就是通过利用GPU 硬件中的特殊设计,针对全局内存和共享存储的I/O 速度的不同,尽可能的避免HBM 中读取或写入注意力矩阵。
FlashAttention 目标:是尽可能高效地使用SRAM 来加快计算速度,避免从全局内存中读取和写入注意力矩阵。达成该目标需要能做到在不访问整个输入的情况下计算Softmax 函数,并且后向传播中不能存储中间注意力矩阵。
标准Attention 算法中:
-
Softmax
计算按行进行,即在与V 做矩阵乘法之前,需要将Q、K
的各个分块完成一整行的计算。 -
在得到
Softmax
的结果后,再与矩阵V 分块做矩阵乘。
而在FlashAttention 中:
- 将输入分割成块,并在输入块上进行多次传递,从而以增量方式执行
Softmax
计算。
自注意力算法的标准实现将计算过程中的矩阵S、P 写入全局内存中,而这些中间矩阵的大小与输入的序列长度有关且为二次型。
因此,FlashAttention 就提出了不使用中间注意力矩阵,通过存储归一化因子来减少全局内存的消耗。FlashAttention 算法并没有将S、P 整体写入全局内存,而是通过分块写入,存储前向传递的Softmax 归一化因子,在后向传播中快速重新计算片上注意力,这比从全局内容中读取中间注意力矩阵的标准方法更快。由于大幅度减少了全局内存的访问量,即使重新计算导致FLOPs 增加,但其运行速度更快并且使用更少的内存。
PyTorch 2.0 中已经可以支持FlashAttention,使用“torch.backends.cuda.enable_flash_sdp()
”启
用或者关闭FlashAttention 的使用。