当前位置: 首页 > article >正文

LLM 优化技术(2)——paged_attention 原理

1 从 flash attention 到 paged attention

1.1 Flash Attention 基本原理

Flash Attention 计算 self-attention 的关键是有效地硬件利用。GPU 的典型操作方式是使用大量的线程来执行一个操作,这个操作被称为内核。输入从HBM加载到寄存器和SRAM,并在计算后写回HBM。

Flash attention 的优化思路:

考虑到 self-attention 计算过程中,对 HBM 的重复读写是一个主要瓶颈。针对这个问题,Flash Attention 通过减少对 HBM 的访问次数来提高性能:

  • 在不访问整个输入的情况下计算 softmax
  • 不为反向传播存储巨大的中间 attention 矩阵

为此 FlashAttention 提出了两种方法来分布解决上述问题:tiing 和 recomputation:

  • tiling,将输入分割成块,通过在输入块上进行多次传递来递增地计算 softmax
  • recomputation,存储来自前向的 softmax 归一化因子,以便在反向中快速重新计算芯片上的 attention,这比从HBM读取中间矩阵的标准注意力方法更快

通过重新计算,会导致总的 Flops 增加,但是这样可以减少大量的 HBM 访问,从而有效提高 Flash Attention 的计算速度。

Flash Attention 的核心思想是分割输入,将它们从慢速HBM加载到快速SRAM,然后计算这些块的 attention 输出。在将每个块的输出相加之前,将其按正确的归一化因子进行缩放,从而得到正确的结果。

1.2 Tiling 与前向计算

分块计算注意力的关键部分是 softmax 的分块计算。向量的 softmax 可以计算为:

m ( x ) = m a x ( x ) f ( x ) = [ e x 1 − m ( x ) , ⋯   , e x B − m ( x ) ] l ( x ) = ∑ i f ( x ) i s o f t m a x ( x ) = f ( x ) l ( x ) \begin{align*} & m(x) = max(x) \\ & f(x) = [e^{x_1 - m(x)}, \quad \cdots , \quad e^{x_B - m(x)}] \\ & l(x) = \sum _i f(x)_i \\ & softmax(x) = \frac{f(x)}{l(x)} & \end{align*} m(x)=max(x)f(x)=[ex1m(x),,exBm(x)]l(x)=if(x)isoftmax(x)=l(x)f(x)

其中 x x x 可以分解为 x = [ x ( 1 ) x ( 2 ) ] ∈ R 2 B , x ( 1 ) , x ( 2 ) ∈ R B x = [x^{(1)} x^{(2)}] \in \mathbb R ^{2B}, \quad x^{(1)}, x^{(2)} \in \mathbb R ^B x=[x(1)x(2)]R2B,x(1),x(2)RB,那么有:

m ( x ) = m ( [ x ( 1 ) x ( 2 ) ] ) = m a x ( m ( x ( 1 ) ) , m ( x ( 2 ) ) ) \begin{align*} m(x) = & m([x^{(1)} \quad x^{(2)}]) \\ = & max(m(x^{(1)}), \quad m(x^{(2)})) & \end{align*} m(x)==m([x(1)x(2)])max(m(x(1)),m(x(2)))

那么可以通过如下构造的方式,使得 f ( x ) f(x) f(x) 的结果与分块前保持统一:

f ( x ) = [ e m ( x ( 1 ) ) − m ( x ) f ( x ( 1 ) ) e m ( x ( 1 ) ) − m ( x ) f ( x ( 1 ) ) ] = [ e m ( x ( 1 ) ) − m ( x ) [ e x 1 ( 1 ) − m ( x ( 1 ) ) ⋯ e x B ( 1 ) − m ( x ( 1 ) ) ] e m ( x ( 2 ) ) − m ( x ) [ e x 1 ( 2 ) − m ( x ( 2 ) ) ⋯ e x B ( 2 ) − m ( x ( 2 ) ) ] ] = [ [ e x 1 ( 1 ) − m ( x ) ⋯ e x B ( 1 ) − m ( x ) ] [ e x 1 ( 2 ) − m ( x ) ⋯ e x B ( 2 ) − m ( x ) ] ] = [ e x 1 − m ( x ) ⋯ e x B − m ( x ) ] \begin{align*} f(x) = & [e^{m(x^{(1)}) - m(x)} f(x^{(1)}) \quad e^{m(x^{(1)}) - m(x)} f(x^{(1)})] \\ = & [e^{m(x^{(1)}) - m(x)} [e^{x_1 ^{(1)} - m(x^{(1)})} \quad \cdots \quad e^{x_B ^{(1)} - m(x^{(1)})}] \quad e^{m(x^{(2)}) - m(x)} [e^{x_1 ^{(2)} - m(x^{(2)})} \quad \cdots \quad e^{x_B ^{(2)} - m(x^{(2)})}]] \\ = & [[e^{x_1 ^{(1)} - m(x)} \quad \cdots \quad e^{x_B ^{(1)} - m(x)}] \quad [e^{x_1 ^{(2)} - m(x)} \quad \cdots \quad e^{x_B ^{(2)} - m(x)}]] \\ = & [e^{x_1 - m(x)} \quad \cdots \quad e^{x_B - m(x)}] & \end{align*} f(x)====[em(x(1))m(x)f(x(1))em(x(1))m(x)f(x(1))][em(x(1))m(x)[ex1(1)m(x(1))exB(1)m(x(1))]em(x(2))m(x)[ex1(2)m(x(2))exB(2)m(x(2))]][[ex1(1)m(x)exB(1)m(x)][ex1(2)m(x)exB(2)m(x)]][ex1m(x)exBm(x)]

l ( x ) l(x) l(x) 的构造方式同理:

l ( x ) = l ( [ x ( 1 ) x ( 2 ) ] ) = e m ( x ( 1 ) ) − m ( x ) l ( x ( 1 ) ) + e m ( x ( 2 ) ) − m ( x ) l ( x ( 2 ) ) \begin{align*} l(x) = & l([x^{(1)} \quad x^{(2)}]) \\ = & e^{m(x^{(1)}) - m(x)} l(x^{(1)}) + e^{m(x^{(2)}) - m(x)} l(x^{(2)}) & \end{align*} l(x)==l([x(1)x(2)])em(x(1))m(x)l(x(1))+em(x(2))m(x)l(x(2))

从而可以得到 s o f t m a x softmax softmax 结果的形式,也可保持与分块前一致:

s o f t m a x ( x ) = f ( x ) l ( x ) softmax(x) = \frac{f(x)}{l(x)} softmax(x)=l(x)f(x)

softmax 的分块计算解决后,其他部分的分块计算就简单了。

完整的 decoder 的 self-attention 的计算过程:

S = Q K T d ∈ R N × N S m a s k e d = M A S K ( S ) ∈ R N × N P = s o f t m a x ( S m a s k e d ) ∈ R N × N p d r o p p e d = d r o p o u t ( P , p d r o p ) ∈ R N × N O = p d r o p p e d V ∈ R N × d \begin{align*} S = & \frac{QK^T}{\sqrt{d}} \quad \in \mathbb R ^{N \times N} \\ S^{masked} = & MASK(S) \quad \in \mathbb R ^{N \times N} \\ P = & softmax(S^{masked}) \quad \in \mathbb R ^{N \times N} \\ p^{dropped} = & dropout(P, p_{drop}) \quad \in \mathbb R ^{N \times N} \\ O = & p^{dropped} V \quad \in \mathbb R ^{N \times d} & \end{align*} S=Smasked=P=pdropped=O=d QKTRN×NMASK(S)RN×Nsoftmax(Smasked)RN×Ndropout(P,pdrop)RN×NpdroppedVRN×d

flash attention 的算法没有增加额外的计算量,只是将大块的操作拆分成多个小块逐步计算,因此算法的复杂度不变,仍为 O ( N 2 d ) O(N^2 d) O(N2d)。另外,增加了中间变量 l , m l, m l,m,因此空间复杂度增加 O ( N ) O(N) O(N)

1.3 Recomputation 与反向计算

Flash Attention 的后向传播需要 S S S P P P 矩阵来计算 Q , K , V Q, K, V Q,K,V 的梯度,空间复杂度是 O ( N 2 ) O(N^2) O(N2) S S S P P P 没有显式存储。

解决方法是使用输出 O O O s o f t m a x softmax softmax 归一化统计 ( m , l ) (m, l) (m,l),可以利用 SRAM 中的 Q , K , V Q, K, V Q,K,V 重新计算 S S S P P P 矩阵。这个过程会增加额外的计算量,但是减少了对 HBM 的访存,整体上加快了后向传播的速度。

标准 self-attention 的反向的求导过程:

注意力机制 O = s o f t m a x ( Q K T d k ) V O = softmax(\frac{QK^T}{\sqrt{d_k}}) V O=softmax(dk QKT)V,反向传播时,需要根据损失函数 Φ \Phi Φ 对模块输出的导数 d O dO dO (即 ∂ Φ ∂ O \frac{\partial \Phi}{\partial O} OΦ),求出其对输入的导数 d Q , d K , d V dQ, dK, dV dQ,dK,dV(即 ∂ Φ ∂ Q , ∂ Φ ∂ K , ∂ Φ ∂ V \frac{\partial \Phi}{\partial Q}, \frac{\partial \Phi}{\partial K}, \frac{\partial \Phi}{\partial V} QΦ,KΦ,VΦ)。

P = s o f t m a x ( S ) , S = Q K T d k , O = P V P = softmax(S), \quad S = \frac{QK^T}{\sqrt{d_k}}, \quad O = PV P=softmax(S),S=dk QKT,O=PV,则其求导过程如下:

∂ O ∂ V = P T \frac{\partial O}{\partial V} = P^T VO=PT 得到:

∂ Φ ∂ V = P T ∂ Φ ∂ O \frac{\partial \Phi}{\partial V} = P^T \frac{\partial \Phi}{\partial O} VΦ=PTOΦ

即:

d V = P T d O dV = P^T dO dV=PTdO

其余导数计算类似。

Flash Attention 将 d V dV dV 的计算拆分成若干个子块 d V j dV_j dVj 的计算,从而可以使用 Tiling 技术:

d V j = ( P T ) J d O = ∑ i ( P T ) j i d O i = ∑ i p i j T d O i \begin{align*} dV_j = & (P^T)_J dO \\ = & \sum _i (P^T)_{ji} dO_i \\ = & \sum _i p^T _{ij} dO_i & \end{align*} dVj===(PT)JdOi(PT)jidOiipijTdOi

P P P 矩阵的大小是 O ( N 2 ) O(N^2) O(N2),为了减小内存的消耗和对 HBM 的访存,FlashAttention 在反向传播时重新计算
P P P 而非在前向传播时保存:

P i j = ( s o f t m a x ( S ) ) i j = e S i j − δ T ( S i m a x , C N ) δ T ( S i s u m , C N ) \begin{align*} P_{ij} = & (softmax(S))_{ij} \\ = & \frac{e^{S_{ij}} - \delta ^T (S_i ^{max}, C_N)}{\delta ^T (S_i ^{sum}, C_N)} & \end{align*} Pij==(softmax(S))ijδT(Sisum,CN)eSijδT(Simax,CN)

其中 S i m a x , S i s u m S_i^{max}, S_i^{sum} Simax,Sisum 是前向的时候保存的,且 S i j = Q i ( K T ) j d k = Q i K j T d k S_{ij} = \frac{Q_i (K^T)_j}{\sqrt{d_k}} = \frac{Q_i K_j^T}{\sqrt {d_k}} Sij=dk Qi(KT)j=dk QiKjT

根据前面的推导,同理可得:

d P = d O V T dP = dOV^T dP=dOVT

考虑到 s i = q i K T d k , p i = s o f t m a x ( s i ) s_i = \frac{q_i K^T}{\sqrt {d_k}}, \quad p_i = softmax(s_i) si=dk qiKT,pi=softmax(si),可以得到:

d q i = d s i ∂ s i ∂ q i = d s i K d k \begin{align*} dq_i = & ds_i \frac{\partial s_i}{\partial q_i} \\ = & ds_i \frac{K}{\sqrt{d_k}} & \end{align*} dqi==dsiqisidsidk K

d s i = d p i ∂ p i ∂ s i = d p i ( d i a g ( p i ) − p i T p i ) = d p i ⊙ p i − d o i o i T p i \begin{align*} ds_i = & dp_i \frac{\partial p_i}{\partial s_i} \\ = & dp_i (diag(p_i) - p_i^T p_i) \\ = & dp_i \odot p_i - do_i o_i^T p_i & \end{align*} dsi===dpisipidpi(diag(pi)piTpi)dpipidoioiTpi

拓展 d q i dq_i dqi d Q i dQ_i dQi:

d Q i = d S i K d k = ∑ j d S i j K j d k dQ_i = dS_i \frac{K}{\sqrt{d_k}} = \sum _j dS_{ij} \frac{K_j}{d_k} dQi=dSidk K=jdSijdkKj

计算 d S i j dS_{ij} dSij

d S i = d P i ⊙ P i − α ( β T ( d O i ⊙ O i ) , N ) ⊙ P i = ( d P i − α ( β T ( d O i ⊙ O i ) , N ) ) ⊙ P i d S i j = ( d P i j − α ( β T ( d O i ⊙ O i ) , C N ) ) ⊙ P i j \begin{align*} dS_i = & dP_i \odot P_i - \alpha (\beta ^T (dO_i \odot O_i), N) \odot P_i \\ = & (dP_i - \alpha (\beta ^T (dO_i \odot O_i), N)) \odot P_i \\ \\ dS_{ij} = & (dP_{ij} - \alpha (\beta ^T (dO_i \odot O_i), C_N))\odot P_{ij} & \end{align*} dSi==dSij=dPiPiα(βT(dOiOi),N)Pi(dPiα(βT(dOiOi),N))Pi(dPijα(βT(dOiOi),CN))Pij

这里, P i j P_{ij} Pij 是重计算得到的:

d P i j = d O i ( V T ) j = d O i V j T dP_{ij} = dO_i (V^T)_j = dO_iV_j^T dPij=dOi(VT)j=dOiVjT

计算 d K j dK_j dKj:

d K = ( d K T ) T = d S T Q d k d K j = ( d S T ) j Q d k = ∑ i ( d S T ) j i Q d k = ∑ i d S i j T Q d k \begin{align*} dK = & (dK^T)^T = dS^T \frac{Q}{\sqrt{d_k}} \\ dK_j = & (dS^T)_j \frac{Q}{\sqrt{d_k}} = \sum _i (dS^T)_{ji} \frac{Q}{\sqrt{d_k}} \\ = & \sum _i dS_{ij} ^T \frac{Q}{\sqrt{d_k}} \end{align*} dK=dKj==(dKT)T=dSTdk Q(dST)jdk Q=i(dST)jidk QidSijTdk Q

1.4 PagedAttention

vLLM 主要用于快速 LLM 推理和服务,其核心是 PagedAttention,这是一种新颖的注意力算法,将在操作系统的虚拟内存中分页的经典思想引入到 LLM 服务中。在无需任何模型架构修改的情况下,可以做到比 HuggingFace Transformers 提供高达 24 倍的 Throughput。

1.4.1 PagedAttention 的基本原理

vLLM 的特点:

  • 非常高的服务吞吐性能

  • 使用 paged Attention 优化 KV cache 内存管理

  • 动态 batch

  • 优化的 cuda Kernel

  • 各种 decoder 算法,包括并行采样、beam search 等

  • 张量并行(TP)以支持分布式推理

  • 流输出

  • 兼容 OpenAI 的 API 服务

KV cache 占用巨大的显存,在 LLaMA-13B 中,缓存单个序列最多需要 1.7GB 显存;另外,KV cache 动态变化,:KV 缓存的大小取决于序列长度,这是高度可变和不可预测的。由于碎片化和过度保留,现有系统浪费了 60% - 80% 的显存。为解决这样的问题,paged Attention 出现了。

PagedAttention 是一种受操作系统中虚拟内存和分页经典思想启发的注意力算法。与传统的注意力算法不同,PagedAttention 允许在非连续的内存空间中存储连续的 key 和 value 。

具体来说,PagedAttention 将每个序列的 KV cache 划分为块,每个块包含固定数量 token 的键和值。在注意力计算期间,PagedAttention 内核可以有效地识别和获取这些块。

因为块在内存中不需要连续,因而可以用一种更加灵活的方式管理 key 和 value ,就像在操作系统的虚拟内存中一样:可以将块视为页面,将 token 视为字节,将序列视为进程。序列的连续逻辑块通过块表映射到非连续物理块中。物理块在生成新 token 时按需分配。

在 PagedAttention 中,内存浪费只会发生在序列的最后一个块中。这使得在实践中可以实现接近最佳的内存使用,仅浪费不到 4%。这种内存效率的提升被证明非常有用,允许系统将更多序列进行批处理,提高 GPU 使用率,显著提升吞吐量。

PagedAttention 还有另一个关键优势 —— 高效的内存共享。例如在并行采样中,多个输出序列是由同一个 prompt 生成的。在这种情况下,prompt 的计算和内存可以在输出序列中共享。

并行采样:

在这里插入图片描述

PagedAttention 自然地通过其块表格来启动内存共享。与进程共享物理页面的方式类似,PagedAttention 中的不同序列可以通过将它们的逻辑块映射到同一个物理块的方式来共享块。为了确保安全共享,PagedAttention 会对物理块的引用计数进行跟踪,并实现写时复制(Copy-on-Write)机制。

PageAttention 的内存共享大大减少了复杂采样算法的内存开销,例如并行采样和集束搜索的内存使用量降低了 55%。这可以转化为高达 2.2 倍的吞吐量提升。这种采样方法也在 LLM 服务中变得实用起来。

1.4.2 attention_ops.single_query_cached_kv_attention op

vLLM 的核心是 PagedAttention ,PagedAttention 核心是 attention_ops.single_query_cached_kv_attention op,完整代码vllm/tests/kernels/test_attention.py 。


http://www.kler.cn/a/613232.html

相关文章:

  • 几种常见的.NET单元测试模拟框架介绍
  • leetcode day32 763+56
  • 【软件测试】:软件测试实战
  • I.MX6ULL 开发板上挂载NTFS格式 U 盘
  • vue将页面导出成word
  • Python_电商erp自动拆分组合编码
  • 规范Unity工程目录和脚本结构能有效提升开发效率、降低维护成本
  • Maven中为什么有些依赖不用引入版本号
  • 【ManiSkill】环境success条件和reward函数学习笔记
  • YOLOv8 中的损失函数解析
  • 构建可扩展、可靠的网络抓取、监控和自动化应用程序的终极指南
  • 【天梯赛】L2-004 这是二叉搜索树吗(经典问题C++)
  • Go语言中regexp模块详细功能介绍与示例
  • 什么是架构,以及当前市面主流架构类型有哪些?
  • X.509证书与证书请求生成原理及其应用(C/C++代码实现)
  • STM32基础教程——旋转编码器测速
  • Mysql的单表查询和多表查询
  • 记录一次TDSQL事务太大拆过binlog阈值报错
  • Python+requests实现接口自动化测试框架
  • JavaWeb——事务管理、AOP