LLM - FlashAttention 的 Safe-Softmax 与 One-Pass Tiling 计算 教程
欢迎关注我的CSDN:https://spike.blog.csdn.net/
本文地址:https://spike.blog.csdn.net/article/details/144963870
FlashAttention 是高效的 注意力机制(Attention) 算法,加速 Transformer 模型中的自注意力计算,显著减少内存占用。通过将输入分块,在每个块上执行注意力操作,从而减少对高带宽内存(HBM)的读写次数。FlashAttention 使用底层硬件的内存层次结构,如 GPU 的内存层次,提高计算速度和减少内存访问开销,保持注意力机制的精度,通过减少内存访问次数,实现更快的计算速度。FlashAttention 已经被集成到 PyTorch 2.0 中,使得在大规模模型和长序列数据处理中更加高效。
推理过程,包括:
- Self-Attention 公式
- Safe-Softmax
- Softmax 流式计算
- FlashAttention 的 One-Pass
- FlashAttention 的 Tiling(分块)
- FlashAttention 的 MQA 和 GQA
1. Self-Attention 公式
在 Self-Attention 中,需要使用 s o f t m a x softmax softmax 公式,如下:
A t t e n t i o n ( Q , K , V ) = s o f t m a x ( Q K ⊤ d k ) V Attention(Q,K,V)=softmax(\frac{QK^{\top}}{\sqrt d_{k}})V Attention(Q,K,V)=softmax(dkQK⊤)V
2. Safe-Softmax
在大模型中,使用 FP16,最大值是 65535,当 x i > 11 x_{i} > 11 xi>11 时,即发生越界,需要使用 Safe-Softmax 公式。Safe-Softmax 需要计算 3 步,即最大值m、计算分母、计算分子,需要 6 次通信 (3次写入,3次写出)
Safe-Softmax 计算公式:
X = [ x 1 , x 2 , . . . , x n ] s o f t m a x ( x ) = e x i ∑ i = 1 n e x i \begin{align} X &= [x_{1},x_{2},...,x_{n}] \\ softmax(x) &= \frac{e^{x_{i}}}{\sum_{i=1}^{n} e^{x_{i}}} \end{align} Xsoftmax(x)=[x1,x2,...,xn]=∑i=1nexiexi
由于 e x i e^{x_{i}} exi 溢出,使用 最大值max 进行优化 :
m a x ( X ) = m a x ( x 1 , x 2 , . . . , x n ) s o f t m a x ( x ) = e x i − m a x ( X ) ∑ i = 1 n e x i − m a x ( X ) \begin{align} max(X) &= max(x_{1},x_{2},...,x_{n}) \\ softmax(x) &= \frac{e^{x_{i}-max(X)}}{\sum_{i=1}^{n} e^{x_{i} - max(X)}} \end{align} max(X)softmax(x)=max(x1,x2,...,xn)=∑i=1nexi−max(X)exi−max(X)
优化之后,
e x i − m a x ( X ) ∈ ( 0 , 1 ] e^{x_{i}-max(X)} \in (0,1] exi−max(X)∈(0,1]
当 n 比较大时,GPU 缓存(SRAM) 不足,需要多次读写显存(HBM),即High Bandwidth Memory,高宽带内存。
3. Softmax 流式计算
使用 Softmax 的流式(分块)计算,使用 GPU 的 share memory 来存储中间结果,需要 2 次通信,1次写入数据,1次读取数据。
Softmax 的分块计算的逻辑:
X = [ X A , X B ] = [ x 1 a , x 2 a , . . . , x n 1 a , x 1 b , x 2 b , . . . , x n 2 b ] n = n 1 + n 2 m a x ( X ) = m a x ( x 1 a , x 2 a , . . . , x n 1 a , x 1 b , x 2 b , . . . , x n 2 b ) D ( X ) = ∑ i = 1 n 1 e x i a − m a x ( X ) + ∑ i = 1 n 2 e x i b − m a x ( X ) N ( x ) = e x i − m a x ( X ) s o f t m a x ( x ) = N ( x i ) D ( X ) \begin{align} X = [X_{A},X_{B}] &= [x_{1}^{a},x_{2}^{a},...,x_{n1}^{a},x_{1}^{b},x_{2}^{b},...,x_{n2}^{b}] \\ n &= n1 + n2 \\ max(X) &= max(x_{1}^{a},x_{2}^{a},...,x_{n1}^{a},x_{1}^{b},x_{2}^{b},...,x_{n2}^{b}) \\ D(X) &= \sum_{i=1}^{n1}e^{x_{i}^{a}-max(X)} + \sum_{i=1}^{n2}e^{x_{i}^{b}-max(X)} \\ N(x) &= e^{x_{i}-max(X)} \\ softmax(x) &= \frac{N(x_{i})}{D(X)} \end{align} X=[XA,XB]nmax(X)D(X)N(x)softmax(x)=[x1a,x2a,...,xn1a,x1b,x2b,...,xn2b]=n1+n2=max(x1a,x2a,...,xn1a,x1b,x2b,...,xn2b)=i=1∑n1exia−max(X)+i=1∑n2exib−max(X)=exi−max(X)=D(X)N(xi)
其中 m a x ( X ) max(X) max(X) 是 全局 最大值。
假设,通过分块 A A A 计算,已经缓存 3 个变量:局部最大值的 m a x ( X A ) max(X_{A}) max(XA) 、局部分母的 D ( X A ) D(X_{A}) D(XA) 、局部分子的 N ( x i a ) N(x_{i}^{a}) N(xia) ,具体步骤如下:
- 更新全局的最大值:
m a x ( X B ) = m a x ( x 1 b , x 2 b , . . . , x n 2 b ) m a x ( X ) = m a x ( m a x ( X A ) , m a x ( X B ) ) \begin{align} max(X_{B}) &= max(x_{1}^{b},x_{2}^{b},...,x_{n2}^{b}) \\ max(X) &= max(max(X_{A}), max(X_{B})) \end{align} max(XB)max(X)=max(x1b,x2b,...,xn2b)=max(max(XA),max(XB))
- 计算 局部分母的 D ( X B ) D(X_{B}) D(XB),结合 D ( X A ) D(X_{A}) D(XA) 更新全部的 D ( X ) D(X) D(X)
D ( X ) = ∑ i = 1 n 1 e x i a − m a x ( X ) + ∑ i = 1 n 2 e x i b − m a x ( X ) = ∑ i = 1 n 1 e x i a − m a x ( X A ) + m a x ( X A ) − m a x ( X ) + ∑ i = 1 n 2 e x i b − m a x ( X B ) + m a x ( X B ) − m a x ( X ) = e m a x ( X A ) − m a x ( X ) D ( X A ) + e m a x ( X B ) − m a x ( X ) D ( X B ) \begin{align} D(X) &= \sum_{i=1}^{n1}e^{x_{i}^{a}-max(X)} + \sum_{i=1}^{n2}e^{x_{i}^{b}-max(X)} \\ &= \sum_{i=1}^{n1}e^{x_{i}^{a} - max(X_{A}) + max(X_{A}) -max(X)} + \sum_{i=1}^{n2}e^{x_{i}^{b} - max(X_{B}) + max(X_{B}) -max(X)} \\ &= e^{max(X_{A})-max(X)} D(X_{A}) + e^{max(X_{B})-max(X)} D(X_{B}) \end{align} D(X)=i=1∑n1exia−max(X)+i=1∑n2exib−max(X)=i=1∑n1exia−max(XA)+max(XA)−max(X)+i=1∑n2exib−max(XB)+max(XB)−max(X)=emax(XA)−max(X)D(XA)+emax(XB)−max(X)D(XB)
- 更新 s o f t m a x ( x a ) softmax(x^{a}) softmax(xa),不需要重复计算,复用 局部分母的 D ( X A ) D(X_{A}) D(XA) 、局部分子的 N ( x i a ) N(x_{i}^{a}) N(xia)
s o f t m a x ( x a ) = e x i a − m a x ( X ) D ( X ) = e x i a + m a x ( X A ) − m a x ( X A ) − m a x ( X ) D ( X ) = e m ( X B ) − m ( X ) N ( x i a ) D ( X ) \begin{align} softmax(x^{a}) &= \frac{e^{x_{i}^{a} - max(X)}}{D(X)} \\ &= \frac{e^{x_{i}^{a} + max(X_{A}) - max(X_{A}) - max(X)}}{D(X)} \\ &= \frac{e^{m(X_{B})-m(X)}N(x_{i}^{a})}{D(X)} \end{align} softmax(xa)=D(X)exia−max(X)=D(X)exia+max(XA)−max(XA)−max(X)=D(X)em(XB)−m(X)N(xia)
- 同理,更新 s o f t m a x ( x b ) softmax(x^{b}) softmax(xb)
4. FlashAttention 的 One-Pass
类似 Online Softmax 的方法,将 Attention 所有的操作,都放到一个 for 循环里,一个 Kernel 就可以实现,实现累加操作,做到 one-pass 计算。
O u t p u t n = ∑ i = 1 n ( e x i − m n D n V i ) = ∑ i = 1 n − 1 ( e x i − m n D n V i ) + e x n − m n D n V n = ∑ i = 1 n − 1 ( e x i − m n − 1 D n − 1 e x i − m n e x i − m n − 1 D n − 1 D n V i ) + e x n − m n D n V n = O u t p u t n − 1 D n − 1 D n e m n − 1 − m n + e x n − m n D n V n \begin{align} Output_{n} &= \sum_{i=1}^{n} (\frac{e^{x_{i}-m_{n}}}{D_{n}}V_{i}) \\ &= \sum_{i=1}^{n-1} (\frac{e^{x_{i}-m_{n}}}{D_{n}}V_{i}) + \frac{e^{x_{n}-m_{n}}}{D_{n}}V_{n} \\ &= \sum_{i=1}^{n-1} (\frac{e^{x_{i}-m_{n-1}}}{D_{n-1}} \frac{e^{x_{i}-m_{n}}}{e^{x_{i}-m_{n-1}}} \frac{D_{n-1}}{D_{n}} V_{i}) + \frac{e^{x_{n}-m_{n}}}{D_{n}}V_{n} \\ &= Output_{n-1} \frac{D_{n-1}}{D_{n}} e^{m_{n-1}-m_{n}} + \frac{e^{x_{n}-m_{n}}}{D_{n}}V_{n} \end{align} Outputn=i=1∑n(Dnexi−mnVi)=i=1∑n−1(Dnexi−mnVi)+Dnexn−mnVn=i=1∑n−1(Dn−1exi−mn−1exi−mn−1exi−mnDnDn−1Vi)+Dnexn−mnVn=Outputn−1DnDn−1emn−1−mn+Dnexn−mnVn
因此, O u t p u t n Output_{n} Outputn 依赖于 O u t p u t n − 1 Output_{n-1} Outputn−1 、 m n m_{n} mn、 m n − 1 m_{n-1} mn−1,其他都是一次计算,因此实现递归计算。Flash Attention 计算过程,其实并没有减少 Attention 的计算量,也不影响精度,但是却比标准的 Attention 运算快 2~4 倍的运行速度,同时减少了 5~20 倍的内存使用量。
5. FlashAttention 的 Tiling(分块)
Tiling 即矩阵分块,通过将 大矩阵 分解为 更小的块,减少内存访问的开销,同时,提高计算效率。
矩阵分块策略:在处理大矩阵时,只加载和处理一部分数据,而不是一次性加载整个矩阵,减少内存带宽的压力。而具体到 Flash Attention 中,就是将 Q、K、V 分成更多个小块,其中 K、V 在外循环,Q 在内循环。在计算注意力分数的时候,通常需要进行 Softmax 操作。为了避免一次性计算整个 Softmax,FlashAttention 采用局部归一化策略。对于每个块,只计算这个块内部的 Softmax,在累加结果的时候进行归一化。通过逐块计算,减少全局内存的访问次数,降低了内存带宽的压力。这种策略特别适用于处理长序列的注意力机制,能够显著加速计算过程。
6. FlashAttention 的 MQA 和 GQA
MQA(Multi Query Attention),多查询注意力,GQA (Group Query Attention),分组查询注意力。
- MQA 只保留一个 KV Head,多个 Query Heads 共享相同的 KV Head。
- GQA 采取折中的做法,把 Query Heads 分组,每组 Query Heads 对应一个 KV Head。例如把 8 个 Query Heads 分成 4 组,每个 Group 包含 2 个 Query Heads,对应一个 KV Head 此时总共有 4 个 KV Heads。
对于 MQA 和 GQA,FlashAttention 采用了 Indexing 的方式,而不是直接复制多份 KV Head 的内容到显存然后再进行计算。Indexing 的思想,就是通过传入 KV Head 索引到 GPU Kernel 中,然后根据内存地址,直接从内存中读取 KV。
参考:
- CSDN - Self-Attention 机制的计算详解
- 知乎 - 从 Online-Softmax 到 FlashAttention V1/V2/V3