flash attention
Flash Attention 优化注意力计算原理解析
Flash Attention 通过分块处理和增量更新的方式优化注意力计算,有效减少内存占用。其核心在于逐步更新归一化因子并调整输出,确保结果正确性。具体步骤如下:
一、无最大值处理
1. 初始化
- 设置局部归一化因子
D_i=0
。 - 初始化输出累加值
O_i=0
。
2. 逐块处理
对每个键值块执行以下操作:
2.1 计算当前块的指数值
- 计算查询与当前块键的点积
e_ij
,并求其指数exp(e_ij)
。
2.2 更新归一化因子
- 累加当前块的指数和至
D_i
,即:
D n e w = D i + ∑ j exp ( e i j ) D_{new} = D_i + \sum_j \exp(e_{ij}) Dnew=Di+j∑exp(eij)
2.3 修正历史输出并累加当前贡献
-
调整之前的输出 O i O_i Oi 的比例:
O i ← O i × D i D n e w O_i \leftarrow O_i \times \frac{D_i}{D_{new}} Oi←Oi×DnewDi -
计算当前块的注意力权重:
α i j = exp ( e i j ) D n e w \alpha_{ij} = \frac{\exp(e_{ij})}{D_{new}} αij=Dnew