当前位置: 首页 > article >正文

flash_attention简要笔记

优化效果

原来,attention部分的计算量和中间激活占用显存的复杂度都是 O ( N 2 ) O(N^2) O(N2)

计算量部分原来QK矩阵乘和attn_score@V矩阵乘的计算量,复杂度都是 O ( N 2 ) O(N^2) O(N2);中间激活因为中间有一个attn_score,所以复杂度也是 O ( N 2 ) O(N^2) O(N2)

现在,attention部分的中间激活占用显存的复杂度变为 O ( N ) O(N) O(N),计算量的复杂度没有变但是通过减少访存加快了计算速度,而且fa与原attention完全等价

具体过程

flash-attention还是基于kernel融合的思想,将QK矩阵乘法、mask、softmax、dropout合并成一个kernel,这样不仅减少了中间变量对显存的占用,而且也减少了计算过程中的访存

一些符号表示:

  • S i j = Q i × K j T S_{ij}=Q_i \times K_j^T Sij=Qi×KjT,Q分块和K分块的乘积,形状为 [ B r , B c ] [B_r, B_c] [Br,Bc]

  • m ~ i j = r o w m a x ( S i j ) \widetilde{m}_{ij}=rowmax(S_{ij}) m ij=rowmax(Sij):对分块 S i j S_{ij} Sij而言,得到其每行的最大值,形状为 [ B r , 1 ] [B_r, 1] [Br,1]

  • P ~ i j = e S i j − m ~ i j = e S i j − r o w m a x ( S i j ) \widetilde{P}_{ij}=e^{S_{ij}-\widetilde{m}_{ij}}=e^{S_{ij}-rowmax(S_{ij})} P ij=eSijm ij=eSijrowmax(Sij):每个分块 S i j S_{ij} Sij减去其局部rowmax m ~ i j \widetilde{m}_{ij} m ij,形状为 [ B r , B c ] [B_r, B_c] [Br,Bc]

  • l ~ i j = r o w s u m ( P ~ i j ) = r o w s u m ( e S i j − r o w m a x ( S i j ) ) \widetilde{l}_{ij}=rowsum(\widetilde{P}_{ij})=rowsum(e^{S_{ij}-rowmax(S_{ij})}) l ij=rowsum(P ij)=rowsum(eSijrowmax(Sij)):对 P ~ i j \widetilde{P}_{ij} P ij而言,按行求和,形状为 [ B r , 1 ] [B_r, 1] [Br,1]

  • m i n e w = m a x ( m ~ i 0 , m ~ i 1 , . . . , m ~ i j ) = r o w m a x ( c o n c a t ( S i 0 , S i 1 , . . . , S i j ) ) m^{new}_i=max(\widetilde{m}_{i0}, \widetilde{m}_{i1}, ... , \widetilde{m}_{ij})=rowmax(concat(S_{i0}, S_{i1}, ... , S_{ij})) minew=max(m i0,m i1,...,m ij)=rowmax(concat(Si0,Si1,...,Sij)):即 c o n t c a t ( S i 0 , S i 1 , . . . , S i j ) contcat(S_{i0}, S_{i1}, ... , S_{ij}) contcat(Si0,Si1,...,Sij)这j+1个分块的每行的最大值,形状为 [ B r , 1 ] [Br, 1] [Br,1]

  • m i m_i mi m i n e w m_i^{new} minew位于SRAM上,将 m i n e w m_i^{new} minew写回到HBM就是 m i m_i mi,初始化 m = − ∞ m=-\infty m=

  • l i n e w = e m i − m i n e w l i + e m ~ i j − m i n e w l ~ i j = r o w s u m [ e S 00 − m a x ( m ~ 00 , . . . , m ~ 0 j ) ] + . . . + r o w s u m [ e S 0 j − m a x ( m ~ 00 , . . . , m ~ 0 j ) ] l^{new}_i=e^{m_i-m_i^{new}}l_i + e^{\widetilde{m}_{ij}-m_i^{new}} \widetilde{l}_{ij}=rowsum[e^{S_{00}-max(\widetilde{m}_{00},...,\widetilde{m}_{0j})}] + ... + rowsum[e^{S_{0j}-max(\widetilde{m}_{00},...,\widetilde{m}_{0j})}] linew=emiminewli+em ijminewl ij=rowsum[eS00max(m 00,...,m 0j)]+...+rowsum[eS0jmax(m 00,...,m 0j)]

  • l i l_i li l i n e w l_i^{new} linew位于SRAM上,将 l i n e w l_i^{new} linew写回到HBM就是 l i l_i li,初始化 l = 0 l=0 l=0

如果不使用flash-attention,具体过程为:

  1. S = Q K T S = Q K ^T S=QKT
  2. P = s o f t m a x ( S + m a s k ) P = softmax(S+mask) P=softmax(S+mask)
  3. O = P V O = P V O=PV

如果使用flash-attention,前向过程为:

在这里插入图片描述

大致过程为:

在这里插入图片描述

  1. 首先对QKV进行分块,K、V分块方法相同(V的分块图中没画出来),首先可以计算 S i j = Q i × K j T S_{ij}=Q_i\times K_j^T Sij=Qi×KjT。因为对QKV进行了分块,所以每次SRAM上能保留 S i j S_{ij} Sij P ~ i j \widetilde{P}_{ij} P ij(橙黄色表示存储在SRAM上;橙红色表示虽然也存储在SRAM上,但是这些部分每次outer loop会写回到HBM中)
  2. 如果有mask,此时对 S i j S_{ij} Sij进行mask
  3. 使用一个局部变量 m ~ i j \widetilde{m}_{ij} m ij和一个全局变量 m m m(或者说 m n e w m^{new} mnew m n e w m^{new} mnew的值在SRAM上,但是每次outer loop会写回到HBM中)来记录分块 S i j S_{ij} Sij局部rowmax和中间遍历过的分块 S i : S_{i:} Si:的历史rowmax
  4. 然后基于分块 S i j S_{ij} Sij计算局部的safe softmax的分子部分,即 e S i j − r o w m a x ( S i j ) e^{S_{ij}-rowmax(S_{ij})} eSijrowmax(Sij),safe softmax的分子部分累加就是分母部分,这样,就得到了一个针对分块 S i j S_{ij} Sij的、局部的safe softmax的分母 l ~ i j \widetilde{l}_{ij} l ij,和 一个 遍历过的历史分块 S i : S_{i:} Si:的 safe softmax分子部分的 累加和 l n e w l^{new} lnew(注意断句,写公式有点晦涩难懂,用语言描述又不太好描述),局部的 l ~ i j \widetilde{l}_{ij} l ij就是用来更新全局的 l l l(或者说 l n e w l^{new} lnew l n e w l^{new} lnew的值在SRAM上,但是每次outer loop会写回到HBM中),对 l ~ i j \widetilde{l}_{ij} l ij举一个例子:
    • 当j=0,i=0时, l 0 n e w = e m 0 − m 0 n e w l 0 + e m ~ 00 − m 0 n e w l ~ 00 = l ~ 00 l_0^{new}=e^{m_0-m_0^{new}} l_0+e^{\widetilde{m}_{00}-m_0^{new}} \widetilde{l}_{00}=\widetilde{l}_{00} l0new=em0m0newl0+em 00m0newl 00=l 00
    • 当j=1,i=0时, l 0 n e w = r o w s u m ( e S 00 − m a x ⁡ ( m ~ 00 , m ~ 01 ) ) + r o w s u m ( e S 01 − m a x ⁡ ( m ~ 00 , m ~ 01 ) ) l_0^{new} = rowsum(e^{S_{00}-max⁡(\widetilde{m}_{00}, \widetilde{m}_{01})})+rowsum(e^{S_{01}-max⁡(\widetilde{m}_{00}, \widetilde{m}_{01})}) l0new=rowsum(eS00max(m 00,m 01))+rowsum(eS01max(m 00,m 01))
  5. 然后对 P ~ i j \widetilde{P}_{ij} P ij进行dropout
  6. 然后相当于要进行 O + = P ~ i j V i O+=\widetilde{P}_{ij} V_i O+=P ijVi了,对于算法的第15行,可以使用分配律拆开看,其中有两个操作:
    1. 后半部分:对于当前的 P ~ i j V i \widetilde{P}_{ij} V_i P ijVi相乘, P ~ i j \widetilde{P}_{ij} P ij中减去的是分块 S i j S_{ij} Sij局部的rowmax,需要调整到 此时已经见过的、所有分块 S i : S_{i:} Si:的rowmax,就是第15行后半部分中 e m ~ i j − m i n e w e^{\widetilde{m}_{ij}-m_i^{new}} em ijminew的意思
    2. 前半部分:调整上一次的 O O O,先乘旧的 l i l_i li恢复到safe softmax的分子部分,然后乘以 e m i − m i n e w e^{m_i-m_i^{new}} emiminew更新一下safe softmax分子部分中减去的全局rowmax,最后再除以当前的safe softmax的分母

(反向过程还是看别的博客吧)

简要分析

首先分析一下fa的FLOPs(只分析大块的矩阵乘法,其他小的操作就不计算了):

  • 一开始的 Q i K j T Q_i K^T_j QiKjT矩阵相乘,其中 Q i Q_i Qi的形状为 [ B r , d ] [B_r, d] [Br,d] K j t K_j^t Kjt的形状为 [ d , B c ] [d, B_c] [d,Bc],此时FLOPs= 2 d × B r × B c 2d \times B_r \times B_c 2d×Br×Bc
  • 后面计算O的时候有一个 P ~ i j V i \widetilde{P}_{ij} V_i P ijVi矩阵相乘,其中 P ~ i j \widetilde{P}_{ij} P ij的形状为 [ B r , B c ] [B_r, B_c] [Br,Bc] V i V_i Vi的形状为 [ B c , d ] [B_c, d] [Bc,d],此时FLOPs= 2 B c × B r × d 2B_c \times B_r \times d 2Bc×Br×d一共进行了 N B r × N B c \frac{N}{B_r} \times \frac{N}{B_c} BrN×BcN次上面的循环,所以FLOPs= 4 N 2 d 4N^2d 4N2d,如果d远小于N,则计算复杂度就变成了 O ( N 2 ) O(N^2) O(N2),计算复杂度相比于standard attention没有变化

然后再分析一下显存占用(显存占用说的是HBM上的显存占用,假设计算精度为 w w w Bytes)

  • HBM上需要维护一个全局的rowmax和expsum,占用显存为 w × N w\times N w×N
  • 然后还要存储一个最后的输出 O O O,占用显存为 w N d wNd wNd,但是这个部分是必须的
  • 因此,显存占用的复杂度为 O ( N d ) O(Nd) O(Nd)(或者 O ( N ) O(N) O(N),如果不考虑 O O O的话)。standard attention需要保存中间的 S , P S, P S,P,显存占用复杂度为 O ( N 2 ) O(N^2) O(N2)

fa相对于standard attention一个优势,在于减小了计算过程中的访存量,最后来分析一下访存次数:

  • standard attention
    • 从HBM中读取Q,K(形状都是 [ N , d ] [N, d] [N,d]),访存量= w N d wNd wNd,计算 S = Q K T S=QK^T S=QKT,然后向HBM中写回S(形状为 [ N , N ] [N, N] [N,N]),访存量= w N 2 wN^2 wN2
    • 从HBM中读取S,访存量= w N 2 w N^2 wN2,计算 P = s o f t m a x ( S ) P=softmax(S) P=softmax(S),向HBM中写回P,访存量= w N 2 w N^2 wN2
    • 从HBM中读取P(形状为 [ N , N ] [N, N] [N,N])、V(形状为 [ N , d ] [N, d] [N,d]),访存量= w N 2 + w N d w N^2 + wNd wN2+wNd,计算 O = P V O=PV O=PV,向HBM中写回O(形状为 [ N , d ] [N, d] [N,d]),访存量= w N d wNd wNd
    • 总的访存量= w ( 3 N d + 4 N 2 ) w(3Nd+4N^2) w(3Nd+4N2),如果d远小于N,则访存量的复杂度变成了 O ( N 2 ) O(N^2) O(N2)
  • flash attention(分析时将inner loop作为一个整体进行分析,就像上面示意图画的那样)
    • 从HBM中读取分块 Q i , i = 0 , . . . , T r − 1 Q_i, i=0, ..., T_r -1 Qi,i=0,...,Tr1,读取分块 K j K_j Kj,访存量= w ( N d + B c d ) w(Nd+B_c d) w(Nd+Bcd);后面 S i j , P ~ i j S_{ij}, \widetilde{P}_{ij} Sij,P ij不需要写回HBM; m , l m, l m,l只是一个向量,数据量很少,忽略;再后面读取和写入分块 O i , i = 0 , . . . , T r = 1 O_i, i = 0, ...,T_r =1 Oi,i=0,...,Tr=1,访存量= w ( 2 × N d ) w(2\times Nd) w(2×Nd)
    • outer loop共有 N B c = T c \frac{N}{B_c}=T_c BcN=Tc次,总的访存量= w × N B c × ( N d + B c d + 2 N d ) = w ( N d + 3 N 2 d B c ) = w ( T c + 1 ) N d w\times \frac{N}{B_c} \times (Nd + B_cd + 2Nd)=w(Nd+\frac{3N^2d}{B_c})=w(T_c+1)Nd w×BcN×(Nd+Bcd+2Nd)=w(Nd+Bc3N2d)=w(Tc+1)Nd
    • 比如N=1024,d=64,B=64,standard_attention访存量-flash_attention访存量= w ( 3 N d + 4 N 2 − N d − 3 N 2 d B c ) = w ( 2 N d + ( 4 − 3 d B c ) N 2 ) = w ( 2 N d + N 2 ) w(3Nd+4N^2-Nd-\frac{3N^2d}{B_c})=w(2Nd+(4-\frac{3d}{B_c})N^2)=w(2Nd+N^2) w(3Nd+4N2NdBc3N2d)=w(2Nd+(4Bc3d)N2)=w(2Nd+N2),可以看出少了很多访存

实际使用

接口返回值

flash-attention开源代码中,针对不同qkv、是否是varlen、是否需要kv_cache等不同需求封装了不同的接口,这里说一下返回值。这些接口的返回值都相同,除了返回输出的 O O O之外,如果设置了return_attn_probs=True,还会返回softmax_lse和S_dmask:

  • softmax_lse(形状 [ n h e a d s , s e q l e n ] [nheads, seqlen] [nheads,seqlen]):在计算 S = Q K T s c a l e S=\frac{QK^T}{scale} S=scaleQKT之后,会得到形状为 [ b s , s e q l e n , s e q l e n ] [bs, seqlen, seqlen] [bs,seqlen,seqlen]的方阵S,在计算softmax的过程中,需要按行求和,得到一个列向量,然后再取log,写成表达式即为: s o f t m a x _ l s e = l o g [ ∑ j e S i j ] softmax\_lse=log[\sum_je^{S_{ij}}] softmax_lse=log[jeSij],注意不是 s o f t m a x _ l s e = l o g [ ∑ j e S i j − r o w m a x ( S i j ) ] softmax\_lse=log[\sum_je^{S_{ij}-rowmax(S_{ij})}] softmax_lse=log[jeSijrowmax(Sij)],参考issue:What’s the exactly formula of softmax_lse? #404
  • S_dmask(形状 [ b s , n h e a d s , s e q l e n , s e q l e n ] [bs, nheads, seqlen, seqlen] [bs,nheads,seqlen,seqlen]):就是返回 P = s o f t m a x ( Q K T s c a l e + m a s k ) P=softmax(\frac{QK^T}{scale}+mask) P=softmax(scaleQKT+mask)的这个P矩阵

varlen attention

特别的,这里再说一下flash_attn_varlen_func等一些支持varlen的接口,其函数形参中还有cu_seqlens_qcu_seqlens_kmax_seqlen_qmax_seqlen_k等特有的参数。这里介绍一些varlen是什么。

varlen即变长序列,产生的背景是”数据拼接“,即LLM使用的训练数据集中,长度较短的序列占大多数,这些短序列为了能够符合Transformer固定长度的输入,就要进行padding,序列越短,padding越多,而我们不太想要padding,padding只是无奈之举。此时,我们可以使用varlen特性,简单来说就是将多个短序列拼接成一个长序列,但是还是每个短序列自己内部计算注意力,短序列之间是隔离的,这样减少了padding,节省计算量和显存。

这里举个例子(参考),比如一些短序列长度分别是:70,300,180, …,260,120,1200,…等,attention固定输入长度是4096,此时我们将这些短序列拼接起来,使用varlen_attn后,就像右图所示,每个短序列自己内部计算attention,短序列之间不计算attention(否则就像左图这样,白白多了很多浪费的计算)

在这里插入图片描述

为了实现varlen特性,需要对接口有一些调整。比如不使用varlen的flash_attn接口中,传入的Q、K、V的形状一般为 [ b s , s e q l e n , n h e a d s , h e a d _ d i m ] [bs, seqlen, nheads, head\_dim] [bs,seqlen,nheads,head_dim](K和V的nheads可以少于Q的nheads,此时就是GQA/MQA)。在使用varlen的flash_attn接口中,主要有两点变化:

  • Q、K、V的形状一般为 [ t o t a l _ s e q , n h e a d s , h e a d _ d i m ] [total\_seq, nheads, head\_dim] [total_seq,nheads,head_dim],这里将多个batch拼接起来,拼起来的长度为 t o t a l _ s e q total\_seq total_seq
  • 多了cu_seqlens_qcu_seqlens_kmax_seqlen_qmax_seqlen_k等特有的参数
    • cu_seqlens_q是对每个短序列的Q的长度的exclusive_scan,作用就是找到原来每个batch的起始点(offset),比如上面的例子,此时cu_seqlens_q=[0, 70, 370, 550, ... ],如果cu_seqlens_q的形状为 [ b a t c h _ s i z e + 1 ] [batch\_size+1] [batch_size+1],则需要在最后拼接上序列Q的总长度
    • max_seqlen_q好理解,就是短序列的Q的最长长度

在具体实现中,对每个序列的每个head分别launch kernel,来实现并行计算,这个过程中要通过cu_seqlens_q来确定对应Q的start_idx和end_idx。

参考:

Flash attention变长batching API使用

How did flash-attn compute attention for cu_seqlens #850

参考

图解大模型计算加速系列:FlashAttention V1,从硬件到计算逻辑

优质好文:

[Attention优化][2w字]🔥原理&图解: 从Online-Softmax到FlashAttention V1/V2/V3


http://www.kler.cn/a/308449.html

相关文章:

  • 【2024软考架构案例题】你知道 Es 的几种分词器吗?Standard、Simple、WhiteSpace、Keyword 四种分词器你知道吗?
  • uniCloud云对象调用第三方接口,根据IP获取用户归属地的免费API接口,亲测可用
  • 文件夹被占用了无法删除怎么办?强制粉碎文件夹你可以这样操作
  • 蔚来Java面试题及参考答案
  • 普通电脑上安装属于自己的Llama 3 大模型和对话客户端
  • 【初阶数据结构与算法】链表刷题之移除链表元素、反转链表、找中间节点、合并有序链表、链表的回文结构
  • QT程序的安装包制作教程
  • 第二十三章 加密安全标头元素
  • go-zero的快速实战(完整)
  • udp的广播,多播,单播 demo
  • 沉浸式利用自然语言无代码开发工具生成式AI产品应用(下)
  • leetcode 42 接雨水
  • 【SQL】百题计划:SQL内置函数“LENGTH“的使用
  • c++ 线程库
  • 汽车英文单词缩写汇总
  • C++学习笔记(27)
  • Rust: Warp RESTful API 如何得到客户端IP?
  • Notepad++中提升编码效率的关键快捷键
  • C++:opencv计算轮廓周长--cv::arcLength
  • 如何快速入门 Vue 3
  • MySQL基础篇(黑马程序员2022-01-18)
  • xilinx hbm ip运用
  • 自定义类型:联合和枚举
  • Java零基础-Java对象详解
  • 5. Python之数据类型
  • JVM字节码与局部变量表