DDIM扩散模型的加速采样(去噪)算法 Denoising Diffusion Implicit Models
DDIM扩散模型的加速采样算法 Denoising Diffusion Implicit Models
DDIM:发表于2021年ICLR,作者来自斯坦福大学。在使用DDPM进行目标检测的时候就结合使用了DDIM
因为时间和精力的原因,学习一下DDIM的主要原因是为了理解一下DiffusionDet的一个采样过程的原理,因此简单解释其中包含的核心原理,对论文的部分不进行详细的说明。
问题提出的背景
在之前学习DDPM的时候,自己在写文章推导DDPM的公式的时候也是主要推导了三个公式
- 加噪过程(推导出了一次加噪的公式)可以理解为是跳步加噪的过程
- 去噪过程(推导出了xt到xt-1的去噪过程公式)依次的通过神经网络进行逆推经过多次的去噪去还原x0也就是原始的图像了
- 损失函数过程:推导出最小化两个分布的KL散度。
对于下面的公式我们推导出了去噪的一个结论。
P ( x t − 1 ∣ x t ) P\left(x_{t-1} \mid x_{t}\right) P(xt−1∣xt)
x t − 1 = 1 α t ( x t − 1 − α t 1 − α ˉ t ϵ θ ( x t , t ) ) + σ t z x_{t-1}=\frac{1}{\sqrt{\alpha_{t}}}\left(x_{t}-\frac{1-\alpha_{t}}{\sqrt{1-\bar{\alpha}_{t}}} \epsilon_{\theta}\left(x_{t}, t\right)\right)+\sigma_{t} z xt−1=αt1(xt−1−αˉt1−αtϵθ(xt,t))+σtz
很容易就可以想到是否也可以和加噪过程一样进行跳步的去噪推导呢?也就是直接由xt推导出xt-2甚至直接推导出来x0.
同样的我们可以得到:
P ( x t − 2 ∣ x t − 1 ) P\left(x_{t-2} \mid x_{t-1}\right) P(xt−2∣xt−1)
x t − 2 = 1 α t − 1 ( x t − 1 − 1 − α t − 1 1 − α ˉ t − 1 ϵ θ ( x t − 1 , t − 1 ) ) + σ t − 1 z x_{t-2}=\frac{1}{\sqrt{\alpha_{t-1}}}\left(x_{t-1}-\frac{1-\alpha_{t-1}}{\sqrt{1-\bar{\alpha}_{t-1}}} \epsilon_{\theta}\left(x_{t-1}, t-1\right)\right)+\sigma_{t-1} z xt−2=αt−11(xt−1−1−αˉt−11−αt−1ϵθ(xt−1,t−1))+σt−1z
尝试将xt-1进行带入:
x t − 2 = 1 α t − 1 ( 1 α t ( x t − 1 − α t 1 − α ˉ t ϵ θ ( x t , t ) ) + σ t z − 1 − α t − 1 1 − α ˉ t − 1 ϵ θ ( x t − 1 , t − 1 ) ) + σ t − 1 z x_{t-2}=\frac{1}{\sqrt{\alpha_{t-1}}}\left(\frac{1}{\sqrt{\alpha_{t}}}\left(x_{t}-\frac{1-\alpha_{t}}{\sqrt{1-\bar{\alpha}_{t}}} \epsilon_{\theta}\left(x_{t}, t\right)\right)+\sigma_{t} z-\frac{1-\alpha_{t-1}}{\sqrt{1-\bar{\alpha}_{t-1}}} \epsilon_{\theta}\left(x_{t-1}, t-1\right)\right)+\sigma_{t-1} z xt−2=αt−11(αt1(xt−1−αˉt1−αtϵθ(xt,t))+σtz−1−αˉt−11−αt−1ϵθ(xt−1,t−1))+σt−1z
由于跳过了求出 x t − 1 的步骤,那我们就没有 x t − 1 ,也就不能求出 ϵ θ ( x t − 1 , t − 1 ) \text { 由于跳过了求出 } x_{t-1} \text { 的步骤,那我们就没有 } x_{t-1} \text { ,也就不能求出 } \epsilon_{\theta}\left(x_{t-1}, t-1\right) 由于跳过了求出 xt−1 的步骤,那我们就没有 xt−1 ,也就不能求出 ϵθ(xt−1,t−1)
由此得出了论文的一个背景信息,如何想出一种算法来进行跳步的推导呢?
DDIM背后的数学原理
我们在ddpm的公式推导中,其实是使用了马尔科夫假设的:即当前时刻的状态只与前一个时刻的状态是有关系的。
我们的加噪和去噪的过程就可以表示为下面的形式了。
q ( x t ∣ x t − 1 , x t − 2 , ⋯ , x 0 ) = q ( x t ∣ x t − 1 ) P ( x t − 1 ∣ x t , x t + 1 , ⋯ , x T ) = P ( x t − 1 ∣ x t ) \begin{array}{l} q\left(x_{t} \mid x_{t-1}, x_{t-2}, \cdots, x_{0}\right)=q\left(x_{t} \mid x_{t-1}\right) \\ P\left(x_{t-1} \mid x_{t}, x_{t+1}, \cdots, x_{T}\right)=P\left(x_{t-1} \mid x_{t}\right) \end{array} q(xt∣xt−1,xt−2,⋯,x0)=q(xt∣xt−1)P(xt−1∣xt,xt+1,⋯,xT)=P(xt−1∣xt)
q ( x 1 : T ∣ x 0 ) = ∏ t = 1 T q ( x t ∣ x t − 1 ) P ( x 0 : T ) = P ( x T ) ∏ t = 1 T P ( x t − 1 ∣ x t ) \begin{array}{l} q\left(x_{1: T} \mid x_{0}\right)=\prod_{t=1}^{T} q\left(x_{t} \mid x_{t-1}\right) \\ P\left(x_{0: T}\right)=P\left(x_{T}\right) \prod_{t=1}^{T} P\left(x_{t-1} \mid x_{t}\right) \end{array} q(x1:T∣x0)=∏t=1Tq(xt∣xt−1)P(x0:T)=P(xT)∏t=1TP(xt−1∣xt)
对这些分解,我们可以理解成,因为有了马尔可夫假设这个"规则”,才可以将扩散和逆扩散链分解成这样那换句话说,是否存在某一种"规则”,可以将扩散链分解成这样(冒号等于表示将右边赋值给左边)
q ( x 1 : T ∣ x 0 ) : = q ( x T ∣ x 0 ) ∏ t = 2 T q ( x t − 1 ∣ x t , x 0 ) q\left(x_{1: T} \mid x_{0}\right):=q\left(x_{T} \mid x_{0}\right) \prod_{t=2}^{T} q\left(x_{t-1} \mid x_{t}, x_{0}\right) q(x1:T∣x0):=q(xT∣x0)t=2∏Tq(xt−1∣xt,x0)
跳步去噪(采样)的构造
- 为了简单起见,我们先假设现在只有9个时刻,t=0对应原始图像
- 对除了0的所有时刻, 我记为S ∈ {1,2,3, 4,5,6, 7,8}, 并且A ∈ {2,5,8}、B ∈ {1,3,4,6, 7}
- 也就是说,S包含了所有时刻,A和B是S的子集,分别包含了部分时刻,并且A+B=S
- 因此,我们假设有一个规则,可以将扩散链分解成这样:
q ( x 1 : 8 ∣ x 0 ) = q ( x 8 ∣ x 0 ) ∏ i ∈ A q ( x i − 1 ∣ x i , x 0 ) ∏ j ∈ B q ( x j ∣ x 0 ) q\left(x_{1: 8} \mid x_{0}\right)=q\left(x_{8} \mid x_{0}\right) \prod_{i \in A} q\left(x_{i-1} \mid x_{i}, x_{0}\right) \prod_{j \in B} q\left(x_{j} \mid x_{0}\right) q(x1:8∣x0)=q(x8∣x0)i∈A∏q(xi−1∣xi,x0)j∈B∏q(xj∣x0)
我们将上面的举例进行一个拓展的操作,就可以得出下面的结论
q ( x 1 : T ∣ x 0 ) = q ( x T ∣ x 0 ) ∏ i ∈ A q ( x i − 1 ∣ x i , x 0 ) ∏ j ∈ B q ( x j ∣ x 0 ) q\left(x_{1: T} \mid x_{0}\right)=q\left(x_{T} \mid x_{0}\right) \prod_{i \in A} q\left(x_{i-1} \mid x_{i}, x_{0}\right) \prod_{j \in B} q\left(x_{j} \mid x_{0}\right) q(x1:T∣x0)=q(xT∣x0)i∈A∏q(xi−1∣xi,x0)j∈B∏q(xj∣x0)
其中A不再代表是{2,5,8},而是代表一大串跳步的序列。B也相应的代表其补足
根据上面的说明我们的加噪过程其实也是可以进行分解的。
P ( x 0 : T ) = P ( x T ) ∏ i ∈ A P ( x i − 1 ∣ x i ) ∏ j ∈ B P ( x 0 ∣ x j ) P\left(x_{0: T}\right)=P\left(x_{T}\right) \prod_{i \in A} P\left(x_{i-1} \mid x_{i}\right) \prod_{j \in B} P\left(x_{0} \mid x_{j}\right) P(x0:T)=P(xT)i∈A∏P(xi−1∣xi)j∈B∏P(x0∣xj)
在变分下界之前,DDIM的所有步骤都和DDPM的一样。那么我就从变分下界开始推了
过程推导
我们就从求解损失函数求期望的那一步开始来进行推导和说明。将之前的分解带入来进行化简和说明。
log P ( x 0 ) ≥ ∫ log P ( x 0 : T ) q ( x 1 : T ∣ x 0 ) q ( x 1 : T ∣ x 0 ) d x 1 : T = E q [ log P ( x 0 : T ) q ( x 1 : T ∣ x 0 ) ] = E q [ log P ( x T ) ∏ i ∈ A P ( x i − 1 ∣ x i ) ∏ j ∈ B P ( x 0 ∣ x j ) q ( x T ∣ x 0 ) ∏ i ∈ A q ( x i − 1 ∣ x i , x 0 ) ∏ j ∈ B q ( x j ∣ x 0 ) ] \begin{aligned} \log P\left(x_{0}\right) & \geq \int \log \frac{P\left(x_{0: T}\right)}{q\left(x_{1: T} \mid x_{0}\right)} q\left(x_{1: T} \mid x_{0}\right) d x_{1: T} \\ & =\mathbb{E}_{q}\left[\log \frac{P\left(x_{0: T}\right)}{q\left(x_{1: T} \mid x_{0}\right)}\right] \\ & =\mathbb{E}_{q}\left[\log \frac{P\left(x_{T}\right) \prod_{i \in A} P\left(x_{i-1} \mid x_{i}\right) \prod_{j \in B} P\left(x_{0} \mid x_{j}\right)}{q\left(x_{T} \mid x_{0}\right) \prod_{i \in A} q\left(x_{i-1} \mid x_{i}, x_{0}\right) \prod_{j \in B} q\left(x_{j} \mid x_{0}\right)}\right] \end{aligned} logP(x0)≥∫logq(x1:T∣x0)P(x0:T)q(x1:T∣x0)dx1:T=Eq[logq(x1:T∣x0)P(x0:T)]=Eq[logq(xT∣x0)∏i∈Aq(xi−1∣xi,x0)∏j∈Bq(xj∣x0)P(xT)∏i∈AP(xi−1∣xi)∏j∈BP(x0∣xj)]
= E q [ log P ( x T ) q ( x T ∣ x 0 ) + log ∏ i ∈ A P ( x i − 1 ∣ x i ) q ( x i − 1 ∣ x i , x 0 ) + log ∏ j ∈ B P ( x 0 ∣ x j ) q ( x j ∣ x 0 ) ] = E q [ log P ( x T ) q ( x T ∣ x 0 ) + ∑ i ∈ A log P ( x i − 1 ∣ x i ) q ( x i − 1 ∣ x i , x 0 ) + ∑ j ∈ B log P ( x 0 ∣ x j ) q ( x j ∣ x 0 ) ] = E q ( x T ∣ x 0 ) [ log P ( x T ) q ( x T ∣ x 0 ) ] + ∑ i ∈ A E q ( x i − 1 , x i ∣ x 0 ) [ log P ( x i − 1 ∣ x i ) q ( x i − 1 ∣ x i , x 0 ) ] + ∑ j ∈ B E q ( x j ∣ x 0 ) [ log P ( x 0 ∣ x j ) q ( x j ∣ x 0 ) ] = − K L ( q ( x T ∣ x 0 ) ∥ P ( x T ) ) − ∑ i ∈ A E q ( x i ∣ x 0 ) [ K L ( q ( x i − 1 ∣ x i , x 0 ) ∣ ∣ P ( x i − 1 ∣ x i ) ) ] − ∑ j ∈ B K L ( q ( x j ∣ x 0 ) ∥ P ( x 0 ∣ x j ) ) \begin{array}{l} =\mathbb{E}_{q}\left[\log \frac{P\left(x_{T}\right)}{q\left(x_{T} \mid x_{0}\right)}+\log \prod_{i \in A} \frac{P\left(x_{i-1} \mid x_{i}\right)}{q\left(x_{i-1} \mid x_{i}, x_{0}\right)}+\log \prod_{j \in B} \frac{P\left(x_{0} \mid x_{j}\right)}{q\left(x_{j} \mid x_{0}\right)}\right] \\ =\mathbb{E}_{q}\left[\log \frac{P\left(x_{T}\right)}{q\left(x_{T} \mid x_{0}\right)}+\sum_{i \in A} \log \frac{P\left(x_{i-1} \mid x_{i}\right)}{q\left(x_{i-1} \mid x_{i}, x_{0}\right)}+\sum_{j \in B} \log \frac{P\left(x_{0} \mid x_{j}\right)}{q\left(x_{j} \mid x_{0}\right)}\right] \\ =\mathbb{E}_{q\left(x_{T} \mid x_{0}\right)}\left[\log \frac{P\left(x_{T}\right)}{q\left(x_{T} \mid x_{0}\right)}\right]+\sum_{i \in A} \mathbb{E}_{q\left(x_{i-1}, x_{i} \mid x_{0}\right)}\left[\log \frac{P\left(x_{i-1} \mid x_{i}\right)}{q\left(x_{i-1} \mid x_{i}, x_{0}\right)}\right]+\sum_{j \in B} \mathbb{E}_{q\left(x_{j} \mid x_{0}\right)}\left[\log \frac{P\left(x_{0} \mid x_{j}\right)}{q\left(x_{j} \mid x_{0}\right)}\right] \\ =-K L\left(q\left(x_{T} \mid x_{0}\right) \| P\left(x_{T}\right)\right)-\sum_{i \in A} \mathbb{E}_{\mathbf{q}\left(x_{i} \mid x_{0}\right)}\left[K L\left(q\left(x_{i-1} \mid x_{i}, x_{0}\right)| | P\left(x_{i-1} \mid x_{i}\right)\right)\right]-\sum_{j \in B} K L\left(q\left(x_{j} \mid x_{0}\right) \| P\left(x_{0} \mid x_{j}\right)\right) \end{array} =Eq[logq(xT∣x0)P(xT)+log∏i∈Aq(xi−1∣xi,x0)P(xi−1∣xi)+log∏j∈Bq(xj∣x0)P(x0∣xj)]=Eq[logq(xT∣x0)P(xT)+∑i∈Alogq(xi−1∣xi,x0)P(xi−1∣xi)+∑j∈Blogq(xj∣x0)P(x0∣xj)]=Eq(xT∣x0)[logq(xT∣x0)P(xT)]+∑i∈AEq(xi−1,xi∣x0)[logq(xi−1∣xi,x0)P(xi−1∣xi)]+∑j∈BEq(xj∣x0)[logq(xj∣x0)P(x0∣xj)]=−KL(q(xT∣x0)∥P(xT))−∑i∈AEq(xi∣x0)[KL(q(xi−1∣xi,x0)∣∣P(xi−1∣xi))]−∑j∈BKL(q(xj∣x0)∥P(x0∣xj))
最后就转化为了三个KL散度的表达形式了。由于第一项是可以直接计算出来的因此我们来看第二项和第三项的结果。
max ( − ∑ i ∈ A E q ( x i ∣ x 0 ) [ K L ( q ( x i − 1 ∣ x i , x 0 ) ∥ P ( x i − 1 ∣ x i ) ) ] − ∑ j ∈ B K L ( q ( x j ∣ x 0 ) ∥ P ( x 0 ∣ x j ) ) ) \max \left(-\sum_{i \in A} \mathbb{E}_{\mathbf{q}_{\left(x_{i} \mid x_{0}\right)}}\left[K L\left(q\left(x_{i-1} \mid x_{i}, x_{0}\right) \| P\left(x_{i-1} \mid x_{i}\right)\right)\right]-\sum_{j \in B} K L\left(q\left(x_{j} \mid x_{0}\right) \| P\left(x_{0} \mid x_{j}\right)\right)\right) max −i∈A∑Eq(xi∣x0)[KL(q(xi−1∣xi,x0)∥P(xi−1∣xi))]−j∈B∑KL(q(xj∣x0)∥P(x0∣xj))
使用贝叶斯公式来进行展开运算就可以得到:
q ( x i − 1 ∣ x i , x 0 ) = q ( x i ∣ x i − 1 , x 0 ) q ( x i − 1 ∣ x 0 ) q ( x i ∣ x 0 ) = q ( x i ∣ x i − 1 ) q ( x i − 1 ∣ x 0 ) q ( x i ∣ x 0 ) q\left(x_{i-1} \mid x_{i}, x_{0}\right)=\frac{q\left(x_{i} \mid x_{i-1}, x_{0}\right) q\left(x_{i-1} \mid x_{0}\right)}{q\left(x_{i} \mid x_{0}\right)}=\frac{q\left(x_{i} \mid x_{i-1}\right) q\left(x_{i-1} \mid x_{0}\right)}{q\left(x_{i} \mid x_{0}\right)} q(xi−1∣xi,x0)=q(xi∣x0)q(xi∣xi−1,x0)q(xi−1∣x0)=q(xi∣x0)q(xi∣xi−1)q(xi−1∣x0)
只要求出这种离散形式的加噪表达式就证明了其合理性。我们自己假设
q ( x i − 1 ∣ x i , x 0 ) = N ( x i − 1 ∣ k i x i + λ i x 0 , σ i 2 I ) q\left(x_{i-1} \mid x_{i}, x_{0}\right)=N\left(x_{i-1} \mid k_{i} x_{i}+\lambda_{i} x_{0}, \sigma_{i}^{2} I\right) q(xi−1∣xi,x0)=N(xi−1∣kixi+λix0,σi2I)
将中间的过程进行省略我们给出最后求得得一个结果:
q ( x i − 1 ∣ x 0 ) ∼ N ( x i − 1 ∣ ( k i α ˉ i + λ i ) x 0 , ( k i 2 ( 1 − α ˉ i ) + σ i 2 ) I ) q ( x i − 1 ∣ x 0 ) ∼ N ( x i − 1 ∣ α ˉ i − 1 x 0 , ( 1 − α ˉ i − 1 ) I ) \begin{array}{l} q\left(x_{i-1} \mid x_{0}\right) \sim N\left(x_{i-1} \mid\left(k_{i} \sqrt{\bar{\alpha}_{i}}+\lambda_{i}\right) x_{0},\left(k_{i}^{2}\left(1-\bar{\alpha}_{i}\right)+\sigma_{i}^{2}\right) I\right) \\ q\left(x_{i-1} \mid x_{0}\right) \sim N\left(x_{i-1} \mid \sqrt{\bar{\alpha}_{i-1}} x_{0},\left(1-\bar{\alpha}_{i-1}\right) I\right) \end{array} q(xi−1∣x0)∼N(xi−1∣(kiαˉi+λi)x0,(ki2(1−αˉi)+σi2)I)q(xi−1∣x0)∼N(xi−1∣αˉi−1x0,(1−αˉi−1)I)
q ( x i − 1 ∣ x i , x 0 ) ∼ N ( x i − 1 ∣ α ˉ i − 1 x i − 1 − α ˉ i ϵ i α ˉ i + 1 − α ˉ i − 1 − σ i 2 ϵ i , σ i 2 I ) q\left(x_{i-1} \mid x_{i}, x_{0}\right) \sim N\left(x_{i-1} \left\lvert\, \sqrt{\bar{\alpha}_{i-1}} \frac{x_{i}-\sqrt{1-\bar{\alpha}_{i}} \epsilon_{i}}{\sqrt{\bar{\alpha}_{i}}}+\sqrt{1-\bar{\alpha}_{i-1}-\sigma_{i}^{2}} \epsilon_{i}\right., \sigma_{i}^{2} I\right) q(xi−1∣xi,x0)∼N(xi−1 αˉi−1αˉixi−1−αˉiϵi+1−αˉi−1−σi2ϵi,σi2I)
有了上面的概率分布,我们就可以采样了:
x s = α ˉ s x 0 + 1 − α ˉ s − σ 2 x k − α ˉ k x 0 1 − α ˉ k + σ ε x_{s}=\sqrt{\bar{\alpha}_{s}} x_{0}+\sqrt{1-\bar{\alpha}_{s}-\sigma^{2}} \frac{x_{k}-\sqrt{\bar{\alpha}_{k}} x_{0}}{\sqrt{1-\bar{\alpha}_{k}}}+\sigma \varepsilon xs=αˉsx0+1−αˉs−σ21−αˉkxk−αˉkx0+σε
最后我们给出网上得一个总结来总结一下他所提出得这个算法: