论文笔记(六十三)Understanding Diffusion Models: A Unified Perspective(四)
Understanding Diffusion Models: A Unified Perspective(四)
- 文章概括
- 学习扩散噪声参数(Learning Diffusion Noise Parameters)
- 三种等效的解释(Three Equivalent Interpretations)
文章概括
引用:
@article{luo2022understanding,
title={Understanding diffusion models: A unified perspective},
author={Luo, Calvin},
journal={arXiv preprint arXiv:2208.11970},
year={2022}
}
Luo, C., 2022. Understanding diffusion models: A unified perspective. arXiv preprint arXiv:2208.11970.
原文: https://arxiv.org/abs/2208.11970
代码、数据和视频:https://arxiv.org/abs/2208.11970
系列文章:
请在
《
《
《文章
》
》
》 专栏中查找
学习扩散噪声参数(Learning Diffusion Noise Parameters)
让我们探讨如何联合学习VDM的噪声参数。一种潜在的方法是使用具有参数 η \eta η的神经网络 α ^ η ( t ) \hat{\alpha}_\eta(t) α^η(t)对 α t \alpha_t αt进行建模。然而,这种方法效率较低,因为在每个时间步 t t t上都必须多次执行推断以计算 α ˉ t \bar{\alpha}_t αˉt。虽然缓存可以缓解这一计算成本,但我们也可以推导出另一种学习扩散噪声参数的方法。通过将公式85中的方差公式代入公式99中推导出的每时间步目标,我们可以化简为:
1. α ˉ t \bar{\alpha}_t αˉt 的定义
扩散模型中, α ˉ t \bar{\alpha}_t αˉt 的定义为: α ˉ t = ∏ i = 1 t α i , \bar{\alpha}_t = \prod_{i=1}^t \alpha_i, αˉt=i=1∏tαi, 即所有时间步 α i \alpha_i αi 的累积乘积。
2. α ˉ t − 1 ( 1 − α t ) \bar{\alpha}_{t-1}(1 - \alpha_t) αˉt−1(1−αt) 的展开
从 α ˉ t \bar{\alpha}_t αˉt 的定义出发,可以得到: α ˉ t − 1 = ∏ i = 1 t − 1 α i . \bar{\alpha}_{t-1} = \prod_{i=1}^{t-1} \alpha_i. αˉt−1=i=1∏t−1αi.
因此, α ˉ t \bar{\alpha}_t αˉt 可以写成: α ˉ t = α ˉ t − 1 ⋅ α t . \bar{\alpha}_t = \bar{\alpha}_{t-1} \cdot \alpha_t. αˉt=αˉt−1⋅αt.
于是,有: 1 − α ˉ t = 1 − ( α ˉ t − 1 ⋅ α t ) . 1 - \bar{\alpha}_t = 1 - (\bar{\alpha}_{t-1} \cdot \alpha_t). 1−αˉt=1−(αˉt−1⋅αt).
3. 化简 α ˉ t − 1 ( 1 − α t ) \bar{\alpha}_{t-1}(1 - \alpha_t) αˉt−1(1−αt)
现在回到公式 (103) 中的 α ˉ t − 1 ( 1 − α t ) \bar{\alpha}_{t-1}(1 - \alpha_t) αˉt−1(1−αt)。将其展开为: α ˉ t − 1 ( 1 − α t ) = α ˉ t − 1 − α ˉ t − 1 ⋅ α t . \bar{\alpha}_{t-1}(1 - \alpha_t) = \bar{\alpha}_{t-1} - \bar{\alpha}_{t-1} \cdot \alpha_t. αˉt−1(1−αt)=αˉt−1−αˉt−1⋅αt.
结合 α ˉ t = α ˉ t − 1 ⋅ α t \bar{\alpha}_t = \bar{\alpha}_{t-1} \cdot \alpha_t αˉt=αˉt−1⋅αt,我们发现: α ˉ t − 1 − α ˉ t − 1 ⋅ α t = α ˉ t − 1 − α ˉ t . \bar{\alpha}_{t-1} - \bar{\alpha}_{t-1} \cdot \alpha_t = \bar{\alpha}_{t-1} - \bar{\alpha}_t. αˉt−1−αˉt−1⋅αt=αˉt−1−αˉt.
回忆公式70, q ( x t ∣ x 0 ) q(x_t|x_0) q(xt∣x0)是形式为 N ( x t ; α ˉ t x 0 , ( 1 − α ˉ t ) I ) \mathcal{N}(x_t;\sqrt{\bar{\alpha}_t}x_0,(1 - \bar{\alpha}_t)\mathbf{I}) N(xt;αˉtx0,(1−αˉt)I)的高斯分布。然后,根据信噪比(SNR)的定义 S N R = μ 2 σ 2 SNR = \frac{\mu^2}{\sigma^2} SNR=σ2μ2,我们可以将每个时间步 t t t的信噪比写为:
SNR ( t ) = α ˉ t 1 − α ˉ t (109) \text{SNR}(t) = \frac{\bar{\alpha}_t}{1 - \bar{\alpha}_t} \tag{109} SNR(t)=1−αˉtαˉt(109)
然后,我们推导的公式108(以及公式99)可以简化为:
1 2 σ q 2 ( t ) α ˉ t − 1 ( 1 − α t ) 2 ( 1 − α ˉ t ) 2 [ ∥ x ^ θ ( x t , t ) − x 0 ∥ 2 2 ] = 1 2 ( SNR ( t − 1 ) − SNR ( t ) ) [ ∥ x ^ θ ( x t , t ) − x 0 ∥ 2 2 ] (110) \frac{1}{2\sigma_q^2(t)} \frac{\bar{\alpha}_{t-1}(1-\alpha_t)^2}{(1-\bar{\alpha}_t)^2} \Big[\left\| \hat{x}_\theta(x_t, t) - x_0 \right\|_2^2 \Big] = \frac{1}{2} \left( \text{SNR}(t-1) - \text{SNR}(t) \right) \Big[\left\| \hat{x}_\theta(x_t, t) - x_0 \right\|_2^2 \Big] \tag{110} 2σq2(t)1(1−αˉt)2αˉt−1(1−αt)2[∥x^θ(xt,t)−x0∥22]=21(SNR(t−1)−SNR(t))[∥x^θ(xt,t)−x0∥22](110)
正如其名称所暗示的,信噪比(SNR)表示原始信号与存在的噪声量之间的比率;更高的SNR表示更多的信号,而较低的SNR表示更多的噪声。在扩散模型中,我们要求SNR随着时间步 t t t的增加单调递减;这形式化地体现了扰动输入 x t x_t xt随时间逐渐变得越来越嘈杂,直到在 t = T t = T t=T时与标准高斯分布相同。
根据公式110中目标的简化,我们可以直接使用神经网络对每个时间步的SNR进行参数化,并与扩散模型一起联合学习。由于SNR必须随着时间单调递减,我们可以将其表示为:
SNR ( t ) = exp ( − ω η ( t ) ) (111) \text{SNR}(t) = \exp\left(-\omega_{\eta}(t)\right) \tag{111} SNR(t)=exp(−ωη(t))(111)
其中, ω η ( t ) \omega_{\eta}(t) ωη(t)由一个单调递增的神经网络建模,参数为 η \eta η。对 ω η ( t ) \omega_{\eta}(t) ωη(t)取负值会导致结果成为单调递减函数,而指数函数强制结果为正值。注意,式(100)中的目标函数现在也必须对 η \eta η(参数)进行优化。通过将式(111)中的SNR参数化与式(109)中的SNR定义相结合,我们还可以显式推导出 α ˉ t \bar{\alpha}_t αˉt的优雅形式,以及 1 − α ˉ t 1 - \bar{\alpha}_t 1−αˉt的表达式:
α ˉ t 1 − α ˉ t = exp ( − ω η ( t ) ) (112) \frac{\bar{\alpha}_t}{1 - \bar{\alpha}_t} = \exp(-\omega_{\eta}(t)) \tag{112} 1−αˉtαˉt=exp(−ωη(t))(112) ∴ α ˉ t = sigmoid ( − ω η ( t ) ) (113) \therefore \bar{\alpha}_t = \text{sigmoid}(-\omega_{\eta}(t)) \tag{113} ∴αˉt=sigmoid(−ωη(t))(113) ∴ 1 − α ˉ t = sigmoid ( ω η ( t ) ) (114) \therefore 1 - \bar{\alpha}_t = \text{sigmoid}(\omega_{\eta}(t)) \tag{114} ∴1−αˉt=sigmoid(ωη(t))(114)
这些项在多种计算中是必需的;例如,在优化过程中,它们用于通过重参数化技巧从输入 x 0 x_0 x0创建任意噪声的 x t x_t xt,这一过程在公式(69)中推导得出。
1. SNR 的参数化
为了强制 SNR 随时间单调递减,我们将其表示为一个指数函数(总是正值且单调递减): SNR ( t ) = exp ( − ω η ( t ) ) . (111) \text{SNR}(t) = \exp(-\omega_\eta(t)). \tag{111} SNR(t)=exp(−ωη(t)).(111)
解释公式 (111):
ω η ( t ) \omega_\eta(t) ωη(t) 是一个单调递增函数:
- 这是由一个神经网络建模的(参数为 η \eta η)。
- 时间 t t t 增大时, ω η ( t ) \omega_\eta(t) ωη(t) 增大。
负号作用:
- 取负号 − ω η ( t ) -\omega_\eta(t) −ωη(t) 将单调递增函数变为单调递减。
- 指数函数 exp ( ⋅ ) \exp(\cdot) exp(⋅) 强制结果为正值。
直观意义:
- 时间 t = 0 t=0 t=0 时, ω η ( 0 ) ≈ 0 \omega_\eta(0) \approx 0 ωη(0)≈0,因此 SNR ( 0 ) ≈ 1 \text{SNR}(0) \approx 1 SNR(0)≈1,表示干净信号。
- 时间 t = T t=T t=T 时, ω η ( T ) → ∞ \omega_\eta(T) \to \infty ωη(T)→∞,因此 SNR ( T ) → 0 \text{SNR}(T) \to 0 SNR(T)→0,表示纯噪声。
2. 推导 α ˉ t \bar{\alpha}_t αˉt 和 1 − α ˉ t 1 - \bar{\alpha}_t 1−αˉt
根据 SNR 的定义: SNR ( t ) = α ˉ t 1 − α ˉ t . \text{SNR}(t) = \frac{\bar{\alpha}_t}{1 - \bar{\alpha}_t}. SNR(t)=1−αˉtαˉt.
2.1 表达式重写 将 SNR ( t ) = exp ( − ω η ( t ) ) \text{SNR}(t) = \exp(-\omega_\eta(t)) SNR(t)=exp(−ωη(t)) 代入: α ˉ t 1 − α ˉ t = exp ( − ω η ( t ) ) . (112) \frac{\bar{\alpha}_t}{1 - \bar{\alpha}_t} = \exp(-\omega_\eta(t)). \tag{112} 1−αˉtαˉt=exp(−ωη(t)).(112)
这是公式 (112)。
2.2 求解 α ˉ t \bar{\alpha}_t αˉt
通过代数化简: α ˉ t 1 − α ˉ t = exp ( − ω η ( t ) ) , \frac{\bar{\alpha}_t}{1 - \bar{\alpha}_t} = \exp(-\omega_\eta(t)), 1−αˉtαˉt=exp(−ωη(t)), 两边乘以 1 − α ˉ t 1 - \bar{\alpha}_t 1−αˉt: α ˉ t = exp ( − ω η ( t ) ) ( 1 − α ˉ t ) . \bar{\alpha}_t = \exp(-\omega_\eta(t)) (1 - \bar{\alpha}_t). αˉt=exp(−ωη(t))(1−αˉt). 展开后: α ˉ t + exp ( − ω η ( t ) ) α ˉ t = exp ( − ω η ( t ) ) . \bar{\alpha}_t + \exp(-\omega_\eta(t)) \bar{\alpha}_t = \exp(-\omega_\eta(t)). αˉt+exp(−ωη(t))αˉt=exp(−ωη(t)). α ˉ t ( 1 + exp ( − ω η ( t ) ) ) = exp ( − ω η ( t ) ) . \bar{\alpha}_t (1 + \exp(-\omega_\eta(t))) = \exp(-\omega_\eta(t)). αˉt(1+exp(−ωη(t)))=exp(−ωη(t)). 因此: α ˉ t = exp ( − ω η ( t ) ) 1 + exp ( − ω η ( t ) ) . \bar{\alpha}_t = \frac{\exp(-\omega_\eta(t))}{1 + \exp(-\omega_\eta(t))}. αˉt=1+exp(−ωη(t))exp(−ωη(t)).
这正是逻辑函数(Sigmoid 函数)的定义: α ˉ t = sigmoid ( − ω η ( t ) ) . (113) \bar{\alpha}_t = \text{sigmoid}(-\omega_\eta(t)). \tag{113} αˉt=sigmoid(−ωη(t)).(113)
2.3 求解 1 − α ˉ t 1 - \bar{\alpha}_t 1−αˉt
从逻辑函数的性质得知: 1 − sigmoid ( x ) = sigmoid ( − x ) . 1 - \text{sigmoid}(x) = \text{sigmoid}(-x). 1−sigmoid(x)=sigmoid(−x).
因此: 1 − α ˉ t = 1 − sigmoid ( − ω η ( t ) ) = sigmoid ( ω η ( t ) ) . (114) 1 - \bar{\alpha}_t = 1 - \text{sigmoid}(-\omega_\eta(t)) = \text{sigmoid}(\omega_\eta(t)). \tag{114} 1−αˉt=1−sigmoid(−ωη(t))=sigmoid(ωη(t)).(114)
3. 优化过程中的作用
公式 (113) 和 (114) 的意义
- α ˉ t = sigmoid ( − ω η ( t ) ) \bar{\alpha}_t = \text{sigmoid}(-\omega_\eta(t)) αˉt=sigmoid(−ωη(t)):表示在时间 t t t 时信号的强度比例。
- 1 − α ˉ t = sigmoid ( ω η ( t ) ) 1 - \bar{\alpha}_t = \text{sigmoid}(\omega_\eta(t)) 1−αˉt=sigmoid(ωη(t)):表示在时间 t t t 时噪声的强度比例。
在优化过程中:
- 我们用神经网络建模 ω η ( t ) \omega_\eta(t) ωη(t),确保信噪比随时间单调递减。
- 参数化的 α ˉ t \bar{\alpha}_t αˉt 和 1 − α ˉ t 1 - \bar{\alpha}_t 1−αˉt 用于生成带噪数据 x t x_t xt,并通过重参数化技巧优化模型。
4. 示例:将信号和噪声混合
假设 x 0 x_0 x0 是干净数据,噪声为 ϵ \epsilon ϵ,带噪数据 x t x_t xt 的生成公式为: x t = α ˉ t x 0 + 1 − α ˉ t ϵ . x_t = \sqrt{\bar{\alpha}_t} x_0 + \sqrt{1 - \bar{\alpha}_t} \epsilon. xt=αˉtx0+1−αˉtϵ.
用公式 (113) 和 (114) 计算:
- 当 t → 0 t \to 0 t→0: α ˉ t → 1 \bar{\alpha}_t \to 1 αˉt→1, 1 − α ˉ t → 0 1 - \bar{\alpha}_t \to 0 1−αˉt→0,则 x t ≈ x 0 x_t \approx x_0 xt≈x0(干净数据)。
- 当 t → T t \to T t→T: α ˉ t → 0 \bar{\alpha}_t \to 0 αˉt→0, 1 − α ˉ t → 1 1 - \bar{\alpha}_t \to 1 1−αˉt→1,则 x t ≈ ϵ x_t \approx \epsilon xt≈ϵ(纯噪声)。
5. 总结
SNR 的引入:
- 为了确保信噪比单调递减,定义 SNR ( t ) = exp ( − ω η ( t ) ) \text{SNR}(t) = \exp(-\omega_\eta(t)) SNR(t)=exp(−ωη(t)),其中 ω η ( t ) \omega_\eta(t) ωη(t) 是单调递增函数。
SNR 与 α ˉ t \bar{\alpha}_t αˉt 的关系:
- 信号强度 α ˉ t = sigmoid ( − ω η ( t ) ) \bar{\alpha}_t = \text{sigmoid}(-\omega_\eta(t)) αˉt=sigmoid(−ωη(t))。
- 噪声强度 1 − α ˉ t = sigmoid ( ω η ( t ) ) 1 - \bar{\alpha}_t = \text{sigmoid}(\omega_\eta(t)) 1−αˉt=sigmoid(ωη(t))。
实际意义:
- 通过神经网络建模 ω η ( t ) \omega_\eta(t) ωη(t),我们可以控制信号和噪声的比例,进而优化生成模型。
三种等效的解释(Three Equivalent Interpretations)
正如我们之前证明的,变分扩散模型(VDM)可以通过学习一个神经网络来预测从任意噪声版本 x t x_t xt及其时间索引 t t t恢复原始自然图像 x 0 x_0 x0来进行训练。然而, x 0 x_0 x0还有另外两种等效的参数化形式,这导致了对VDM的另外两种解释。
首先,我们可以利用重参数化技巧。在我们推导 q ( x t ∣ x 0 ) q(x_t|x_0) q(xt∣x0)形式的过程中,可以重新排列公式(69)以得到:
x 0 = x t − 1 − α ˉ t ϵ 0 α ˉ t (115) x_0 = \frac{x_t - \sqrt{1 - \bar{\alpha}_t} \epsilon_0}{\sqrt{\bar{\alpha}_t}} \tag{115} x0=αˉtxt−1−αˉtϵ0(115)
将此代入我们之前推导出的真实去噪转移均值 μ q ( x t , x 0 ) \mu_q(x_t, x_0) μq(xt,x0),我们可以重新推导如下:
公式(116)来自公式(93)
将此代入我们之前推导出的真实去噪转移均值 μ q ( x t , x 0 ) \mu_q(x_t, x_0) μq(xt,x0),我们可以重新推导如下:
μ θ ( x t , t ) = 1 α t x t − 1 − α t 1 − α ˉ t α t ϵ ^ 0 ( x t , t ) (125) \mu_\theta(x_t, t) = \frac{1}{\sqrt{\alpha_t}} x_t - \frac{1 - \alpha_t}{\sqrt{1 - \bar{\alpha}_t} \sqrt{\alpha_t}} \hat{\epsilon}_0(x_t, t) \tag{125} μθ(xt,t)=αt1xt−1−αˉtαt1−αtϵ^0(xt,t)(125)
相应的优化问题变为:
这里, ϵ ^ θ ( x t , t ) \hat{\epsilon}_\theta(x_t, t) ϵ^θ(xt,t)是一个神经网络,用于学习预测源噪声 ϵ 0 ∼ N ( ϵ ; 0 , I ) \epsilon_0 \sim \mathcal{N}(\epsilon; 0, I) ϵ0∼N(ϵ;0,I),该噪声决定了从 x 0 x_0 x0生成 x t x_t xt。因此,我们已经证明,通过预测原始图像 x 0 x_0 x0来学习VDM等同于学习预测噪声;然而,从经验上看,一些研究发现,预测噪声会带来更好的性能[5, 7]。
为了推导出变分扩散模型的第三种常见解释,我们求助于Tweedie公式[8]。用简单的语言来说,Tweedie公式指出,给定从指数族分布中抽取的样本,该分布的真实均值可以通过样本的最大似然估计(即经验均值)加上一些涉及估计得分的修正项来估计。在仅有一个观测样本的情况下,经验均值就是样本本身。该公式通常用于减轻样本偏差;如果观测样本都偏向底层分布的一端,则负得分会变大,从而将样本的朴素最大似然估计修正为更接近真实均值的值。
数学上,对于一个高斯变量 z ∼ N ( z ; μ z , Σ z ) z \sim \mathcal{N}(z; \mu_z, \Sigma_z) z∼N(z;μz,Σz),Tweedie公式表示:
E [ μ z ∣ z ] = z + Σ z ∇ z log p ( z ) \mathbb{E}[\mu_z|z] = z + \Sigma_z \nabla_z \log p(z) E[μz∣z]=z+Σz∇zlogp(z)
在这种情况下,我们将其应用于预测已知样本的条件下 x t x_t xt的真实后验均值。从公式70中我们知道:
q ( x t ∣ x 0 ) = N ( x t ; α ˉ t x 0 , ( 1 − α ˉ t ) I ) q(x_t | x_0) = \mathcal{N}(x_t; \sqrt{\bar{\alpha}_t} x_0, (1 - \bar{\alpha}_t) \mathbf{I}) q(xt∣x0)=N(xt;αˉtx0,(1−αˉt)I)
然后,根据Tweedie公式,我们可以得到:
E [ μ x t ∣ x t ] = x t + ( 1 − α ˉ t ) ∇ x t log p ( x t ) (131) \mathbb{E}[\mu_{x_t} | x_t] = x_t + (1 - \bar{\alpha}_t) \nabla_{x_t} \log p(x_t) \tag{131} E[μxt∣xt]=xt+(1−αˉt)∇xtlogp(xt)(131)
整体目标和背景
我们要解决的问题是 优化变分扩散模型(VDM)。核心目标是:
- 学习一个模型 μ θ ( x t , t ) \mu_\theta(x_t, t) μθ(xt,t),使其能够逼近数据真实分布的均值 μ q ( x t , x 0 ) \mu_q(x_t, x_0) μq(xt,x0)。
- 为此,我们引入了 Tweedie公式,推导了 μ q ( x t , x 0 ) \mu_q(x_t, x_0) μq(xt,x0) 的新形式。
最终优化目标是最小化一个 KL 散度(衡量两个分布之间的差异)。
1. Tweedie公式是什么?
1.1 定义:
Tweedie公式是用于估计 真实均值 的一个工具。公式为: E [ μ z ∣ z ] = z + Σ z ∇ z log p ( z ) . \mathbb{E}[\mu_z | z] = z + \Sigma_z \nabla_z \log p(z). E[μz∣z]=z+Σz∇zlogp(z). 这里:
- z z z 是观察到的样本。
- μ z \mu_z μz 是真实均值(需要估计)。
- Σ z ∇ z log p ( z ) \Sigma_z \nabla_z \log p(z) Σz∇zlogp(z) 是修正项,用于纠正样本偏差。
- ∇ z log p ( z ) \nabla_z \log p(z) ∇zlogp(z):分布的对数梯度(也叫 得分函数,score function)。
1.2 作用 Tweedie公式的作用是:
- 在你观察到 z z z 时,估计真实的均值 μ z \mu_z μz。
- 它修正了直接使用样本均值的偏差。
1.3 直观例子 假设 z ∼ N ( μ , σ 2 ) z \sim \mathcal{N}(\mu, \sigma^2) z∼N(μ,σ2):
- 如果你观测到 z = 1.5 z = 1.5 z=1.5,真实均值可能不是 z z z 自身,而是 z + 修正项 z + \text{修正项} z+修正项。
- 修正项由样本的偏差方向决定。
2. 将 Tweedie公式应用于 q ( x t ∣ x 0 ) q(x_t|x_0) q(xt∣x0)
2.1 q ( x t ∣ x 0 ) q(x_t|x_0) q(xt∣x0) 的形式 在扩散模型中: q ( x t ∣ x 0 ) = N ( x t ; α ˉ t x 0 , ( 1 − α ˉ t ) I ) , q(x_t|x_0) = \mathcal{N}(x_t; \sqrt{\bar{\alpha}_t} x_0, (1 - \bar{\alpha}_t) \mathbf{I}), q(xt∣x0)=N(xt;αˉtx0,(1−αˉt)I),
- 均值 μ q ( x t ∣ x 0 ) = α ˉ t x 0 \mu_q(x_t|x_0) = \sqrt{\bar{\alpha}_t} x_0 μq(xt∣x0)=αˉtx0。
- 协方差 Σ q = ( 1 − α ˉ t ) I \Sigma_q = (1 - \bar{\alpha}_t) \mathbf{I} Σq=(1−αˉt)I。
2.2 应用 Tweedie公式 根据 Tweedie公式: E [ μ x t ∣ x t ] = x t + Σ q ∇ x t log p ( x t ) . \mathbb{E}[\mu_{x_t}|x_t] = x_t + \Sigma_q \nabla_{x_t} \log p(x_t). E[μxt∣xt]=xt+Σq∇xtlogp(xt).
代入协方差 Σ q = ( 1 − α ˉ t ) I \Sigma_q = (1 - \bar{\alpha}_t) \mathbf{I} Σq=(1−αˉt)I: E [ μ x t ∣ x t ] = x t + ( 1 − α ˉ t ) ∇ x t log p ( x t ) . (131) \mathbb{E}[\mu_{x_t}|x_t] = x_t + (1 - \bar{\alpha}_t) \nabla_{x_t} \log p(x_t). \tag{131} E[μxt∣xt]=xt+(1−αˉt)∇xtlogp(xt).(131)
我们将 ∇ x t log p ( x t ) \nabla_{x_t}\log p(x_t) ∇xtlogp(xt)简写为 ∇ log p ( x t ) \nabla\log p(x_t) ∇logp(xt)以简化符号表示。根据Tweedie公式, x t x_t xt生成的真实均值的最佳估计 μ x t = α ˉ t x 0 \mu_{x_t}=\sqrt{\bar{\alpha}_t}x_0 μxt=αˉtx0定义为:
α ˉ t x 0 = x t + ( 1 − α ˉ t ) ∇ log p ( x t ) (132) \sqrt{\bar{\alpha}_t}x_0 = x_t + (1 - \bar{\alpha}_t)\nabla \log p(x_t) \tag{132} αˉtx0=xt+(1−αˉt)∇logp(xt)(132) ∴ x 0 = x t + ( 1 − α ˉ t ) ∇ log p ( x t ) α ˉ t (133) \therefore x_0 = \frac{x_t + (1 - \bar{\alpha}_t)\nabla \log p(x_t)}{\sqrt{\bar{\alpha}_t}} \tag{133} ∴x0=αˉtxt+(1−αˉt)∇logp(xt)(133)
3. 推导 x 0 x_0 x0 的表达式
3.1 利用均值关系 从 q ( x t ∣ x 0 ) q(x_t|x_0) q(xt∣x0) 的定义,知道均值为: μ q ( x t ∣ x 0 ) = α ˉ t x 0 . \mu_q(x_t|x_0) = \sqrt{\bar{\alpha}_t} x_0. μq(xt∣x0)=αˉtx0.
3.2 结合 Tweedie公式 根据公式 (131): α ˉ t x 0 = x t + ( 1 − α ˉ t ) ∇ x t log p ( x t ) . (132) \sqrt{\bar{\alpha}_t} x_0 = x_t + (1 - \bar{\alpha}_t) \nabla_{x_t} \log p(x_t). \tag{132} αˉtx0=xt+(1−αˉt)∇xtlogp(xt).(132)
3.3 解出 x 0 x_0 x0 整理公式,解出 x 0 x_0 x0: x 0 = x t + ( 1 − α ˉ t ) ∇ log p ( x t ) α ˉ t . (133) x_0 = \frac{x_t + (1 - \bar{\alpha}_t) \nabla \log p(x_t)}{\sqrt{\bar{\alpha}_t}}. \tag{133} x0=αˉtxt+(1−αˉt)∇logp(xt).(133)
然后,我们可以再次将公式(133)代入真实的去噪转移均值 μ q ( x t , x 0 ) \mu_q(x_t,x_0) μq(xt,x0)中,并推导出一种新的形式:
μ q ( x t , x 0 ) = α t ( 1 − α ˉ t − 1 ) x t + α ˉ t − 1 ( 1 − α t ) x 0 1 − α ˉ t ( 134 ) = α t ( 1 − α ˉ t − 1 ) x t + α ˉ t − 1 ( 1 − α t ) x t + ( 1 − α ˉ t ) ∇ log p ( x t ) α ˉ t 1 − α ˉ t ( 135 ) = α t ( 1 − α ˉ t − 1 ) x t + ( 1 − α t ) x t + ( 1 − α ˉ t ) ∇ log p ( x t ) α t 1 − α ˉ t ( 136 ) = α t ( 1 − α ˉ t − 1 ) x t 1 − α ˉ t + ( 1 − α t ) x t ( 1 − α ˉ t ) α t + ( 1 − α t ) ( 1 − α ˉ t ) ∇ log p ( x t ) ( 1 − α ˉ t ) α t ( 137 ) = ( α t ( 1 − α ˉ t − 1 ) 1 − α ˉ t + ( 1 − α t ) ( 1 − α ˉ t ) α t ) x t + ( 1 − α t ) ∇ log p ( x t ) α t ( 138 ) = ( α t ( 1 − α ˉ t − 1 ) ( 1 − α ˉ t ) α t + ( 1 − α t ) ( 1 − α ˉ t ) α t ) x t + ( 1 − α t ) ∇ log p ( x t ) α t ( 139 ) = α t − α ˉ t + 1 − α t ( 1 − α ˉ t ) α t x t + ( 1 − α t ) ∇ log p ( x t ) α t ( 140 ) = 1 − α ˉ t ( 1 − α ˉ t ) α t x t + ( 1 − α t ) ∇ log p ( x t ) α t ( 141 ) = 1 α t x t + ( 1 − α t ) ∇ log p ( x t ) α t ( 142 ) \begin{aligned} \mu_q(x_t, x_0) &= \frac{\sqrt{\alpha_t}(1 - \bar{\alpha}_{t-1})x_t + \sqrt{\bar{\alpha}_{t-1}}(1 - \alpha_t)x_0}{1 - \bar{\alpha}_t} & \quad & (134) \\ &= \frac{\sqrt{\alpha_t}(1 - \bar{\alpha}_{t-1})x_t + \sqrt{\bar{\alpha}_{t-1}}(1 - \alpha_t)\frac{x_t + (1 - \bar{\alpha}_t)\nabla \log p(x_t)}{\sqrt{\bar{\alpha}_t}}}{1 - \bar{\alpha}_t} & \quad & (135) \\ &= \frac{\sqrt{\alpha_t}(1 - \bar{\alpha}_{t-1})x_t+(1 - \alpha_t)\frac{x_t + (1 - \bar{\alpha}_t)\nabla \log p(x_t)}{\sqrt{{\alpha}_t}}}{1 - \bar{\alpha}_t} & \quad & (136) \\ &= \frac{\sqrt{\alpha_t}(1 - \bar{\alpha}_{t-1})x_t}{1 - \bar{\alpha}_t} + \frac{(1 - \alpha_t)x_t}{(1 - \bar{\alpha}_t)\sqrt{\alpha_t}} + \frac{(1 - \alpha_t)(1 - \bar{\alpha}_t)\nabla \log p(x_t)}{(1 - \bar{\alpha}_t)\sqrt{\alpha_t}} & \quad & (137) \\ &= \Bigg(\frac{\sqrt{\alpha_t}(1 - \bar{\alpha}_{t-1})}{1 - \bar{\alpha}_t} + \frac{(1 - \alpha_t)}{(1 - \bar{\alpha}_t)\sqrt{\alpha_t}} \Bigg)x_t+\frac{(1 - \alpha_t)\nabla \log p(x_t)}{\sqrt{\alpha_t}} & \quad & (138) \\ &= \Bigg(\frac{{\alpha_t}(1 - \bar{\alpha}_{t-1})}{(1 - \bar{\alpha}_t)\sqrt{\alpha_t}} + \frac{(1 - \alpha_t)}{(1 - \bar{\alpha}_t)\sqrt{\alpha_t}} \Bigg)x_t +\frac{(1 - \alpha_t)\nabla \log p(x_t)}{\sqrt{\alpha_t}} & \quad & (139) \\ &= \frac{{\alpha_t} - \bar{\alpha}_{t}+1-\alpha_t}{(1 - \bar{\alpha}_t)\sqrt{\alpha_t}} x_t +\frac{(1 - \alpha_t)\nabla \log p(x_t)}{\sqrt{\alpha_t}} & \quad & (140) \\ &= \frac{1 - \bar{\alpha}_{t}}{(1 - \bar{\alpha}_t)\sqrt{\alpha_t}} x_t + \frac{(1 - \alpha_t)\nabla \log p(x_t)}{\sqrt{\alpha_t}} & \quad & (141) \\ &= \frac{1}{\sqrt{\alpha_t}} x_t +\frac{(1 - \alpha_t)\nabla \log p(x_t)}{\sqrt{\alpha_t}} & \quad & (142) \end{aligned} μq(xt,x0)=1−αˉtαt(1−αˉt−1)xt+αˉt−1(1−αt)x0=1−αˉtαt(1−αˉt−1)xt+αˉt−1(1−αt)αˉtxt+(1−αˉt)∇logp(xt)=1−αˉtαt(1−αˉt−1)xt+(1−αt)αtxt+(1−αˉt)∇logp(xt)=1−αˉtαt(1−αˉt−1)xt+(1−αˉt)αt(1−αt)xt+(1−αˉt)αt(1−αt)(1−αˉt)∇logp(xt)=(1−αˉtαt(1−αˉt−1)+(1−αˉt)αt(1−αt))xt+αt(1−αt)∇logp(xt)=((1−αˉt)αtαt(1−αˉt−1)+(1−αˉt)αt(1−αt))xt+αt(1−αt)∇logp(xt)=(1−αˉt)αtαt−αˉt+1−αtxt+αt(1−αt)∇logp(xt)=(1−αˉt)αt1−αˉtxt+αt(1−αt)∇logp(xt)=αt1xt+αt(1−αt)∇logp(xt)(134)(135)(136)(137)(138)(139)(140)(141)(142)
4. 替换 μ q ( x t , x 0 ) \mu_q(x_t, x_0) μq(xt,x0)
4.1 原始公式 真实去噪均值公式为: μ q ( x t , x 0 ) = α t ( 1 − α ˉ t − 1 ) x t + α ˉ t − 1 ( 1 − α t ) x 0 1 − α ˉ t . (134) \mu_q(x_t, x_0) = \frac{\sqrt{\alpha_t}(1 - \bar{\alpha}_{t-1}) x_t + \sqrt{\bar{\alpha}_{t-1}}(1 - \alpha_t) x_0}{1 - \bar{\alpha}_t}. \tag{134} μq(xt,x0)=1−αˉtαt(1−αˉt−1)xt+αˉt−1(1−αt)x0.(134)
来自公式(93)
4.2 用 x 0 x_0 x0 替换 将 x 0 = x t + ( 1 − α ˉ t ) ∇ log p ( x t ) α ˉ t x_0 = \frac{x_t + (1 - \bar{\alpha}_t) \nabla \log p(x_t)}{\sqrt{\bar{\alpha}_t}} x0=αˉtxt+(1−αˉt)∇logp(xt) 代入公式 (134): μ q ( x t , x 0 ) = α t ( 1 − α ˉ t − 1 ) x t + α ˉ t − 1 ( 1 − α t ) x t + ( 1 − α ˉ t ) ∇ log p ( x t ) α ˉ t 1 − α ˉ t . \mu_q(x_t, x_0) = \frac{\sqrt{\alpha_t}(1 - \bar{\alpha}_{t-1}) x_t + \sqrt{\bar{\alpha}_{t-1}}(1 - \alpha_t) \frac{x_t + (1 - \bar{\alpha}_t) \nabla \log p(x_t)}{\sqrt{\bar{\alpha}_t}}}{1 - \bar{\alpha}_t}. μq(xt,x0)=1−αˉtαt(1−αˉt−1)xt+αˉt−1(1−αt)αˉtxt+(1−αˉt)∇logp(xt).
5. 化简公式
以下是化简的关键步骤:
- 展开分子:将 x t x_t xt 和梯度项分开。
- 提取系数:将 x t x_t xt 和 ∇ log p ( x t ) \nabla \log p(x_t) ∇logp(xt) 的系数分别写成一个整体。
最终得到的结果为: μ q ( x t , x 0 ) = 1 α t x t + 1 − α t α t ∇ log p ( x t ) . (142) \mu_q(x_t, x_0) = \frac{1}{\sqrt{\alpha_t}} x_t + \frac{1 - \alpha_t}{\sqrt{\alpha_t}} \nabla \log p(x_t). \tag{142} μq(xt,x0)=αt1xt+αt1−αt∇logp(xt).(142)
化简过程还是很简单的
因此,我们也可以将近似去噪转换的均值 μ θ ( x t , t ) \mu_\theta(x_t,t) μθ(xt,t)设置为:
μ θ ( x t , t ) = 1 α t x t + 1 − α t α t s θ ( x t , t ) (143) \mu_\theta(x_t,t) = \frac{1}{\sqrt{\alpha_t}}x_t + \frac{1-\alpha_t}{\sqrt{\alpha_t}}s_\theta(x_t,t) \tag{143} μθ(xt,t)=αt1xt+αt1−αtsθ(xt,t)(143)
对应的优化问题变为:
arg
min
θ
D
K
L
(
q
(
x
t
−
1
∣
x
t
,
x
0
)
∥
p
θ
(
x
t
−
1
∣
x
t
)
)
\underset{\theta}{\arg\min} \, D_{KL}(q(x_{t-1}|x_t, x_0) \, \| \, p_\theta(x_{t-1}|x_t)) \\
θargminDKL(q(xt−1∣xt,x0)∥pθ(xt−1∣xt))
=
arg
min
θ
D
K
L
(
N
(
x
t
−
1
;
μ
q
,
Σ
q
(
t
)
)
∥
N
(
x
t
−
1
;
μ
θ
,
Σ
q
(
t
)
)
)
(
144
)
=
arg
min
θ
1
2
σ
q
2
(
t
)
[
∥
1
α
t
x
t
+
1
−
α
t
α
t
s
θ
(
x
t
,
t
)
−
1
α
t
x
t
−
1
−
α
t
α
t
∇
log
p
(
x
t
)
∥
2
2
]
(
145
)
=
arg
min
θ
1
2
σ
q
2
(
t
)
[
∥
1
−
α
t
α
t
s
θ
(
x
t
,
t
)
−
1
−
α
t
α
t
∇
log
p
(
x
t
)
∥
2
2
]
(
146
)
=
arg
min
θ
1
2
σ
q
2
(
t
)
[
∥
(
1
−
α
t
)
α
t
(
s
θ
(
x
t
,
t
)
−
∇
log
p
(
x
t
)
)
∥
2
2
]
(
147
)
=
arg
min
θ
1
2
σ
q
2
(
t
)
(
1
−
α
t
)
2
α
t
[
∥
s
θ
(
x
t
,
t
)
−
∇
log
p
(
x
t
)
∥
2
2
]
(
148
)
\begin{aligned} &= \underset{\theta}{\arg\min} \, D_{KL}(\mathcal{N}(x_{t-1}; \mu_q, \Sigma_q(t)) \, \| \, \mathcal{N}(x_{t-1}; \mu_\theta, \Sigma_q(t))) & \quad & (144) \\ &= \underset{\theta}{\arg\min} \, \frac{1}{2\sigma_q^2(t)} \left[ \left\| \frac{1}{\sqrt{\alpha_t}}x_t + \frac{1-\alpha_t}{\sqrt{\alpha_t}}s_\theta(x_t, t) - \frac{1}{\sqrt{\alpha_t}}x_t - \frac{1-\alpha_t}{\sqrt{\alpha_t}}\nabla \log p(x_t) \right\|_2^2 \right] & \quad & (145) \\ &= \underset{\theta}{\arg\min} \, \frac{1}{2\sigma_q^2(t)} \left[ \left\| \frac{1-\alpha_t}{\sqrt{\alpha_t}} s_\theta(x_t, t) - \frac{1-\alpha_t}{\sqrt{\alpha_t}} \nabla \log p(x_t) \right\|_2^2 \right] & \quad & (146) \\ &= \underset{\theta}{\arg\min} \, \frac{1}{2\sigma_q^2(t)} \left[ \left\|\frac{(1-\alpha_t)}{\alpha_t} (s_\theta(x_t, t) - \nabla \log p(x_t) )\right\|_2^2 \right] & \quad & (147) \\ &= \underset{\theta}{\arg\min} \, \frac{1}{2\sigma_q^2(t)}\frac{(1-\alpha_t)^2}{\alpha_t } \left[ \left\| s_\theta(x_t, t) - \nabla \log p(x_t) \right\|_2^2 \right] & \quad & (148) \end{aligned}
=θargminDKL(N(xt−1;μq,Σq(t))∥N(xt−1;μθ,Σq(t)))=θargmin2σq2(t)1[
αt1xt+αt1−αtsθ(xt,t)−αt1xt−αt1−αt∇logp(xt)
22]=θargmin2σq2(t)1[
αt1−αtsθ(xt,t)−αt1−αt∇logp(xt)
22]=θargmin2σq2(t)1[
αt(1−αt)(sθ(xt,t)−∇logp(xt))
22]=θargmin2σq2(t)1αt(1−αt)2[∥sθ(xt,t)−∇logp(xt)∥22](144)(145)(146)(147)(148)
在这里, s θ ( x t , t ) s_\theta(x_t, t) sθ(xt,t) 是一个神经网络,用于学习预测分数函数 ∇ x t log p ( x t ) \nabla_{x_t} \log p(x_t) ∇xtlogp(xt),即数据空间中任意噪声水平 t t t 下的 x t x_t xt 的梯度。
敏锐的读者会注意到,分数函数 ∇ log p ( x t ) \nabla \log p(x_t) ∇logp(xt) 的形式与源噪声 ϵ 0 \epsilon_0 ϵ0 极为相似。这可以通过将 Tweedie 公式(方程 133)与重参数化技巧(方程 115)结合起来显式地证明出来。
x 0 = x t + ( 1 − α ˉ t ) ∇ log p ( x t ) α ˉ t = x t − 1 − α ˉ t ϵ 0 α ˉ t ( 149 ) ∴ ( 1 − α ˉ t ) ∇ log p ( x t ) = − 1 − α ˉ t ϵ 0 ( 150 ) ∇ log p ( x t ) = − 1 1 − α ˉ t ϵ 0 ( 151 ) \begin{aligned} x_0 = \frac{x_t + (1 - \bar{\alpha}_t)\nabla \log p(x_t)}{\sqrt{\bar{\alpha}_t}} & = \frac{x_t - \sqrt{1 - \bar{\alpha}_t}\epsilon_0}{\sqrt{\bar{\alpha}_t}} \quad \quad (149) \\ \therefore (1 - \bar{\alpha}_t)\nabla \log p(x_t) &= -\sqrt{1 - \bar{\alpha}_t}\epsilon_0 \quad \quad (150) \\ \nabla \log p(x_t) &= -\frac{1}{\sqrt{1 - \bar{\alpha}_t}}\epsilon_0 \quad \quad (151) \end{aligned} x0=αˉtxt+(1−αˉt)∇logp(xt)∴(1−αˉt)∇logp(xt)∇logp(xt)=αˉtxt−1−αˉtϵ0(149)=−1−αˉtϵ0(150)=−1−αˉt1ϵ0(151)
6. 优化目标的推导
我们用神经网络 μ θ ( x t , t ) \mu_\theta(x_t, t) μθ(xt,t) 近似 μ q ( x t , x 0 ) \mu_q(x_t, x_0) μq(xt,x0): μ θ ( x t , t ) = 1 α t x t + 1 − α t α t s θ ( x t , t ) , \mu_\theta(x_t, t) = \frac{1}{\sqrt{\alpha_t}} x_t + \frac{1 - \alpha_t}{\sqrt{\alpha_t}} s_\theta(x_t, t), μθ(xt,t)=αt1xt+αt1−αtsθ(xt,t), 其中 s θ ( x t , t ) s_\theta(x_t, t) sθ(xt,t) 是神经网络的输出。
优化目标是: arg min θ D K L ( q ( x t − 1 ∣ x t , x 0 ) ∥ p θ ( x t − 1 ∣ x t ) ) . \arg \min_\theta D_{KL}(q(x_{t-1}|x_t, x_0) \| p_\theta(x_{t-1}|x_t)). argθminDKL(q(xt−1∣xt,x0)∥pθ(xt−1∣xt)).
KL 散度的计算结果为: arg min θ 1 2 σ q 2 ( t ) ( 1 − α t ) 2 α t ∥ s θ ( x t , t ) − ∇ log p ( x t ) ∥ 2 2 . \arg \min_\theta \frac{1}{2\sigma_q^2(t)} \frac{(1 - \alpha_t)^2}{\alpha_t} \| s_\theta(x_t, t) - \nabla \log p(x_t) \|_2^2. argθmin2σq2(t)1αt(1−αt)2∥sθ(xt,t)−∇logp(xt)∥22.
简化目标 arg min θ ∥ s θ ( x t , t ) − ∇ log p ( x t ) ∥ 2 2 . \arg \min_\theta \| s_\theta(x_t, t) - \nabla \log p(x_t) \|_2^2. argθmin∥sθ(xt,t)−∇logp(xt)∥22.
7. 分数函数和噪声的关系
通过 Tweedie公式和扩散模型的关系,可以推导分数函数和噪声之间的关系: ∇ log p ( x t ) = − 1 1 − α ˉ t ϵ 0 . (151) \nabla \log p(x_t) = -\frac{1}{\sqrt{1 - \bar{\alpha}_t}} \epsilon_0. \tag{151} ∇logp(xt)=−1−αˉt1ϵ0.(151)
这表明,学习分数函数 ∇ log p ( x t ) \nabla \log p(x_t) ∇logp(xt) 等价于学习噪声 ϵ 0 \epsilon_0 ϵ0 的负值。
8. 直观意义总结
- Tweedie公式:修正样本 x t x_t xt 偏差,估计其真实均值。
- 优化目标:学习神经网络 s θ ( x t , t ) s_\theta(x_t, t) sθ(xt,t),以近似分数函数或噪声。
- 最终模型:通过随机采样时间步 t t t,以可扩展方式优化神经网络。
事实证明,这两个项之间仅相差一个随时间缩放的常数因子!得分函数衡量了如何在数据空间中移动以最大化对数概率;直观上,由于源噪声被添加到自然图像中以对其进行破坏,沿着其相反方向移动可以“去噪”图像,并且这是增加后续对数概率的最佳更新方式。我们的数学证明证实了这一直觉;我们已经明确证明了,学习建模得分函数等价于建模源噪声的负值(乘以一个缩放因子)。
因此,我们得出了三种优化VDM的等效目标:训练一个神经网络预测原始图像 x 0 x_0 x0、源噪声 ϵ 0 \epsilon_0 ϵ0或任意噪声级别下图像的得分函数 ∇ log p ( x t ) \nabla \log p(x_t) ∇logp(xt)。通过随机采样时间步 t t t并最小化预测值与真实目标的范数,可以以可扩展的方式训练VDM。