分布式推理框架 xDit
1. xDiT 简介
xDiT 是一个为大规模多 GPU 集群上的 Diffusion Transformers(DiTs)设计的可扩展推理引擎。它提供了一套高效的并行方法和 GPU 内核加速技术,以满足实时推理需求。
1.1 DiT 和 LLM
DiT(Diffusion Transformers)是文生图与文生视频的核心网络结构。
DiT 和 LLM 推理任务特点不一样,这导致二者软件设计哲学也有差异:
1. LLM 有 Prefill 和 Decode 两阶段,分别是计算密集和访存密集的;而 DiT 的计算和 Prefill 相似是计算密集的。所以 Batching、Paged Attention 等为了增加 Decode 阶段计算密度的设计在 DiT 中不需要,其实也没有 KVCache 不断 Append 增长的现象。
2. LLM 模型很大,而序列长度有限,比如 ShareGPT 的任务生成顶多几百个 token。而 DiT 正好反过来,模型不大,推理时序列长度很长,甚至 Million 级别。这导致张量并行这种 vLLM 中第一优先级的并行方法,在 DiT 中性能很低。
3. LLM 模型大多收敛到微调 Llama 架构,而 DiT 架构则呈现出较大的差异性。例如,SD3、Latte 等网络与 Meta DiT 论文中的 DiT Block 在算子设计上存在显著差异,且 DiT Block 之间的连接方式也各不相同,常采用类似 U-Net 的 U 型结构进行链接。因此,DiT 模型的代码相对 LLM 不够简洁,往往涉及很多文件和函数,对其进行模型层面的改造也更具挑战性。
1.2 背景与挑战
DiT 模型在生成长视觉序列时的推理延迟显著增加,例如,Sora 模型的推理延迟可达数分钟,单个 GPU 无法满足实际应用的延迟要求。因此,DiT 模型的部署必然是多卡并行处理。Attention 机制导致的计算时间与序列长度呈二次增长,为了达到同样延迟指标,GPU 数应该随着序列长度平方项增长的,所有部署 Sore 很大概率需要多机多卡大集群执行单个视频生成,DiT 的并行推理扩展性至关重要。尽管大语言模型(LLMs)中常用的并行化技术,如张量并行和序列并行,通信量和激活参数正相关,在扩散模型中由于大激活小参数特点,这些技术效率不高。通信成本往往超过了并行计算的好处。因此,DiT 模型的部署仍然需要配备高带宽互连(如 NVLink+RDMA)的 GPU 集群。
随着 DiTs 中输入上下文长度的增加,Attention 机制的计算需求呈二次方增长,因此多 GPU 和多机器部署对于满足在线服务的实时需求至关重要。
1.3 Diffusion Model 推理原理和特性
尽管 DiT 和 LLM 的结构相似都是 Transformers 架构,看起来工程挑战类似。
扩散模型的训练过程:给一个图片,经过很多步骤,每一步加噪声最后变成一个全是噪音的图片。训练过程就是预测一个 Noise Predictor 的监督学习任务,输入是中间状态图片、输入指令和 timestep 来 ground truth 就是该 timestep 的噪声。
扩散模型的推理过程:给一个噪声,通过 Noise Predictor 来通过多次去噪,最后变成一个有意义的图片。每一次去噪就是一个 Diffusion Step。
通过描述可知,扩散模型推理和 LLM 的有一个很大的差异点,就是扩散模型是重复计算很多相同的 Diffusion Step,而且连续的 Diffusion Step 之间输入数据和激活状态之间存在的高度相似性,称之为输入时间冗余(Input Temporal Redundancy)。这种冗余性表明,在进行相邻时间 Diffusion Step 去噪时,可以利用前一步骤的过期 Activation 来顶替当前步骤生成的 Activation 做一些计算,可能结果也不会差太多。
在 DiT 推理中,鉴于结构的相似性,可以用 LLM 中非常成熟的张量并行(TP)和序列并行(SP),也可以使用一些利用扩散模型 Input Temporal Redundancy 特性的并行方法,比如 DistriFusion 和 PipeFusion。
1.4 xDiT Overview
正如 vLLM 让 transformers 具备 Batching 的能力,xDiT 让 diffusers 有并行的能力。考虑到 DiT 的特点,首先提出了一系列创新的并行方式,并可以相互混合。更重要的是,xDiT 提供了一套优雅的开发接口,针对性解决了 DiT 模型更改难度高的问题,这套开发接口尽可能可能复用 diffusers 的现有逻辑,开发者通过一些 wrapper 实现复杂的混合并行,实现高效地扩展方法。将之前的 PipeFusion 项目升级成了通用的 DiT 推理引擎,命名为 xDiT。(向 vLLM 致敬,X表示Scalable,希望 xDiT 成为 DiT 推理领域的 vLLM)
vLLM=huggingface transformers+Batching,xDiT=huggingface diffusers+Parallel
xDiT 是一个为大规模多 GPU 集群上的 Diffusion Transformers(DiTs)设计的可扩展推理引擎。它提供了一套高效的并行方法和 GPU 内核加速技术,以满足实时推理需求。
xDiT 有如下核心能力:
第一,DiT 主干网络混合并行:xDiT 支持任意混合四种基础并行策略方法。基础并行方法有:
-
Pipefusion Parallel:TeraPipe 方式流水线并行方法,在弱互联,比如 pcie/以太网的网络硬件上有优势,比如以太网互联的多机多卡情境下效率远高于其他方法;
-
Sequence Parallel:混合序列并行,使用 ring & ulysses sequence parallel 混合方式 USP;
-
Data Parallel:在输入多个 prompt 或单个 prompt 生成多张图片时,在 image 间并行处理;
-
CFG Parallel, a.k.a Split Batch:在模型使用 classifier free guidance(CFG)时可以开启,并行度恒定为 2。
xDiT 中这四种并行方式可以以任何形式排列组合,混合策略在通信量和通信拓扑两方面可以取得最佳效果,使 xDiT 可达到近似线性的扩展性。
第二,Parallel VAE:针对扩散模型后处理的解码 VAE 模块在高分辨率图像生成时 OOM 问题,xDiT 实现了 Patch Parallel 版本的 VAE。
第三,简单灵活的开发接口:xDiT 轻松帮助用户支持新的 DiT 模型。它尽可能复用开源社区 diffuser,完成对模型推理的并行化改造。这套接口也是让 xDiT 从之前的不同并行方法集合变成一个灵活开发工具。通过 xDiT 接口,实测 15 分钟可以让 Pixart-Sigma 模型并行起来!
1.5 xDit 并行方法
- 多GPU并行方法:
- PipeFusion:利用扩散模型的特性,通过位移Patch实现Patch级流水线并行
- Unified Sequence Parallel (USP):将 DeepSpeed-Ulysses 和 Ring-Attention 结合起来的并行方法
- Hybrid Parallel:混合并行,可配置多种并行方法以优化通信模式,适应底层网络硬件
- CFG Parallel:也称为 Split Batch,在使用无分类器指导(CFG)时激活,保持恒定的并行度为2
- Parallel VAE:并行 VAE
- 单GPU加速方法:
- 编译加速:xDiT 提供了 GPU 内核加速,以提高推理速度,torch.compile 和 onediff 这些编译加速与并行化方法结合使用
- torch.compile:PyTorch 的一个特性,利用自动图优化、操作融合以及内核选择等技术,加速PyTorch模型的执行效率。在 xDiT 中,torch.compile 用于提升 GPU 上的单 GPU 性能
- onediff:用于 PyTorch 的自动微分编译器,通过优化编译过程来提高模型的执行速度和效率
- DiTFastAttn:xDiT还为单GPU加速提供了DiTFastAttn,它可以通过利用扩散模型不同步骤之间的冗余来降低注意力层的计算成本
- 编译加速:xDiT 提供了 GPU 内核加速,以提高推理速度,torch.compile 和 onediff 这些编译加速与并行化方法结合使用
xDiT 支持多种 Diffusion Transformers 模型,包括 CogVideoX、Flux.1、HunyuanDiT、SD3、Pixart、Latte 等。
针对以上不同的模型,xDiT目前实现了不同的方式对它进行并行推理加速。例如,Latte是文生视频模型,xDiT目前实现了USP方式对它进行并行推理加速,PipeFusion还在开发中,使用混合序列并行(ulysses_degree=2, ring_degree=4)可以获得最佳性能。CogVideo 是一个文本到视频的模型,xDiT 目前整合了 USP 技术(包括 Ulysses 注意力和 Ring 注意力)和 CFG 并行来提高推理速度,同时 PipeFusion 的工作正在开发中。
在使用不同GPU数目时,最佳的并行方案都是不同的,这说明了多种并行和混合并行的重要性。 例如,HunyuanDiT最佳的并行策略在不同GPU规模时分别是:在2个GPU上,使用ulysses_degree=2;在4个GPU上,使用cfg_parallel=2, ulysses_degree=2;在8个GPU上,使用cfg_parallel=2, pipefusion_parallel=4。torch.compile带来的加速效果也很可观。stable-diffusion-3最佳的并行策略在不同GPU规模时分别是:在2个GPU上,使用cfg_parallel=2;在4个GPU上,使用cfg_parallel=2, pipefusion_parallel=2;在8个GPU上,使用cfg_parallel=2, pipefusion_parallel=4。
2. Parallel Methods
2.1 PipeFusion
流水线并行
流水线并行属于层间并行,对模型不同的 Transformer 层间进行分割,模型被分割并分布不同的设备上,每一个设备只保存模型的一部分参数,在各层之间进行并行计算。
将模型的不同层放置到不同的计算设备,可以降低单个计算设备的显存消耗。 如下图所示,例如把模型分成四个模型层,分别放置到四个不同的计算设备,即第 1 层放置到GPU 0,第 2 层放置到GPU 1,第 3 层放置到GPU 2,第 4 层放置到GPU 3。
前向计算过程中,输入数据首先在GPU 0 上通过第 1 层的计算得到中间结果,并将中间结果传输到GPU 1,然后在GPU 1 上计算得到第 2 层的输出,并将模型第 2 层的输出结果传输到GPU 2,在GPU 2 上计算得到第 3 层的输出,将模型第 3 层的输出结果传输到GPU 3,GPU 3 经由最后一层的计算最后得到前向计算结果。反向传播过程类似。
以上是训练过程,推理过程中没有反向传播,只有前向传播,最后一层的输出结果即为整个模型的推理输出。
流水线并行训练的一个明显缺点是训练设备容易出现空闲状态(因为后一个阶段需要等待前一个阶段执行完毕),导致计算资源的浪费。
PipeFusion
论文地址:https://arxiv.org/pdf/2405.14430
PipeFusion 将输入图像分割成 M 个不重叠的 Patch,并将 DiT 网络均匀分成 N 个阶段,并按顺序分配到 N 个计算设备上。
例如,下图表示在 N = 4 和 M = 4 的情况(将输入图像分割成了4个不重叠的Patch,P0、P1、P2、P3,并且将 DiT 网络均匀分成 4 个阶段,step0、1、2、3),反向扩散过程(推理)的流水线工作流程,横轴分为两个时间步:时间步 T(深灰色)和时间步 T+1(浅灰色),回顾SD的原理可知反向扩散从T+1步开始,纵轴 device 0 到 device 3 表示并行计算的 GPU。
流水线处理的概念是在一个设备完成当前任务并将结果传递给下一个设备时,它可以同时开始处理下一个Patch。例如,GPU0先处理patch P0,完成后将其传递给GPU1,自己开始处理patch P1,此时GPU1继续处理从GPU0接收的Patch0,以此类推。
不同Patch(P0、P1、P2、P3)的计算在不同设备上同时进行,在每个时间步内,不同设备处理不同的任务,整个计算过程是并行化的。
时间冗余与流水线优化:在并行计算的流水线架构中,存在一个流水线初始化阶段,即当开始计算时,设备之间需要传递数据以保持同步。这个过程会引入一些“空隙”或“延迟”,因为每个设备都需要等待上一个设备完成部分计算并传递给它。
为了克服这一延迟,采用了输入时间冗余的策略,即每个设备使用前一个时间步的旧输出作为计算的上下文,避免了当前时间步的完全等待。这使得流水线能够尽早开始,而不必等待所有输出完全计算完成。
流水线的有效计算比例:
-
M:输入图像被划分为的不重叠patches数量
-
N:DiT 网络被划分为的阶段(stages)数量
-
S:扩散时间步的数量(扩散模型通常需要多步迭代来逐渐去噪)
-
M⋅S:表示总的计算量,M⋅S 表示所有Patch在所有时间步中需要计算次数
-
N-1:流水线前面几步需要填补流水线中的空隙
假设 M = N = 4,扩散时间步 S = 50,带入公式:
说明流水线的有效计算比例为98.5%,当扩散时间步S增加时,这个空隙占总计算时间的比例变小,整体效率提升。
2.2 USP: Unified Sequence Parallelism
序列并行
注意力机制的基本工作原理:Q 与 K 的转置进行点积计算,生成一个大小为 S×S 的注意力矩阵,注意力矩阵经过softmax操作后与 V 相乘,生成最终输出,这整个过程的内存复杂度是与序列长度呈二次方关系
近年来,随着生成性AI模型上下文长度的不断增长,例如,Claude在大型语言模型(LLMs)中将序列长度扩展到了100K个token,而OpenAI的GPT-4则将上下文长度扩展到了128K个token,为了能够训练超长上下文,通常需要使用一些复杂的并行策略。
序列并行(Sequence Parallelism,SP)是一种将输入张量的序列维度分到多个计算设备上的技术,将输入序列分割成多个子序列,并将每个子序列输入到其相应的设备(即 GPU)中。在序列并行中,模型通常不会被直接切分,而是每个设备都运行同一模型的一个完整副本,但输入数据是不同的序列片段。
上图展示了序列并行(Sequence Parallelism)的工作原理。在这种并行方式中,微批次(Micro Batches)被沿着序列维度均匀地分布到不同设备上,每个设备负责处理不同的序列片段,这样能够减少单个设备的内存开销,同时加速长序列的处理。每层的输出在设备之间同步传递。
上图展示了如何将序列数据划分到不同的设备上,并在每个设备上并行地执行注意力机制:
- 输入序列被拆分成不同的部分(Qi, Ki, Vi等),Q、K、V分别代表查询(Query)、键(Key)、值(Value),序列中的每个元素的Q、K、V向量被分配到不同的设备上,分别在不同设备上进行计算
- 图中有多个设备(Device i-1, i, i+1, i+2等),每个设备负责一部分序列的计算
- 注意力计算:每个设备需要访问其他序列元素的K和V来计算注意力,因此,每个设备都会与相邻设备交换KV信息
- 每个设备计算其对应序列元素的注意力输出
通过将序列数据分割并行处理,多个设备可以同时处理不同部分的数据,提高模型的性能和效率。
综上,序列并行的关键在于它能够有效地分配和管理大规模输入序列的计算任务,让模型能够处理更长的序列,不会受到单个设备内存限制的束缚。序列并行已成为训练和推理更长序列的一种有效方法。DeepSpeed-Ulysses和Ring-Attention标志着序列并行技术的成熟,这是目前两种最重要的序列并行方法。
DeepSpeed-Ulysses
论文地址:https://arxiv.org/pdf/2309.14509
DeepSpeed-Ulysses 是微软提出的简单、易用且高效用于支持具有极长序列长度的高效可扩展LLM训练的方法
多头注意力机制计算过程:
-
输入向量 X:输入序列,每个元素的维度为 [N, d],其中 N 是序列长度,d 是隐藏层的维度
-
线性变换:输入 X 通过三个不同的线性变换(使用权重矩阵$ W_Q, W_K, W_V$)生成查询(Q)、键(K)和值(V),这三个矩阵的维度都是 [N, d]。
-
分头处理:生成的 Q, K, V 各自被分成多个“头”(这里的头数 $h_c$ = 4),每个头处理输入数据的一个子集,这可以帮助模型从多个子空间中并行学习信息,每个头的维度是 [N, $h_s$],其中 hs 是每个头的大小,等于 $d/h_c$
-
点积注意力:
-
计算点积:对于每个头,计算查询 Q 和键 K’(K 的转置)的点积
-
缩放点积:通常,这个点积还会除以一个缩放因子( $h_s$ 的平方根)
-
应用 softmax 函数:对缩放后的点积结果应用 softmax 函数,生成一个注意力权重矩阵 S
-
-
生成输出:
-
加权和:使用注意力权重矩阵 S 对值 V 进行加权求和,得到每个头的输出 $P_h$
- 连接:将所有头的输出连接起来,形成维度为 [N, d] 的矩阵 P
- 最终线性变换:P 通过另一个权重矩阵$W_0$ 进行线性变换,得到最终的多头注意力输出
-
DeepSpeed- Ulysses 计算过程:
-
输入向量 X:输入数据 X 的维度为 [N, d],其中 N 是输入序列长度,d 是隐藏层的维度,输入数据被分割成 P 个部分,每部分的维度为 [N/P, d],分别送到 P 个处理器(GPU),这里假设 P = hc = 4,即 GPU 数量等于注意力头数
-
线性变换:每个GPU上的数据通过三个不同的[d, d]维度的线性变换矩阵$W_Q, W_K, W_V$生成局部的查询(Q)、键(K)和值(V)
-
all-to-all 通信:
-
发送:每个GPU需要将其计算得到的局部 Q、$K^T$和 V 发送到其他所有GPU,每个GPU都有完整的Q、$K^T$和 V
-
接收:每个GPU将接收其他GPU发送过来的Q、$K^T$和 V 分片,组合成全局 Q、$K^T$和 V
-
-
计算注意力:
-
计算点积:每个GPU并行计算不同的注意力头的 Q 和 $K^T$的点积
-
缩放点积:除以一个缩放因子( hs 的平方根)
-
应用 softmax 函数:对缩放后的点积结果应用 softmax 函数,生成局部的注意力权重矩阵 S
-
-
生成输出:
-
加权和:使用局部的注意力权重矩阵 S 对局部的值 V 进行加权求和,得到每个头的输出 $P_h$
-
all-to-all 通信:再次使用 all-to-all 通信,将所有处理器上的$P_h$聚集起来,形成完整的输出 P
-
最终线性变换:聚集后的输出 P 通过权重矩阵 $W_0$ 进行线性变换,得到最终的多头注意力输出
-
对于每一个GPU送入处理的序列长度只是 N/P,这样大大减少了单个GPU的运算量,但同时,因为 All-To-All 通信方式的存在,做 Attention 时还是考虑了整体的长度。DeepSpeed-Ulysses 方法的特点是当序列长度和计算设备成比例增加时,通信量保持不变。
局限性:每个GPU并行计算不同的注意力头,Ulysses的SP并行度受限于注意力头数量(并行度不能超过头的数量)
Ring-Attention
GitHub:https://github.com/gpu-mode/ring-attention
💍Ring Attention,顾名思义,环形注意力,是另一种序列并行方法。
- 数据块的计算顺序可以任意:在环形结构中,各个数据块(如Q、K、V)的计算顺序可以灵活安排,不必严格按照序列顺序
- 将QKV序列分割到N个主机上:序列中的查询(Q)、键(K)和值(V)被分割并分配到多个处理单元(图中例为四个GPU)
- 构成环状结构交换KV数据:每个GPU不仅计算分配到的数据块,各个GPU形成一个环形网络,其中每个GPU会向相邻的GPU发送和接收K和V数据。这样,每个GPU都能够获取到整个序列的K和V信息来计算注意力(所需的K和V块不在本GPU,则使用点对点(P2P)通信从其他GPU获取)
- 单次循环完成所有KV部分的交换:通过环形网络的一次循环,每个GPU都能看到并处理所有的K和V数据
- 长序列处理的零开销:同时进行计算和通信,在处理较长序列时,没有额外开销,这意味着随着序列长度的增加,系统的效率不会下降
Ring Attention Algorithm 详细算法步骤:
-
输入和初始化:
-
输入序列 x
-
主机数 Nh:指的是参与计算的设备数量(GPU),这些设备将形成一个环状结构来交换数据
-
-
分割输入序列:输入序列被分割成 Nh 个块,每个块由一个GPU处理
-
计算查询、键、值(Q, K, V):每个GPU计算其对应块的查询(Q)、键(K)和值(V)
-
循环处理每个Transformer层:
对于Transformer中的每一层,执行以下步骤:-
发送和接收K和V:每个GPU将其K和V块发送给下一个GPU,并从上一个GPU接收K和V块,这个过程在所有GPU中并发执行
-
计算注意力:每个GPU使用本地的Q、K、V块来计算注意力
-
-
计算前馈网络:使用注意力输出,每个GPU并发地计算feedforword前馈网络
-
结束:经过所有Transformer层的处理后,算法结束
USP
论文地址:https://arxiv.org/pdf/2405.07719
GitHub:https://github.com/feifeibear/long-context-attention?tab=readme-ov-file
USP(Unified Sequence Parallel) 结合了DeepSpeed-Ulysses和Ring-Attention两种序列并行方法,提出了“2D混合序列并行”的概念,这种混合并行方式不仅继承了两种方法各自的优点,还克服了它们各自的局限性。通过在序列维度上的混合并行策略,允许在不同的计算设备上分别运行DeepSpeed-Ulysses和Ring-Attention,从而优化了通信模式和计算效率。对比主流的DeepSpeed-Ulysses和Ring-Attention序列并行方式,USP在DiT场景下生图的性能提升到24%,在LLAMA2-7B场景的性能提升2倍以上。
USP方法核心思想是将计算过程分为两个正交的过程组:SP-Ring过程组和SP-Ulysses过程组,形成二维网格结构,它的算法流程包括:
1) 使用DeepSpeed-Ulysses方法的all-to-all交换query、key、value;
2) 各组内进行独立的Ring-Attention计算;
3) all-to-all通信使各组具有完整的Head Output结果
算法流程如下:
对image的切分逻辑如下:
- 即在外层维度切分ring blocks,在ring block的内部切分ulysses block,每个device持有对应的ulysses block
- 在进行计算时,同一个ulysses group所属的devices先进行all2all,相当于沿hidden dim方向切分image hidden state,保证每个device上有attention中一部分heads所需要的,属于目前ulysses group的所有kv
- 随后,属于同一ring group的devices进行ring attention,通过online softmax的特性完成对整个attention结果的计算
- 值得注意的是,ulysses的通信中需要对query,key,value三者均进行all2all,而ring attention中仅有kv需要进行p2p通信。并且在ring attention中,kv通信和attn技术是可重叠的。因此在低带宽连接的设备中,采用较高的ring degree与较低的ulysses degree通常能取得更好的加速效果
将上述算法展示成图的方式如下图所示:
-
输入数据分布:输入数据按序列维度拆分到4个GPU上,每个GPU上都有部分的序列,并且每个GPU上处理的注意力头(head)也是不一样的。比如,GPU0和GPU2负责head0,GPU1和GPU3负责head1(相当于GPU0和GPU1之间进行Ulysess并行,GPU0和GPU2之间进行Ring Attention并行)
-
AlltoAll q k v:在进行多头自注意力计算之前,需要对query (q)、key (k)、value (v) 进行AlltoAll操作。在这个过程中,GPU0和GPU1之间会交换部分序列数据,GPU2和GPU3之间也会交换。这样,GPU0和GPU1都拥有相同的部分序列,但它们分别处理不同的head,GPU2和GPU3同理
-
Ring Attention计算:这一阶段的目的是在多个GPU之间计算Attention分数。Ring Attention是通过在GPU0和GPU2之间、GPU1和GPU3之间的通信完成的;从图中可以看出,GPU0和GPU2具有所有数据(GPU0有一半数据,GPU2有另一半数据),相当于它们共享的序列数据是完整的,但是只有head0数据,所以它们会一起计算head0的Attention结果,同样,GPU1和GPU3有完整token数据,但是只有head1的数据,故GPU1和GPU3一起计算head1的attention结果
-
AlltoAll输出:在完成Ring Attention计算之后,再次执行AlltoAll操作,将不同GPU上的head结果进行交换。GPU0和GPU1交换它们的计算结果,GPU0上就有了完整序列的head0和head1的结果,类似地,GPU2和GPU3也完成结果交换,每张卡上都有完整的head和序列的计算结果。例如,序列0-7拆分成是0-3和4-7,那么GPU0有0-3且head为0的计算结果,GPU1有0-3且head为1的计算结果,GPU2有4-7且head为0的计算结果,GPU3有4-7且head为1的计算结果,此时GPU0和GPU1的结果进行AlltoAll,GPU2和GPU3的结果进行AlltoAll,所以卡均获得完整head的数据,且序列完全拆分
2.3 Parallel VAE
项目代码:https://github.com/xdit-project/DistVAE
在高分辨率图片生成中,最后用于从潜在空间到像素空间转换的VAE开销很大,是整个生成过程的显存尖峰(Memory Spike),在主干网络生成结束后,添加了并行VAE的实现以支持高分辨率图片解码,提出DistVAE,避免OOM,并行高效地处理高分辨率图像
DistVAE 结合了两种关键策略:
-
序列并行:将潜在空间中的特征映射划分为多个Patch,并在不同设备上执行序列并行VAE解码,将中间激活所需的峰值内存减少到1/N,其中N是所使用的设备数量
-
Chunked Input Processing,分块输入处理:与MIT patch conv类似,将输入特征图分割成块,并将其顺序输入卷积算子,这种方法最大限度地减少了临时内存消耗
通过协同这两种方法,极大地扩展了VAE解码的能力。成功处理了高达 10240px 的图像分辨率,与默认 VAE 实现相比,提升了 11 倍。