Stable Diffusion 3 论文
Stable Diffusion 3 论文
文章目录
- Stable Diffusion 3 论文
- 摘要
- Abstract
- 一、论文阅读
- 1. 核心贡献
- 2. 整流模型简介
- 3. 非均匀训练噪声采样
- 4. 网络整体架构
- 5. 提升自编码器通道数
- 6. 多模态 DiT (MM-DiT)
- 7. 比例可变的位置编码
- 8. 训练数据预处理
- 9. 用 QK 归一化提升训练稳定度
- 10. 哪种扩散模型训练目标最适合文生图任务?
- 11. 参数扩增实验结果
- 总结
摘要
本周主要阅读了Stable Diffusion 3(SD3)的核心论文《Scaling Rectified Flow Transformers for High-Resolution Image Synthesis》。文章提出了一种基于整流模型(Rectified Flow)和Diffusion Transformer(DiT)的新型生成模型,并进行了大规模实验验证其性能。在网络架构上,SD3 引入了多模态 DiT(MM-DiT),并通过提升自编码器通道数、引入比例可变的位置编码等技术优化生成效果。实验结果表明,整流模型的生成质量在多个数据集上优于传统扩散模型,并且模型性能随着参数量扩增呈现出一致的增长趋势。
Abstract
This week focused on the core paper of Stable Diffusion 3 (SD3), titled Scaling Rectified Flow Transformers for High-Resolution Image Synthesis. The paper introduces a novel generative model based on rectified flow and Diffusion Transformer (DiT) and conducts large-scale experiments to validate its performance. In terms of architecture, SD3 incorporates a Multimodal DiT (MM-DiT) and optimizes image generation through enhancements such as increasing the latent space channel count and introducing scalable positional encoding. Experimental results demonstrate that rectified flow outperforms traditional diffusion models across several datasets, with model performance consistently improving as parameter scale increases.
一、论文阅读
1. 核心贡献
Stable Diffusion 3 (SD3) 的文章标题为 Scaling Rectified Flow Transformers for High-Resolution Image Synthesis。正如其标题所示,这篇文章的内容很简明,就是用整流 (rectified flow) 生成模型、Transformer 神经网络做了模型参数扩增实验,以实现高质量文生图大模型。文章的核心贡献如下:
从方法设计上:
- 首次在大型文生图模型上使用了整流模型。
- 用一种新颖的 Diffusion Transformer (DiT) 神经网络来更好地融合文本信息。
- 使用了各种小设计来提升模型的能力。如使用二维位置编码来实现任意分辨率的图像生成。
从实验上:
- 开展了一场大规模、系统性的实验,以验证哪种扩散模型/整流模型的学习目标最优。
- 开展了扩增模型参数的实验 (scaling study),以证明提升参数量能提升模型的效果。
2. 整流模型简介
由于 SD3 最后用了整流模型来建模图像生成,所以文章是从一种称为流匹配 (Flow Matching) 的角度而非更常见的扩散模型的角度来介绍各种训练目标。鉴于 SD3 并没有对其他论文中提出的整流模型做太多更改,我们在阅读本文时可以主要关注整流的想法及其与扩散模型的关系,后续再从其他论文中学习整流的具体原理。在此,我们来大致认识一下流匹配与整流的想法。
所谓图像生成,其实就是让神经网络模型学习一个图像数据集所表示的分布,之后从分布里随机采样。比如我们想让模型生成人脸图像,就是要让模型学习一个人脸图像集的分布。为了直观理解,我们可以用二维点来表示一张图像的数据。比如在下图中我们希望学习红点表示的分布,即我们希望随机生成点,生成的点都落在红点处,而不是落在灰点处。
我们很难表示出一个适合采样的复杂分布。因此,我们会把学习一个分布的问题转换成学习一个简单好采样的分布到复杂分布的映射。一般这个简单分布都是标准正态分布。如下图所示,我们可以用简单的算法采样在原点附近的来自标准正态分布的蓝点,我们要想办法得到蓝点到红点的映射方法。
学习这种映射依然是很困难的。而近年来包括扩散模型在内的几类生成模型用一种巧妙的方法来学习这种映射:从纯噪声(标准正态分布里的数据)到真实数据的映射很难表示,但从真实数据到纯噪声的逆映射很容易表示。所以,我们先人工定义从图像数据集到噪声的变换路线(红线),再让模型学习逆路线(蓝线)。让噪声数据沿着逆路线走,就实现了图像生成。
我们又可以用一种巧妙的方法间接学习图像生成路线。知道了预定义的数据到噪声的路线后,我们其实就知道了数据在路线上每一位置的速度(红箭头)。那么,我们可以以每一位置的反向速度(蓝箭头)为真值,学习噪声到真实数据的速度场。这样的学习目标被称为流匹配。
对于不同的扩散模型及流匹配模型,其本质区别在于图像到噪声的路线的定义方式。在扩散模型中,图像到噪声的路线是由一个复杂的公式表示的。而整流模型将图像到噪声的路线定义为了直线。比如根据论文的介绍,整流中 t t t时刻数据 z t {z_t} zt由真实图像 x 0 {x_0} x0变换成纯噪声 ϵ \epsilon ϵ的位置为:
z t = ( 1 − t ) x 0 + t ϵ {z_t} = \left( {1 - t} \right){x_0} + t\epsilon zt=(1−t)x0+tϵ
而较先进的扩散模型EDM(Elucidating the Design Space of Diffusion-Based Generative Models)提出的路线公式为( b t {b_t} bt是一个形式较为复杂的变量):
z t = x 0 + b t ϵ {z_t} = {x_0} + {b_t}\epsilon zt=x0+btϵ
由于整流最后学习出来的生成路线近乎是直线,这种模型在设计上就支持少步数生成。
虽然整流模型是这样宣传的,但实际上 SD3 还是默认用了 28 步来生成图像。单看这篇文章,原整流论文里的很多设计并没有用上。对整流感兴趣的话,可以去阅读原论文 Flow straight and fast: Learning to generate and transfer data with rectified flow
流匹配模型和扩散模型的另一个区别是,流匹配模型天然支持 image2image 任务。从纯噪声中生成图像只是流匹配模型的一个特例。
3. 非均匀训练噪声采样
在学习这样一种生成模型时,会先随机采样一个时刻 t ∈ [ 0 , 1 ] t \in \left[ {0,1} \right] t∈[0,1],根据公式获取此时刻对应位置在生成路线上的速度,再让神经网络学习这个速度。直观上看,刚开始和快到终点的路线很好学,而路线的中间处比较难学。因此,在采样时刻 t t t时,SD3 使用了一种非均匀采样分布。
如下图所示,SD3 主要考虑了两种公式: mode(左)和 logit-norm (右)。二者的共同点是中间多,两边少。mode 相比 logit-norm,在开始和结束时概率不会过分接近 0。
4. 网络整体架构
以上内容都是和训练相关的理论基础,下面我们来看比较熟悉的文生图架构。
从整体架构上来看,和之前的 SD 一样,SD3 主要基于隐扩散模型(latent diffusion model, LDM)。这套方法是一个两阶段的生成方法:先用一个 LDM 生成隐空间低分辨率的图像,再用一个自编码器把图像解码回真实图像。
扩散模型 LDM 会使用一个神经网络模型来对噪声图像去噪。为了实现文生图,该去噪网络会以输入文本为额外约束。相比之前多数扩散模型,SD3 的主要改进是把去噪模型的结构从 U-Net 变为了 DiT。
DiT 的论文为 Scalable Diffusion Models with Transformers。如果只是对 DiT 的结构感兴趣的话,可以去直接通过读 SD3 的源码来学习。读 DiT 论文时只需要着重学习 AdaLayerNormZero 模块。
5. 提升自编码器通道数
在当时设计整套自编码器 + LDM 的生成架构时,SD 的开发者并没有仔细改进自编码器,用了一个能把图像下采样 8 倍,通道数变为 4 的隐空间图像。比如输入 512×512×3 的图像会被自编码器编码成 64×64×4。而近期有些工作发现,这个自编码器不够好,提升隐空间的通道数能够提升自编码器的重建效果。因此,SD3 把隐空间图像的通道数从 4 改为了 16。
6. 多模态 DiT (MM-DiT)
SD3 的去噪模型是一个 Diffusion Transformer (DiT)。如果去噪模型只有带噪图像这一种输入的话,DiT 则会是一个结构非常简单的模型,和标准 ViT 一样:图像过图块化层 (Patching) 并与位置编码相加,得到序列化的数据。这些数据会像标准 Transformer 一样,经过若干个子模块,再过反图块层得到模型输出。DiT 的每个子模块 DiT-Block 和标准 Transformer 块一样,由 LayerNorm, Self-Attention, 一对一线性层 (Pointwise Feedforward, FF) 等模块构成。
图块化层会把 2×2 个像素打包成图块,反图块化层则会把图块还原回像素。
然而,扩散模型中的去噪网络一定得支持带约束生成。这是因为扩散模型约束于去噪时刻 t t t。此外,作为文生图模型,SD3 还得支持文本约束。DiT 及本文的 MM-DiT 把模型设计的重点都放在了处理额外约束上。
我们先看一下模块是怎么处理较简单的时刻约束的。此处,如下图所示,SD3 的模块保留了 DiT 的设计,用自适应 LayerNorm (Adaptive LayerNorm, AdaLN) 来引入额外约束。具体来说,过了 LayerNorm 后,数据的均值、方差会根据时刻约束做调整。另外,过完 Attention 层或 FF 层后,数据也会乘上一个和约束相关的系数。
我们再来看文本约束的处理。文本约束以两种方式输入进模型:与时刻编码拼接、在注意力层中融合。具体数据关联细节可参见下图。如图所示,为了提高 SD3 的文本理解能力,描述文本 (“Caption”) 经由三种编码器编码,得到两组数据。一组较短的数据会经由 MLP 与文本编码加到一起;另一组数据会经过线性层,输入进 Transformer 的主模块中。
将约束编码与时刻编码相加是一种很常见的做法。此前 U-Net 去噪网络中处理简单约束(如 ImageNet 类型约束)就是用这种方法。
SD3 的 DiT 的子模块结构图如下所示。我们可以分几部分来看它。先看时刻编码 y y y 的那些分支。和标准 DiT 子模块一样, y y y 通过修改 LayerNorm 后数据的均值、方差及部分层后的数据大小来实现约束。再看输入的图像编码 x x x 和文本编码 c c c。二者以相同的方式做了 DiT 里的 LayerNorm, FF 等操作。不过,相比此前多数基于 DiT 的模型,此模块用了一种特殊的融合注意力层。具体来说,在过注意力层之前, x x x 和 c c c 对应的 Q , K , V Q,K,V Q,K,V 会分别拼接到一起,而不是像之前的模型一样, Q Q Q 来自图像, K , V K,V K,V 来自文本。过完注意力层,输出的数据会再次拆开,回到原本的独立分支里。由于 Transformer 同时处理了文本、图像的多模态信息,所以作者将模型取名为 MM-DiT (Multimodal DiT)。
论文里讲:「这个结构可以等价于两个模态各有一个 Transformer,但是在注意力操作时做了拼接,使得两种表示既可以在独自的空间里工作也可以考虑到另一个表示。」然而,仅从数据来源来看,过了一个注意力层后,图像信息和文本信息就混在了一起。你很难说,也很难测量,之后的 x x x 主要是图像信息, c c c 主要是文本信息。只能说 x , c x,c x,c 都蕴含了多模态的信息。之前 SD U-Net 里的 x , c x,c x,c 可以认为是分别包含了图像信息和文本信息,因为之前的 x x x 保留了二维图像结构,而 c c c 仅由文本信息决定。
7. 比例可变的位置编码
此前多数方法在使用类 ViT 架构时,都会把图像的图块从左上到右下编号,把二维图块拆成一维序列,再用这种一维位置编码来对待图块。
这样做有一个很大的坏处:生成的图像的分辨率是无法修改的。比如对于上图,假如采样时输入大小不是 4×4,而是 4×5,那么 0 号图块的下面就是 5 而不是 4 了,模型训练时学习到的图块之间的位置关系全部乱套。
解决此问题的方法很简单,只需要将一维的编码改为二维编码。这样 Transformer 就不会搞混二维图块间的关系了。
SD3 的 MM-DiT 一开始是在 25 6 2 {256^2} 2562 固定分辨率上训练的。之后在高分辨率图像上训练时,开发者用了一些巧妙的位置编码设置技巧,让不同比例的高分辨率图像也能共享之前学到的这套位置编码。详细公式请参见原论文。
8. 训练数据预处理
看完了模块设计,我们再来看一下 SD3 在训练中的一些额外设计。在大规模训练前,开发者用三个方式过滤了数据:
- 用了一个 NSFW 过滤器过滤图片,似乎主要是为了过滤色情内容。
- 用美学打分器过滤了美学分数太低的图片。
- 移除了看上去语义差不多的图片。
虽然开发者们自信满满地向大家介绍了这些数据过滤技术,但根据社区用户们的反馈,可能正是因为色情过滤器过分严格,导致 SD3 经常会生成奇怪的人体。
由于在训练 LDM 时,自编码器和文本编码器是不变的,因此可以提前处理好所有训练数据的图像编码和文本编码,这是一项非常基础的工程技巧。
9. 用 QK 归一化提升训练稳定度
按照之前高分辨率文生图模型的训练方法,SD3 会先在 25 6 2 {256^2} 2562 的图片上训练,再在高分辨率图片上微调。然而,开发者发现,开始微调后,混合精度训练常常会训崩。根据之前工作的经验,这是由于注意力输入的熵会不受控制地增长。解决方法也很简单,只要在做注意力计算之前对 Q , K Q, K Q,K 做一次归一化就行,具体做计算的位置可以参考上文模块图中的 “RMSNorm”。不过,开发者也承认,这个技巧并不是一个长久之策,得具体问题具体分析。看来这种 DiT 模型在大规模训练时还是会碰到许多训练不稳定的问题,且这些问题没有一个通用解。
10. 哪种扩散模型训练目标最适合文生图任务?
最后我们来看论文的实验结果部分。首先,为了寻找最好的扩散模型/流匹配模型,开发者开展了一场声势浩大的实验。实验涉及 61 种训练公式,其中的可变项有:
- 对于普通扩散模型,考虑 ϵ − ϵ- ϵ− 或 v − p r e d i c t i o n v-prediction v−prediction,考虑线性或 cosine 噪声调度。
- 对于整流,考虑不同的噪声调度。
- 对于 EDM,考虑不同的噪声调度,且尽可能与整流的调度机制相近以保证可比较。
在训练时,除了训练目标公式可变外,优化算法、模型架构、数据集、采样器都不可变。所有模型在 ImageNet 和 CC12M 数据集上训练,在 COCO-2014 验证集上评估 FID 和 CLIP Score。根据评估结果,可以选出每个模型的最优停止训练的步数。基于每种目标下的最优模型,开发者对模型进行最后的排名。由于在最终评估时,仍有采样步数、是否使用 EMA 模型等可变采样配置,开发者在所有 24 种采样配置下评估了所有模型,并用一种算法来综合所有采样配置的结果,得到一个所有模型的最终排名。最终的排名结果如下面的表 1 所示。训练集上的一些指标如表 2 所示。
根据实验结果,我们可以得到一些直观的结论:整流领先于扩散模型。但惊人的是,较新推出的 EDM 竟然没有战胜早期的 LDM (“eps/linear”)。
一般来说,大家做图像生成会用一个统一的指标,比如 ImageNet 上的 FID。这篇论文相当于是新提出了一种昂贵的评价方法。这种评价方法是否合理,是否能得到公认还犹未可知。另外,想说明一个生成模型的拟合能力不错,用 ImageNet 上的 FID 指标就足够有说服力了,大家不会对一个简单的生成模型有太多要求。然而,对于大型文生图模型,大家更关心的是模型的生成效果,而 FID 和 CLIP Score 并不能直接反映文生图模型的质量。因此,光凭这份实验结果,我们并不能说整流一定比之前的扩散模型要好。
11. 参数扩增实验结果
现在多数生成模型都会做参数扩增实验,即验证模型表现随参数量增长而增长,确保模型在资源足够的情况下可以被训练成「大模型」。SD3 也做了类似的实验。开发者用参数 d d d 来控制 MM-DiT 的大小,Transformer 块的个数为 d d d,且所有特征的通道数与 d d d 成正比。开发者在 25 6 2 {256^2} 2562 的数据上训练了所有模型 500k 步,每 50k 步在 CoCo 数据集上统计验证误差。最终所有评估指标如下图所示。可以说,所有指标都表明,模型的表现的确随参数量增长而增长。更多结果可以参见论文。
总结
通过本周对Stable Diffusion 3(SD3)论文的学习,对整流模型(Rectified Flow)和Diffusion Transformer(DiT)有了更深入的理解。整流模型用直线方式定义噪声到数据的生成路径,简化了传统扩散模型复杂的路线设计,同时天然支持少步数生成和image2image任务。这种设计思想拓宽了生成模型的技术视野,展现了整流模型在高效生成上的潜力。在网络架构上,SD3 的多模态 DiT(MM-DiT)通过拼接注意力机制实现了文本和图像的深度融合,同时通过比例可变的位置编码处理了不同分辨率图像生成的问题。这些设计不仅在理论上具有创新性,在实践中也展现了较好的性能。尤其是提升自编码器通道数这一点,直观体现了隐空间质量对生成效果的重要性。论文的实验部分让我对生成模型的对比评估有了新的认识。虽然实验结果表明整流模型在多个任务上优于传统扩散模型,但也暴露出评价标准上的局限性,例如FID和CLIP分数无法完全反映文生图模型的实际生成质量。此外,整流模型虽然在理论上支持少步数生成,但SD3 的默认配置仍然用了28步,说明理论优势在实际应用中可能受到限制。最后,论文的参数扩增实验进一步验证了生成模型参数量与性能的正相关性。这一部分让我深刻体会到,大模型的构建不仅需要强大的硬件支持,还需要从模型设计到数据处理的全方位优化。通过这篇论文,我对整流流匹配模型和高分辨率生成的设计思路有了更清晰的理解,也为后续在生成模型领域的研究积累了重要的参考经验。