当前位置: 首页 > article >正文

(ICLR-2025)你只采样一次:通过自协作扩散 GAN 驯服一步文本到图像合成

你只采样一次:通过自协作扩散 GAN 驯服一步文本到图像合成

paper是HKUST发表在ICLR 2025的工作

paper title:You Only Sample Once: Taming One-Step Text-To-Image Synthesis by Self-Cooperative Diffusion GANs

Code:待开源

ABSTRACT

近年来,一些研究尝试结合扩散模型(Diffusion Models, DMs)和生成对抗网络(Generative Adversarial Networks, GANs),以降低扩散模型中迭代去噪推理的计算成本。然而,这一方向的现有方法通常存在训练不稳定、模式崩溃或一步生成的学习效率较低等问题。为了解决这些问题,我们提出 YOSO,这是一种全新的生成模型,专为快速、可扩展且高保真度的一步图像生成设计,同时具有高训练稳定性和模式覆盖能力。具体而言,我们通过去噪生成器自身平滑对抗散度,从而实现自协作学习(self-cooperative learning)。实验结果表明,我们的方法可以作为从零开始训练的一步生成模型,并且能够取得具有竞争力的性能。此外,我们将 YOSO 扩展至基于预训练模型的一步文本到图像生成,并结合了多个高效的训练技术,包括:- 潜在感知损失(latent perceptual loss)和潜在判别器(latent discriminator),用于更高效地训练潜在扩散模型;- 信息性先验初始化(IPI),增强模型初始阶段的稳定性;- 快速适应阶段(quick adaptation stage),用于修正有缺陷的噪声调度器。实验结果表明,YOSO 在一步生成任务上达到了当前最先进的性能,即使采用低秩自适应(LoRA)微调方法,也能保持较好的生成效果。尤其是在 YOSO-PixArt-α 变体上,我们证明该模型能够在 512 分辨率 进行一步训练,并且无需额外显式训练,即可适配到 1024 分辨率。此外,该模型的微调仅需约 10 A800 GPU 天

1 INTRODUCTION

扩散模型(Diffusion Models,DMs)(Sohl-Dickstein et al, 2015; Ho et al, 2020; Song et al, 2021)近年来作为一类强大的生成模型崭露头角,在多个生成建模任务中取得了最先进的结果,例如文本到图像(Rombach et al, 2022; Xu et al, 2023c; Chen et al, 2024; Feng et al, 2023)、文本到视频(Blattmann et al, 2023; Hong et al, 2022)、图像编辑(Hertz et al, 2022; Brooks et al, 2023; Meng et al, 2022)以及受控生成(Zhang et al, 2023; Mou et al, 2023)。然而,DMs 的生成过程依赖于迭代去噪,导致生成速度较慢。此外,大规模 DMs 的高计算需求进一步增加了其实际应用的门槛,限制了其更广泛的采用。

从 DMs 进行采样可以视为求解概率流常微分方程(PF-ODE)(Song et al, 2021)。一些先前的研究(Song et al, 2020; Lu et al, 2022a;b; Bao et al, 2022)致力于开发先进的 ODE 求解器,以减少采样步骤。然而,即使采用这些方法,仍需要 20+ 步 才能实现高质量生成。另一种研究方向是从预训练的 PF-ODE 进行蒸馏(Song et al, 2023; Liu et al, 2023; Luo et al, 2023a;b;c; Salimans & Ho, 2022),其目标是在一步内预测 PF-ODE 求解器的多步解。现有方法(Luo et al, 2023a;b)能够在 4+ 步 内生成合理质量的样本。然而,在一步内生成高质量样本 仍然具有挑战性。

相比之下,生成对抗网络(GANs)(Goodfellow et al, 2014; Radford et al, 2016)天然支持一步生成,并具有快速采样的优势。然而,GANs 在大规模数据集 上的训练面临挑战(Sauer et al, 2022; Kang et al, 2023),导致其生成质量通常不如 DMs(Sauer et al, 2023b; Kang et al, 2023)。在本研究中,我们提出了一种新颖的方法,将扩散过程和 GANs 结合。成功实现这一目标的关键在于平滑对抗散度(adversarial divergence),从而稳定训练,同时保持高效的一步学习能力

先前的研究(Xiao et al, 2022; Xu et al, 2023b;c; Sauer et al, 2023c)已经探索了多种将扩散模型与 GANs 结合的变体。然而,现有方法要么直接对抗真实数据进行对抗散度计算,但缺乏平滑策略,导致训练不稳定和模式崩溃;要么通过添加噪声来平滑对抗散度,从而稳定训练,但牺牲了一步生成的学习效率。为了兼顾两者的优势,我们提出了一种新方法,即通过去噪生成器本身来平滑对抗散度

具体而言,我们将基于较少污染样本的一步去噪生成 视为真实分布(ground truth),而将基于更多污染样本的一步去噪生成 视为学生分布(student distribution),并在二者之间执行对抗散度计算。这一方法不仅能够自然地缩小目标分布与学生分布之间的距离,以稳定训练,还能够在干净样本上形成高效的一步学习机制

这一学习过程可被视为一种自协作学习(self-cooperative learning)(Xie et al, 2018),即生成器通过学习自身输出进行优化。这种创新性的设计实现了稳定训练,并有效提升了一步生成能力。因此,我们将该模型命名为 YOSO(“You Only Sample Once”),即**“你只采样一次”**。

此外,我们将 YOSO 扩展至基于预训练模型的一步文本到图像生成,并引入了多种高效的训练技术,包括:

  • 潜在感知损失(latent perceptual loss)潜在判别器(latent discriminator),以提高潜在扩散模型(latent DMs)的训练效率;
  • 信息性先验初始化(IPI, Informative Prior Initialization),用于增强模型初始化阶段的稳定性;
  • 快速适应阶段(quick adaptation stage),用于修正噪声调度器的缺陷。

得益于这些高效的设计,我们能够快速且高效地微调现有的预训练文本到图像扩散模型(例如 Stable Diffusion(Rombach et al., 2022)和 PixArt-α(Chen et al., 2024)),以实现高质量的一步生成(详见图 1)。此外,我们率先在一步文本到图像生成任务中支持低秩适应(LoRA, Low-Rank Adaptation)(Hu et al., 2022)微调,以提升训练效率,并最终达到了当前最先进的生成性能

我们的工作提出了多个重要贡献:

• 我们引入了 YOSO,一种新颖的生成模型,可以通过一步推理 生成高质量图像,同时具有稳定的训练过程良好的模式覆盖

• 我们进一步扩展了 YOSO,结合多种系统化且高效的训练技术,使其能够在低资源条件下微调预训练的文本到图像扩散模型(DMs),实现一步文本到图像生成,微调仅需 约 10 A800 计算天

• 我们进行了大规模实验,验证了 YOSO 的有效性,包括从零开始的图像生成文本到图像微调,以及与现有的图像个性化和可控生成模块的兼容性

图1

图 1:YOSO 在不同配置下的一步生成图像(底部)。该模型通过我们的算法在 512 分辨率 上微调 PixArt-α(Chen et al., 2024)进行训练。左下角的图像由 YOSO 生成,其适配至 1024 分辨率,基于 公式 (7),且未进行额外的显式训练

2 BACKGROUND

扩散模型(Diffusion models)
扩散模型(Diffusion Models, DMs)(Sohl-Dickstein et al., 2015; Ho et al., 2020)定义了一个前向过程,该过程通过在 T T T 个步骤中添加噪声 β t \beta_t βt,逐步将样本从数据分布转换为高斯分布:
q ( x t ∣ x t − 1 ) ≜ N ( x t ; 1 − β t x t − 1 , β t I ) . q(x_t | x_{t-1}) \triangleq \mathcal{N}(x_t; \sqrt{1 - \beta_t} x_{t-1}, \beta_t I). q(xtxt1)N(xt;1βt xt1,βtI).
受噪声污染的样本可以直接通过以下公式获得:
x t = α ˉ t x 0 + 1 − α ˉ t ϵ , x_t = \bar{\alpha}_t x_0 + \sqrt{1 - \bar{\alpha}_t} \epsilon, xt=αˉtx0+1αˉt ϵ,
其中:
α ˉ t = ∏ s = 1 t ( 1 − β s ) , \bar{\alpha}_t = \prod_{s=1}^{t} (1 - \beta_s), αˉt=s=1t(1βs),
ϵ ∼ N ( 0 , I ) \epsilon \sim \mathcal{N}(0, I) ϵN(0,I)

参数化的反向扩散过程被定义为逐步去噪:
p θ ( x t − 1 ∣ x t ) ≜ N ( x t − 1 ; μ θ ( x t , t ) , σ 2 I ) . p_\theta (x_{t-1} | x_t) \triangleq \mathcal{N}(x_{t-1}; \mu_\theta(x_t, t), \sigma^2 I). pθ(xt1xt)N(xt1;μθ(xt,t),σ2I).
该模型可以通过最小化负 ELBO 进行训练(Ho et al., 2020; Kingma et al., 2021):
L = E t , q ( x 0 ) q ( x t ∣ x 0 ) K L ( q ( x t − 1 ∣ x t , x 0 ) ∣ ∣ p θ ( x t − 1 ∣ x t ) ) . \mathcal{L} = \mathbb{E}_{t, q(x_0)} q(x_t | x_0) KL(q(x_{t-1} | x_t, x_0) || p_\theta (x_{t-1} | x_t)). L=Et,q(x0)q(xtx0)KL(q(xt1xt,x0)∣∣pθ(xt1xt)).
其中, q ( x t − 1 ∣ x t , x 0 ) q(x_{t-1} | x_t, x_0) q(xt1xt,x0) 是在 Ho et al. (2020) 中推导出的高斯后验分布。

扩散模型的一个关键假设是去噪步长足够小,即时间步 t t t t − 1 t-1 t1 之间的去噪操作近似为高斯分布。这一假设确保了 q ( x t − 1 ∣ x t ) q(x_{t-1} | x_t) q(xt1xt) 近似服从高斯分布,从而使得 p θ ( x t − 1 ∣ x t ) p_\theta (x_{t-1} | x_t) pθ(xt1xt) 的建模更具有效性。


扩散-GAN 混合模型(Diffusion-GAN hybrids)
扩散模型的一个问题在于,当去噪步长较大时,真实的后验分布 q ( x t − 1 ∣ x t ) q(x_{t-1} | x_t) q(xt1xt) 并不严格遵循高斯分布。因此,为了支持较大的去噪步长,Diffusion GANs(Xiao et al., 2022)提出最小化模型分布 p θ ( x t − 1 ∣ x t ) p_\theta (x_{t-1} | x_t) pθ(xt1xt) 与真实后验 q ( x t − 1 ∣ x t ) q(x_{t-1} | x_t) q(xt1xt) 之间的对抗散度(adversarial divergence)
min ⁡ θ E q ( x t ) [ D a d v ( q ( x t − 1 ∣ x t ) ∣ ∣ p θ ( x t − 1 ′ ∣ x t ) ) ] . \min_\theta \mathbb{E}_{q(x_t)} [D_{adv} (q(x_{t-1} | x_t) || p_\theta (x'_{t-1} | x_t))]. θminEq(xt)[Dadv(q(xt1xt)∣∣pθ(xt1xt))].
其中, p θ ( x t − 1 ′ ∣ x t ) p_\theta (x'_{t-1} | x_t) pθ(xt1xt) 由 GAN 生成器建模:
p θ ( x t − 1 ′ ∣ x t ) ≜ ∫ p θ ( x 0 ∣ x t ) q ( x t − 1 ∣ x t , x 0 ) d x 0 . p_\theta (x'_{t-1} | x_t) \triangleq \int p_\theta (x_0 | x_t) q(x_{t-1} | x_t, x_0) dx_0. pθ(xt1xt)pθ(x0xt)q(xt1xt,x0)dx0.
通过引入 GAN 生成器,Diffusion-GAN 允许更大的去噪步长(如 4 步),从而加速生成过程。

3 METHOD: SELF-COOPERATIVE DIFFUSION GANS

Diffusion-GAN 混合模型(Xiao et al., 2022; Xu et al., 2023b;c)的一个关键问题在于,它们通过匹配生成器分布 p θ ( x t − 1 ∣ x t ) ≜ E p θ ( x 0 ∣ x t ) q ( x t − 1 ∣ x t , x 0 ) p_\theta (x_{t-1} | x_t) \triangleq \mathbb{E}_{p_\theta (x_0 | x_t)} q(x_{t-1} | x_t, x_0) pθ(xt1xt)Epθ(x0xt)q(xt1xt,x0) 与受污染的数据分布进行训练。然而,这种方法仅间接学习 p θ ( x 0 ∣ x t ) p_\theta (x_0 | x_t) pθ(x0xt) p θ ( x 0 ) = ∫ q ( x t ) p θ ( x 0 ∣ x t ) d x t p_\theta (x_0) = \int q(x_t) p_\theta (x_0 | x_t) dx_t pθ(x0)=q(xt)pθ(x0xt)dxt,而这些分布是用于一步生成的关键分布。由于这种间接匹配方式,使得整个学习过程的效率下降。

3.1 OUR DESIGN


为了使一步生成的学习更有效,我们提出直接基于干净数据构建学习目标。首先,我们构造一个干净数据的序列分布如下:
p θ ( t ) ( x 0 ) = ∫ q ( x t ) p θ ( x 0 ∣ x t ) d x t , 0 < t ≤ T ; p θ ( 0 ) ( x 0 ) ≜ q ( x 0 ) , p_\theta^{(t)}(x_0) = \int q(x_t) p_\theta(x_0 | x_t) dx_t, \quad 0 < t \leq T; \quad p_\theta^{(0)}(x_0) \triangleq q(x_0), pθ(t)(x0)=q(xt)pθ(x0xt)dxt,0<tT;pθ(0)(x0)q(x0),
其中 q ( x 0 ) q(x_0) q(x0) 是数据分布, p θ ( x 0 ∣ x t ) ≜ N ( G θ ( x t , t ) , σ 2 I ) p_\theta (x_0 | x_t) \triangleq \mathcal{N}(G_\theta(x_t, t), \sigma^2 I) pθ(x0xt)N(Gθ(xt,t),σ2I) G θ G_\theta Gθ 是去噪生成器。需要注意的是, G θ ( x t , t ) G_\theta(x_t, t) Gθ(xt,t) 是我们的去噪生成器,它用于预测干净样本。如果网络 ϵ θ \epsilon_\theta ϵθ 被参数化为预测噪声,则:
G θ ( x t , t ) ≜ x t − 1 − α ˉ t ϵ θ ( x t , t ) α ˉ t . G_\theta(x_t, t) \triangleq \frac{x_t - \sqrt{1 - \bar{\alpha}_t} \epsilon_\theta(x_t, t)}{\bar{\alpha}_t}. Gθ(xt,t)αˉtxt1αˉt ϵθ(xt,t).

在此构造的分布基础上,我们可以定义优化目标如下:
E t [ D adv ( q ( x ) ∣ ∣ p θ ( t ) ( x ) ) + λ ⋅ K L ( q ( x 0 , x t ) ∣ ∣ p θ ( x 0 , x t ) ) ] \mathbb{E}_t [D_{\text{adv}}(q(x) || p_\theta^{(t)}(x)) + \lambda \cdot KL(q(x_0, x_t) || p_\theta (x_0, x_t))] Et[Dadv(q(x)∣∣pθ(t)(x))+λKL(q(x0,xt)∣∣pθ(x0,xt))]
= E t [ D adv ( q ( x ) ∣ ∣ p θ ( t ) ( x ) ) ] + λ t ⋅ K L ( q ( x 0 ) q ( x t ∣ x 0 ) ∣ ∣ q ( x t ) p θ ( x 0 ∣ x t ) ) , = \mathbb{E}_t [D_{\text{adv}}(q(x) || p_\theta^{(t)}(x))] + \lambda_t \cdot KL(q(x_0) q(x_t | x_0) || q(x_t) p_\theta(x_0 | x_t)), =Et[Dadv(q(x)∣∣pθ(t)(x))]+λtKL(q(x0)q(xtx0)∣∣q(xt)pθ(x0xt)),
其中 q ( x 0 , x t ) ≜ q ( x 0 ) q ( x t ∣ x 0 ) q(x_0, x_t) \triangleq q(x_0) q(x_t | x_0) q(x0,xt)q(x0)q(xtx0) p θ ( x 0 , x t ) ≜ q ( x t ) p θ ( x 0 ∣ x t ) p_\theta (x_0, x_t) \triangleq q(x_t) p_\theta (x_0 | x_t) pθ(x0,xt)q(xt)pθ(x0xt)。该优化目标结合了对抗散度(adversarial divergence)和 KL 散度。具体而言,对抗散度专注于整体分布匹配,确保生成质量,而 KL 散度则关注逐点级别的匹配,确保模式覆盖

然而,直接在干净数据分布上学习对抗散度较为困难,这类似于 GAN 训练中遇到的挑战。为了应对这些问题,先前的扩散 GAN 方法(Xiao et al., 2022; Xu et al., 2023b)通常在受污染的数据分布上训练对抗散度。然而,正如前文分析的,这种方法难以直接匹配 p θ ( x 0 ) p_\theta(x_0) pθ(x0),从而削弱了一步生成的有效性。此外,这种方法还迫使判别器适应不同程度的噪声,导致其能力受限。

回顾 p θ ( t ) ( x ) p_\theta^{(t)}(x) pθ(t)(x) 的定义:
p θ ( t ) ( x ) = ∫ q ( x t ) p θ ( x ∣ x t ) d x t . p_\theta^{(t)}(x) = \int q(x_t) p_\theta(x | x_t) dx_t. pθ(t)(x)=q(xt)pθ(xxt)dxt.
该分布的质量主要受两个关键因素影响:

  1. 可训练生成器 G θ G_\theta Gθ 的能力
  2. x t x_t xt 提供的信息

因此,在生成器 G θ G_\theta Gθ 固定的情况下,如果我们增加 x t x_t xt 中的信息,理论上可以获得更优的分布。换句话说, p θ ( t k ) ( x ) p_\theta^{(t_k)}(x) pθ(tk)(x) 很可能优于 p θ ( t ) ( x ) p_\theta^{(t)}(x) pθ(t)(x),其中:
t k = max ⁡ { t − k , 0 } < t . t_k = \max\{t - k, 0\} < t. tk=max{tk,0}<t.
受到协作方法(cooperative approach)的启发(Xie et al., 2018; 2021; 2022; Hill et al., 2022),该方法使用基于 MCMC 修正的模型分布来学习生成器,我们建议使用 p θ ( t k ) ( x ) p_\theta^{(t_k)}(x) pθ(tk)(x) 作为学习 p θ ( t ) ( x ) p_\theta^{(t)}(x) pθ(t)(x) 的真实分布(ground truth),从而构造以下训练目标:
$$
\min_\theta \mathcal{L}\theta \triangleq \mathbb{E}t \mathbb{E}{q(x) q(x_t | x)} \lambda_t ||G\theta(x_t, t) - x||_2^2

  • \mathbb{E}t [D{\text{adv}}(p_\theta^{(t_k)} (\text{sg}(x)) || p_\theta^{(t)} (x))],
    $$
    其中 sg [ ⋅ ] \text{sg}[\cdot] sg[] 表示停止梯度算子(stop-gradient operator),第二项被称为协作对抗损失(cooperative adversarial loss)。该训练目标可以被视为一种自协作学习方法(self-cooperative approach),因为**“修正后”的样本 p θ ( t k ) ( x ) p_\theta^{(t_k)}(x) pθ(tk)(x) p θ ( t ) ( x ) p_\theta^{(t)}(x) pθ(t)(x) 由相同的生成网络 G θ G_\theta Gθ 生成**。

需要注意的是,我们仅在对抗散度(adversarial divergence)中用 p θ ( t k ) ( x ) p_\theta^{(t_k)}(x) pθ(tk)(x) 取代数据分布,以平滑学习目标。近期研究(Luo et al., 2023d)发现,使用真实数据和修正数据的混合进行训练,有助于更有效地学习生成器。

我们将在下文简要验证协作对抗散度的理论合理性

命题 1协作对抗损失(cooperative adversarial loss)的最优解收敛于 p θ ( T ) ( x ) = p d ( x ) p_\theta^{(T)}(x) = p_d(x) pθ(T)(x)=pd(x)

该命题表明,当网络的能力足够强时,所提出的协作对抗损失能够恢复真实数据分布,从而证明了所提出优化目标的理论合理性。详细证明见附录 C

在上述分布匹配目标中,我们采用非饱和 GAN 目标(non-saturating GAN objective) 来最小化边际分布(marginal distribution)的对抗散度,并使用 L 2 L_2 L2 损失优化点匹配(point matching)。由此,我们可以得到一个可行的训练目标:
$$
\min_\theta \max_\phi \mathbb{E}t [\mathbb{E}{p_\theta^{(t_k)}(x)} \log D_\phi (\text{sg}(x), t) - \mathbb{E}{p\theta^{(t)}(x)} \log D_\phi (x, t)]

  • \lambda_t \mathbb{E}{q(x) q(x_t | x)} ||G\theta(x_t, t) - x||_2^2,
    $$
    其中, D ϕ D_\phi Dϕ 是判别器网络(discriminator)。我们发现,自协作方法(self-cooperative approach)与一致性训练(Consistency Training, Song et al., 2023)密切相关。然而,一致性训练(Consistency Training)将 x t x_t xt 视为 ODE 近似解,并执行点对点匹配(point-to-point matching)。相比之下,我们的方法直接在边际分布级别(marginal distribution level) 匹配 p θ ( t ) ( x ) p_\theta^{(t)}(x) pθ(t)(x) p θ ( t k ) ( x ) p_\theta^{(t_k)}(x) pθ(tk)(x),从而避免 ODE 近似误差的影响。

为了进一步保证所提出模型的模式覆盖能力(mode cover),我们可以将一致性损失(consistency loss)作为正则化项添加到优化目标中,从而构造以下损失函数:
min ⁡ θ max ⁡ ϕ E t { E p θ ( t k ) ( x ) [ log ⁡ D ϕ ( x , t ) − E p θ ( t ) ( x ) log ⁡ D ϕ ( x , t ) ] \min_\theta \max_\phi \mathbb{E}_t \{ \mathbb{E}_{p_\theta^{(t_k)}(x)} [\log D_\phi (x, t) - \mathbb{E}_{p_\theta^{(t)}(x)} \log D_\phi (x, t)] θminϕmaxEt{Epθ(tk)(x)[logDϕ(x,t)Epθ(t)(x)logDϕ(x,t)]
$$

  • \mathbb{E}{q(x) q(x_t, x{t_k} | x_t, x)} [\lambda(t) ||G_\theta (x_t, t) - x||2^2 + \lambda_t^{\text{con}} ||G\theta (x_t, t) - \text{sg}(G_\theta (x_{t_k}, t_k))||_2^2 ] },
    $$
    其中, λ t \lambda_t λt λ ( t ) \lambda(t) λ(t) 是预定义的超参数(pre-defined hyper-parameters)。

4 TRY IT ON CIFAR-10 BEFORE SCALING UP FOR SAVING MONEY!

在本节中,我们评估拟议的YOSO在CIFAR-10上的性能(Yu等,2015),以验证其在从头开始和微调环境的培训下的有效性。

4.1 TRAINING STRATEGIES


在开始培训之前,我们引入了一些有效的培训策略,以遵循驯服Yoso。

脱钩的调度程序。我们发现,用于执行一致性损失和对抗性损失的最佳调度程序并不相同。这是由于合作对抗损失不涉及时间步长的近似错误,从而实现了实质性跳过以最大程度地发挥其疗效。

相比之下,一致性损失(consistency loss)容易受到时间步跳跃(timestep skips)导致的近似误差影响,因此需要采用更保守的跳跃策略 来保持其有效性。因此,为了更好地发挥每个损失项的能力,我们提出使用解耦调度器(decoupled schedulers) 来构建最终的训练目标:
min ⁡ θ max ⁡ ϕ E t { E p θ ( t k ) ( x ) [ log ⁡ D ϕ ( x , t ) − E p θ ( t ) ( x ) log ⁡ D ϕ ( x , t ) ] \min_\theta \max_\phi \mathbb{E}_t \{ \mathbb{E}_{p_\theta^{(t_k)}(x)} [\log D_\phi (x, t) - \mathbb{E}_{p_\theta^{(t)}(x)} \log D_\phi (x, t)] θminϕmaxEt{Epθ(tk)(x)[logDϕ(x,t)Epθ(t)(x)logDϕ(x,t)]
$$

  • \mathbb{E}{q(x) q(x_t, x{t_m} | x_t, x)} [\lambda(t) ||G_\theta (x_t, t) - x||2^2 + \lambda_t^{\text{con}} ||G\theta (x_t, t) - \text{sg}(G_\theta (x_{t_m}, t_m))||_2^2 ] },
    $$
    其中 t k = max ⁡ ( t − k , 0 ) t_k = \max(t - k, 0) tk=max(tk,0) t m = max ⁡ ( t − m , 0 ) t_m = \max(t - m, 0) tm=max(tm,0),并且在实验中设定 k = 250 k = 250 k=250 m = 25 m = 25 m=25

退火策略(Annealing strategy)
由于我们的目标是获得一个强大的一步去噪生成模型,但 KL 损失和一致性损失会在点级别执行匹配(point-level matching),这可能会在模型容量不足的情况下影响性能。因此,我们建议在训练过程中逐渐将这两个损失的权重降低至零

具体来说,我们定义:
λ = ( 1 − ⌊ n / K ⌋ K − 1 ) λ ′ \lambda = \left(1 - \frac{\lfloor n / K \rfloor}{K - 1} \right) \lambda' λ=(1K1n/K)λ
其中, λ ′ \lambda' λ 为初始权重, K K K 为退火次数, n n n 是当前训练迭代次数, N N N 是总训练迭代次数。该公式确保损失权重在 K K K 次退火后降至零,以适应训练进程。

5 TOWARDS ONE-STEP TEXT-TO-IMAGE SYNTHESIS

由于从零开始训练文本到图像(text-to-image)模型的成本极高,我们建议使用预训练的文本到图像扩散模型(DMs) 作为初始化,并结合自协作扩散 GAN(Self-Cooperative Diffusion GANs) 进行优化。在本节中,我们介绍了一些系统性设计,用于开发基于预训练 DMs 的一步文本到图像生成模型

5.1 USING PRE-TRAINED MODELS FOR TRAINING


潜在感知损失(Latent Perceptual Loss)
先前的研究(Hou et al., 2017; Hoshen et al., 2019; Song et al., 2023)已证实感知损失(perceptual loss)在多个领域中的有效性。近期的研究(Liu et al., 2023; Song et al., 2023)发现 LPIPS 损失(Zhang et al., 2018)对于获得高质量的少步扩散模型(DMs) 至关重要。然而,LPIPS 损失的一个显著缺点是,它是在数据空间(data space) 计算的,这会带来较高的计算开销。

相比之下,流行的 Stable Diffusion(SD)潜在空间(latent space) 运行,从而减少了计算需求。因此,在潜在 DMs 训练中使用 LPIPS 损失成本较高,不仅需要计算数据空间的 LPIPS 损失,还额外增加了解码操作。考虑到预训练的 SD 可以作为有效的特征提取器(Xu et al., 2023a),我们建议利用预训练 SD 来执行潜在感知损失

然而,SD 采用 UNet 结构,其最终层预测的 ϵ \epsilon ϵ 与数据维度相同。因此,我们提出使用 UNet 的瓶颈层(bottleneck layer) 进行计算:
d ( z θ , z ) = ∣ ∣ HalfUNet ( z θ , c , t = 0 ) − HalfUNet ( z , c , t = 0 ) ∣ ∣ 2 2 , d(\mathbf{z}_\theta, \mathbf{z}) = ||\text{HalfUNet}(\mathbf{z}_\theta, c, t = 0) - \text{HalfUNet}(\mathbf{z}, c, t = 0)||_2^2, d(zθ,z)=∣∣HalfUNet(zθ,c,t=0)HalfUNet(z,c,t=0)22,
其中 z \mathbf{z} z 是 VAE 编码后的潜变量(latent encoded images), c c c 是文本特征(text feature)。
值得注意的是,通过 SD 计算潜在感知损失的好处不仅在于计算效率的提升,还在于整合了文本特征,这对于文本到图像任务至关重要


潜在判别器(Latent Discriminator)
在大规模数据集上训练 文本到图像(text-to-image) 任务的 GAN 面临严重挑战。具体而言,与无条件生成(unconditional generation)不同,文本到图像任务的判别器(discriminator)不仅需要评估图像质量,还需要对齐文本信息。这一挑战在训练的早期阶段尤为明显。

为了解决该问题,先前的纯 GAN 方法(Kang et al., 2023)提出了复杂的学习目标,但训练成本较高。而研究表明,GAN 训练可以受益于使用预训练网络作为判别器。如上所述,预训练的 SD 已经学习到具有代表性的特征,因此,我们建议使用预训练的 SD 构建潜在判别器(latent discriminator)

与潜在感知损失类似,我们仅使用 UNet 的一半(Half UNet) 作为判别器,并附加一个简单的预测头(predict head)。所提出策略的优势有以下两点:

  1. 利用信息丰富的预训练网络进行初始化
  2. 判别器基于潜在空间(latent space),计算效率更高

与先前的研究(Sauer et al., 2023c)不同,后者的判别器在数据空间 进行计算,并需要解码潜在变量,再从解码器反向传播,计算成本极高。而采用 潜在判别器(Latent Discriminator),可以显著降低计算成本,同时实现稳定的训练过程,并加快收敛速度

5.2 FIXING THE NOISE SCHEDULER


扩散模型(DMs)的一个常见问题是,最终的受污染样本并非纯噪声。例如,Stable Diffusion(SD) 采用的噪声调度器(noise scheduler)使得最终时间步的受污染样本计算如下:
x T = 0.068265 ⋅ x 0 + 0.99767 ⋅ ϵ , \mathbf{x}_T = 0.068265 \cdot \mathbf{x}_0 + 0.99767 \cdot \epsilon, xT=0.068265x0+0.99767ϵ,
其中终端信噪比(SNR) 计算为:
α ˉ T 1 − α ˉ T = 0.004682 , \frac{\bar{\alpha}_T}{1 - \bar{\alpha}_T} = 0.004682, 1αˉTαˉT=0.004682,
这在训练和推理 之间造成了分布偏差(distribution gap)

先前的研究(Lin et al., 2024a)仅观察到,这种偏差使 DMs 无法生成纯黑色或纯白色图像。然而,我们发现该问题在一步生成(one-step generation) 中更为严重。如 图 2 所示,如果直接从标准高斯分布(standard Gaussian) 采样噪声,生成结果会出现显著的伪影(artifacts)。其可能的原因在于:

  • 多步生成(multi-step generation) 中,该分布偏差可以在采样过程中逐步修正;
  • 一步生成(one-step generation) 中,该分布偏差会直接影响最终输出。

为了解决该问题,我们提出了 两种简单但有效的解决方案

信息性先验初始化(Informative Prior Initialization, IPI)
非零终端信噪比(SNR)问题类似于 VAE(变分自编码器)中的先验空洞问题(prior hole issue)(Klushyn et al., 2019; Bauer & Mnih, 2019; Kingma et al., 2016)。因此,我们可以采用信息性先验(informative prior) 而非非信息性先验(non-informative prior),以有效解决该问题。

为简化处理,我们采用可学习的高斯分布 N ( μ , σ 2 I ) \mathcal{N}(\mu, \sigma^2 I) N(μ,σ2I),其最优公式如下:
ϵ ′ ′ = α ˉ T ⋅ ( E x x + Std ( x ) × ϵ ′ ) + 1 − α ˉ T ⋅ ϵ , \epsilon'' = \bar{\alpha}_T \cdot (\mathbb{E}_x x + \text{Std}(x) \times \epsilon') + \sqrt{1 - \bar{\alpha}_T} \cdot \epsilon, ϵ′′=αˉT(Exx+Std(x)×ϵ)+1αˉT ϵ,
其中 E x x \mathbb{E}_x x Exx Std ( x ) \text{Std}(x) Std(x) 可通过有限样本高效估计,而 ϵ ′ \epsilon' ϵ 服从标准高斯分布(standard Gaussian distribution)。

图 2 所示,一步生成中的伪影(artifacts)在应用 IPI 后被立即消除。值得注意的是,该方法仅需最小的调整 即可获得良好的效果,从而为 LoRA 微调(LoRA fine-tuning) 实现一步文本到图像生成提供了可能性。

图2

图2:Yosolora的样品,具有不同初始化的一步推断。


快速适应 V-预测与零终端 SNR(Quick Adaption to V-prediction and Zero Terminal SNR)
当终端 SNR 过低时,IPI 可能会导致数值不稳定。如 图 3 所示,我们对 PixArt-α(Chen et al., 2024)进行微调,其终端 SNR 仅为 4 e − 5 4e-5 4e5,导致 ϵ \epsilon ϵ-预测失败,使得一步生成无法正常进行

根据 Lin et al. (2024a),我们建议切换至 V-预测(v-prediction)(Salimans & Ho, 2022)并设定零终端 SNR。然而,我们发现直接转换收敛速度较慢(详见附录 F)。在大规模文本到图像任务 中,计算资源有限,这种缓慢的收敛速度是不可接受的。

为了解决此问题,我们提出快速适应阶段(quick adapt-stage)
适应阶段-I(Adapt-stage-I) 切换至 V-预测:
L ( θ ) = λ t ∣ ∣ v θ ( x t , t ) − v ϕ ( x t , t ) ∣ ∣ 2 2 , L(\theta) = \lambda_t || v_\theta (x_t, t) - v_\phi (x_t, t) ||_2^2, L(θ)=λt∣∣vθ(xt,t)vϕ(xt,t)22,
其中:
v ϕ ( x t , t ) = α ˉ t ϵ ϕ ( x t , t ) − 1 − α ˉ t x ϕ t , x ϕ t = x t − 1 − α ˉ t ϵ ϕ ( x t , t ) α ˉ t . v_\phi (x_t, t) = \bar{\alpha}_t \epsilon_\phi (x_t, t) - \sqrt{1 - \bar{\alpha}_t} x_\phi^t, \quad x_\phi^t = \frac{x_t - \sqrt{1 - \bar{\alpha}_t} \epsilon_\phi (x_t, t)}{\bar{\alpha}_t}. vϕ(xt,t)=αˉtϵϕ(xt,t)1αˉt xϕt,xϕt=αˉtxt1αˉt ϵϕ(xt,t).
其中, ϵ ϕ ( ⋅ , ⋅ ) \epsilon_\phi(\cdot, \cdot) ϵϕ(,) 表示冻结的预训练模型(frozen pre-trained model) v θ ( ⋅ , ⋅ ) v_\theta(\cdot, \cdot) vθ(,) 表示目标 V-预测模型(desired v-prediction model)。

图3

图3:预测失败。

适应阶段-II(Adapt-stage-II) 切换至零终端 SNR(zero terminal SNR):
L ( θ ) = λ t ∣ ∣ v θ ( x t , t ) − v ϕ ( x t ′ , t ) ∣ ∣ 2 2 . L(\theta) = \lambda_t || v_\theta (x_t, t) - v_\phi (x_t', t) ||_2^2. L(θ)=λt∣∣vθ(xt,t)vϕ(xt,t)22.
在该阶段,我们仅更改调度器,使其在学生模型(student model)上采用零终端 SNR。这样可以避免 ϵ \epsilon ϵ-预测( ϵ \epsilon ϵ-prediction) 在零终端 SNR 下的数值不稳定问题。值得注意的是,零终端 SNR 调度器在每个时间步的 SNR 均低于原始调度器,从而形成有效的蒸馏目标(distillation objective)

我们观察到,该适应阶段收敛速度极快,通常仅需要 1000 次迭代 即可初始化 YOSO。这一过程使得 V-预测和零终端 SNR 能够快速适配到预训练的 ϵ \epsilon ϵ-预测 DMs,从而全面解决噪声调度器中的非零终端 SNR 问题


http://www.kler.cn/a/546947.html

相关文章:

  • HTML的入门
  • windows平台上 oracle简单操作手册
  • 【二叉树学习7】
  • Eclipse:关闭多余的工具条
  • Docker compose 以及镜像使用
  • Sprinig源码解析
  • [LeetCode]day21 15.三数之和
  • Machine Learning:Optimization
  • H5自适应响应式代理记账与财政咨询服务类PbootCMS网站模板 – HTML5财务会计类网站源码下载
  • HCIA项目实践---OSPF的基本配置
  • 在本地校验密码或弱口令 (windows)
  • DeepSeek免费部署到WPS或Office
  • Linux 内核 IPoIB 驱动中 sysfs 属性冲突问题的分析与解决
  • LAWS是典型的人机环境系统
  • 【第4章:循环神经网络(RNN)与长短时记忆网络(LSTM)— 4.6 RNN与LSTM的变体与发展趋势】
  • Unity使用iTextSharp导出PDF-04图形
  • 修改OnlyOffice编辑器默认字体
  • 小米 R3G 路由器刷机教程(Pandavan)
  • 算法练习——哈希表
  • QML使用ChartView绘制箱线图