论文阅读笔记:Denoising Diffusion Probabilistic Models (2)
接论文阅读笔记:Denoising Diffusion Probabilistic Models (1)
3、论文推理过程
扩散模型的流程如下图所示,可以看出
q
(
x
0
:
T
)
q(x_{0:T})
q(x0:T)为正向加噪音过程,
p
θ
(
x
0
:
T
)
p_{\theta}(x_{0:T})
pθ(x0:T)为逆向去噪过程。可以看出,逆向去噪的末端得到的图上还散布一些噪点。
3.1、一些词的理解
q
(
x
0
)
q(x_0)
q(x0):以MNIST数据集为例,
x
0
x_0
x0表示MNIST数据集中的图像,而
q
(
x
0
)
q(x_0)
q(x0)就表示数据集MNIST中数据集的分布情况。
q
(
x
T
)
q(x_T)
q(xT):
x
T
x^T
xT为正向加噪过程的终点图像,其分布满足
q
(
x
T
)
∼
N
(
α
t
ˉ
⋅
x
0
,
1
−
α
t
ˉ
)
q(x_T)\sim N(\sqrt{\bar{\alpha_t}} \cdot x_{0}, 1-\bar{\alpha_t})
q(xT)∼N(αtˉ⋅x0,1−αtˉ)。
p
(
x
T
)
p(x^T)
p(xT):
x
T
x^T
xT是逆向去噪过程的起点,其对应的分布
p
(
x
T
)
p(x^T)
p(xT)为一个正态分布,
p
(
x
T
)
∼
N
(
0
,
1
)
p(x_T)\sim N(0,1)
p(xT)∼N(0,1)。
3.2、推理过程
正向加噪过程满足马尔可夫性质,因此有公式(1)。
q ( x 0 : T ) = q ( x 0 ) ⋅ ∏ t = 1 T q ( x t ∣ x t − 1 ) = q ( x 0 ) ⋅ q ( x 1 ∣ x 0 ) ⋅ q ( x 2 ∣ x 1 ) … q ( x T ∣ x T − 1 ) q ( x 1 : T ∣ x 0 ) = q ( x 1 ∣ x 0 ) ⋅ q ( x 2 ∣ x 1 ) … q ( x T ∣ x T − 1 ) ) \begin{equation} \begin{split} q(x_{0:T})&=q(x_0)\cdot \prod_{t=1}^{T}{q(x_t|x_{t-1})} \\ &=q(x_0)\cdot q(x_1|x_0)\cdot q(x_2|x_1)\dots q(x_T|x_{T-1}) \\ q(x_{1:T}|x_0)&=q(x_1|x_0)\cdot q(x_2|x_1)\dots q(x_T|x_{T-1})) \end{split} \end{equation} q(x0:T)q(x1:T∣x0)=q(x0)⋅t=1∏Tq(xt∣xt−1)=q(x0)⋅q(x1∣x0)⋅q(x2∣x1)…q(xT∣xT−1)=q(x1∣x0)⋅q(x2∣x1)…q(xT∣xT−1))
逆向去噪过程如公式(2)。
p θ ( x 0 : T ) = p ( x T ) ⋅ ∏ t = 1 T p θ ( x t − 1 ∣ x t ) = p ( x T ) ⋅ p θ ( x T − 1 ∣ x T ) ⋅ p θ ( x T − 2 ∣ x T − 1 ) … p θ ( x 0 ∣ x 1 ) . \begin{equation} \begin{split} p_{\theta}(x_{0:T})&=p(x_T)\cdot \prod_{t=1}^{T}{p_{\theta}(x_{t-1}|x_{t})} \\ &=p(x_T)\cdot p_{\theta}(x_{T-1}|x_T)\cdot p_{\theta}(x_{T-2}|x_{T-1})\dots p_{\theta}(x_{0}|x_{1}). \end{split} \end{equation} pθ(x0:T)=p(xT)⋅t=1∏Tpθ(xt−1∣xt)=p(xT)⋅pθ(xT−1∣xT)⋅pθ(xT−2∣xT−1)…pθ(x0∣x1).
逆向去噪的目标是使得其终点与正向加噪的起点相同,也就是使得
p
θ
(
x
0
)
p_\theta(x_0)
pθ(x0)最大,逆向去噪过程的终点为
x
0
x_0
x0的概率最大。
p
θ
(
x
0
)
=
∫
p
θ
(
x
0
,
x
1
)
d
x
1
(
联合分布概率公式
)
=
∫
p
θ
(
x
1
)
⋅
p
θ
(
x
0
∣
x
1
)
d
x
1
(
贝叶斯概率公式
)
=
∫
(
∫
p
θ
(
x
1
,
x
2
)
d
x
2
)
⋅
p
θ
(
x
0
∣
x
1
)
d
x
1
(
积分套积分
)
=
∬
p
θ
(
x
2
)
⋅
p
θ
(
x
1
∣
x
2
)
⋅
p
θ
(
x
0
∣
x
1
)
d
x
1
d
x
2
(
改写为二重积分
)
=
⋮
=
∫
∫
⋯
∫
p
θ
(
x
T
)
⋅
p
θ
(
x
T
−
1
∣
x
T
)
⋅
p
θ
(
x
T
−
2
∣
x
−
1
)
⋯
p
θ
(
x
0
∣
x
1
)
⋅
d
x
1
d
x
2
⋯
d
x
T
=
∫
p
θ
(
x
0
:
T
)
d
x
1
:
T
(
T
−
1
重积分,其实可以直接一步写到这里
)
=
∫
d
1
:
T
⋅
p
θ
(
x
0
:
T
)
⋅
q
(
x
1
:
T
∣
x
0
)
q
(
x
1
:
T
∣
x
0
)
=
∫
d
x
1
:
T
⋅
q
(
x
1
:
T
∣
x
0
)
⋅
p
θ
(
x
0
:
T
)
q
(
x
1
:
T
∣
x
0
)
=
∫
d
x
1
:
T
⋅
q
(
x
1
:
T
∣
x
0
)
⋅
p
(
x
T
)
⋅
p
θ
(
x
T
−
1
∣
x
T
)
⋅
p
θ
(
x
T
−
2
∣
x
T
−
1
)
⋯
p
θ
(
x
0
∣
x
1
)
q
(
x
1
∣
x
0
)
⋅
q
(
x
2
∣
x
1
)
…
q
(
x
T
∣
x
T
−
1
)
=
∫
d
x
1
:
T
⋅
q
(
x
1
:
T
∣
x
0
)
⋅
p
θ
(
x
T
)
⋅
p
θ
(
x
T
−
1
∣
x
T
)
⋅
p
(
x
T
−
2
∣
x
T
−
1
)
…
p
(
x
0
∣
x
1
)
q
(
x
1
∣
x
0
)
⋅
q
(
x
2
∣
x
1
)
…
q
(
x
T
∣
x
T
−
1
)
=
∫
d
x
1
:
T
⋅
q
(
x
1
,
:
T
∣
x
0
)
⋅
p
θ
(
x
T
)
⋅
∏
t
=
1
T
p
θ
(
x
t
−
1
∣
x
t
)
q
(
x
t
∣
x
t
−
1
)
=
E
x
1
:
T
∼
q
(
x
1
:
T
∣
x
0
)
p
θ
(
x
T
)
⋅
∏
t
=
1
T
p
θ
(
x
t
−
1
∣
x
t
)
q
(
x
t
∣
x
t
−
1
)
(
改写为期望的形式
)
\begin{equation} \begin{split} p_{\theta}(x_0)&=\int p_{\theta}(x_0,x_1)dx_{1} (联合分布概率公式)\\ &=\int p_{\theta}(x_1)\cdot p_{\theta}(x_0|x_1)dx_1 (贝叶斯概率公式) \\ &=\int \Big(\int p_{\theta}(x_1,x_2)dx_2 \Big) \cdot p_{\theta}(x_0|x_1)dx_1 (积分套积分)\\ &=\iint p_{\theta}(x_2)\cdot p_{\theta}(x_1|x_2) \cdot p_{\theta}(x_0|x_1)dx_1 dx_2(改写为二重积分)\\ &= \vdots \\ &= \int \int \cdots \int p_{\theta}(x_T)\cdot p_{\theta}(x_{T-1}|x_{T})\cdot p_{\theta}(x_{T-2}|x_{-1})\cdots p_{\theta}(x_0|x_1) \cdot dx_1 dx_2 \cdots dx_T \\ &= \int p_{\theta}(x_{0:T})dx_{1:T} (T-1重积分,其实可以直接一步写到这里) \\ &= \int d_{1:T} \cdot p_{\theta}(x_{0:T}) \cdot \frac{q(x_{1:T} | x_0)}{q(x_{1:T}|x_0)} \\ &= \int dx_{1:T} \cdot q(x_{1:T} | x_0) \cdot \frac{ p_{\theta}(x_{0:T}) }{q(x_{1:T}|x_0)} \\ &= \int dx_{1:T} \cdot q(x_{1:T} | x_0) \cdot \frac{p(x_T)\cdot p_{\theta}(x_{T-1}|x_T)\cdot p_{\theta}(x_{T-2}|x_{T-1})\cdots p_{\theta}(x_{0}|x_{1})}{q(x^1|x^0)\cdot q(x_2|x_1)\dots q(x_T|x_{T-1})} \\ &= \int dx_{1:T} \cdot q(x_{1:T}| x_0) \cdot p_{\theta}(x_T)\cdot \frac{ p_{\theta}(x_{T-1}|x_T)\cdot p(x_{T-2}|x_{T-1})\dots p(x_{0}|x_{1})}{q(x_1|x_0)\cdot q(x_2|x_1)\dots q(x_T|x_{T-1})} \\ &= \int dx_{1:T} \cdot q(x_{1,:T}| x_0) \cdot p_{\theta}(x_T)\cdot \prod_{t=1}^{T} \frac{ p_{\theta}(x_{t-1}|x_t)}{q(x_t|x_{t-1})} \\ &= E_{x_{1:T} \sim q(x_{1:T} | x_0)} p_{\theta}(x_T)\cdot \prod_{t=1}^{T} \frac{ p_{\theta}(x_{t-1}|x_t)}{q(x_t|x_{t-1})} (改写为期望的形式)\\ \end{split} \end{equation}
pθ(x0)=∫pθ(x0,x1)dx1(联合分布概率公式)=∫pθ(x1)⋅pθ(x0∣x1)dx1(贝叶斯概率公式)=∫(∫pθ(x1,x2)dx2)⋅pθ(x0∣x1)dx1(积分套积分)=∬pθ(x2)⋅pθ(x1∣x2)⋅pθ(x0∣x1)dx1dx2(改写为二重积分)=⋮=∫∫⋯∫pθ(xT)⋅pθ(xT−1∣xT)⋅pθ(xT−2∣x−1)⋯pθ(x0∣x1)⋅dx1dx2⋯dxT=∫pθ(x0:T)dx1:T(T−1重积分,其实可以直接一步写到这里)=∫d1:T⋅pθ(x0:T)⋅q(x1:T∣x0)q(x1:T∣x0)=∫dx1:T⋅q(x1:T∣x0)⋅q(x1:T∣x0)pθ(x0:T)=∫dx1:T⋅q(x1:T∣x0)⋅q(x1∣x0)⋅q(x2∣x1)…q(xT∣xT−1)p(xT)⋅pθ(xT−1∣xT)⋅pθ(xT−2∣xT−1)⋯pθ(x0∣x1)=∫dx1:T⋅q(x1:T∣x0)⋅pθ(xT)⋅q(x1∣x0)⋅q(x2∣x1)…q(xT∣xT−1)pθ(xT−1∣xT)⋅p(xT−2∣xT−1)…p(x0∣x1)=∫dx1:T⋅q(x1,:T∣x0)⋅pθ(xT)⋅t=1∏Tq(xt∣xt−1)pθ(xt−1∣xt)=Ex1:T∼q(x1:T∣x0)pθ(xT)⋅t=1∏Tq(xt∣xt−1)pθ(xt−1∣xt)(改写为期望的形式)
因此公式3中的参数
θ
\theta
θ应满足
θ
=
a
r
g
max
θ
p
θ
(
x
0
)
.
\begin{equation} \theta= arg \underset {\theta}{\text{max}} p_{\theta}(x^0). \end{equation}
θ=argθmaxpθ(x0).
公式4是对数据集中的一张图片进行求解,然而数据集中通常是有成千上万张图像的。假设数据集中有
N
N
N张图像,因此有公式6,其目的是求得一组参数
θ
\theta
θ,使得
L
L
L取得最大值。值得注意的是
q
(
x
0
)
q(x^0)
q(x0)表示数据集中每张图片被采样出来的概率。
为了防止边缘效应,在本文中令
p
(
x
1
∣
x
0
)
=
q
(
x
1
∣
x
0
)
p(x^1|x^{0})=q(x^1|x^{0})
p(x1∣x0)=q(x1∣x0).
L
:
=
−
l
o
g
[
p
(
x
0
)
]
=
−
l
o
g
[
E
x
1
,
2
,
⋯
T
∼
q
(
x
1
,
2
⋯
T
∣
x
0
)
p
(
x
T
)
⋅
∏
t
=
1
T
p
(
x
t
−
1
∣
x
t
)
q
(
x
t
∣
x
t
−
1
)
]
≤
−
E
x
1
,
2
,
⋯
T
∼
q
(
x
1
,
2
⋯
T
∣
x
0
)
(
l
o
g
[
p
(
x
T
)
⋅
∏
t
=
1
T
p
(
x
t
−
1
∣
x
t
)
q
(
x
t
∣
x
t
−
1
)
]
)
=
−
E
x
1
,
2
,
⋯
T
∼
q
(
x
1
,
2
⋯
T
∣
x
0
)
(
l
o
g
[
p
(
x
T
)
]
+
∑
t
=
1
T
l
o
g
[
p
(
x
t
−
1
∣
x
t
)
q
(
x
t
∣
x
t
−
1
)
]
)
=
−
E
x
1
,
2
,
⋯
T
∼
q
(
x
1
,
2
⋯
T
∣
x
0
)
(
l
o
g
[
p
(
x
T
)
]
+
l
o
g
[
p
(
x
0
∣
x
1
)
q
(
x
1
∣
x
0
)
]
+
∑
t
=
2
T
l
o
g
[
p
(
x
t
−
1
∣
x
t
)
q
(
x
t
∣
x
t
−
1
)
]
)
=
−
E
x
1
,
2
,
⋯
T
∼
q
(
x
1
,
2
⋯
T
∣
x
0
)
(
l
o
g
[
p
(
x
T
)
]
+
l
o
g
[
p
(
x
0
∣
x
1
)
p
(
x
1
∣
x
0
)
⏟
p
(
x
1
∣
x
0
)
=
q
(
x
1
∣
x
0
)
]
+
∑
t
=
2
T
l
o
g
[
p
(
x
t
−
1
∣
x
t
)
q
(
x
t
∣
x
t
−
1
,
x
0
)
⏟
q
(
x
t
∣
x
t
−
1
)
=
q
(
x
t
∣
x
t
−
1
,
x
0
)
]
)
=
−
E
x
1
,
2
,
⋯
T
∼
q
(
x
1
,
2
⋯
T
∣
x
0
)
(
l
o
g
[
p
(
x
T
)
]
+
l
o
g
[
p
(
x
0
∣
x
1
)
p
(
x
1
∣
x
0
)
]
+
∑
t
=
2
T
l
o
g
[
p
(
x
t
−
1
∣
x
t
)
q
(
x
t
,
x
t
−
1
,
x
0
)
⋅
q
(
x
t
−
1
,
x
0
)
⋅
q
(
x
0
)
q
(
x
0
)
⋅
q
(
x
t
,
x
0
)
q
(
x
t
,
x
0
)
⏟
q
(
x
t
∣
x
t
−
1
,
x
0
)
=
q
(
x
t
,
x
t
−
1
,
x
0
)
q
(
x
t
−
1
,
x
0
)
]
)
=
−
E
x
1
,
2
,
⋯
T
∼
q
(
x
1
,
2
⋯
T
∣
x
0
)
(
l
o
g
[
p
(
x
T
)
]
+
l
o
g
[
p
(
x
0
∣
x
1
)
p
(
x
1
∣
x
0
)
]
+
∑
t
=
2
T
l
o
g
[
p
(
x
t
−
1
∣
x
t
)
q
(
x
t
−
1
∣
x
t
,
x
0
)
⋅
q
(
x
t
−
1
,
x
0
)
q
(
x
0
)
⋅
q
(
x
0
)
q
(
x
t
,
x
0
)
⏟
q
(
x
t
,
x
t
−
1
,
x
0
)
=
q
(
x
t
,
x
0
)
⋅
q
(
x
t
−
1
∣
x
t
,
x
0
)
]
)
=
−
E
x
1
,
2
,
⋯
T
∼
q
(
x
1
,
2
⋯
T
∣
x
0
)
(
l
o
g
[
p
(
x
T
)
]
+
l
o
g
[
p
(
x
0
∣
x
1
)
p
(
x
1
∣
x
0
)
]
+
∑
t
=
2
T
l
o
g
[
p
(
x
t
−
1
∣
x
t
)
q
(
x
t
−
1
∣
x
t
,
x
0
)
⋅
q
(
x
t
−
1
∣
x
0
)
q
(
x
t
∣
x
0
)
⏟
q
(
x
t
−
1
,
x
0
)
=
q
(
x
0
)
⋅
q
(
x
t
−
1
∣
x
0
)
;
q
(
x
t
,
x
0
)
=
q
(
x
0
)
⋅
q
(
x
t
∣
x
0
)
]
)
=
−
E
x
1
,
2
,
⋯
T
∼
q
(
x
1
,
2
⋯
T
∣
x
0
)
(
l
o
g
[
p
(
x
T
)
]
+
l
o
g
[
p
(
x
0
∣
x
1
)
p
(
x
1
∣
x
0
)
]
+
∑
t
=
2
T
l
o
g
[
p
(
x
t
−
1
∣
x
t
)
q
(
x
t
−
1
∣
x
t
,
x
0
)
]
+
∑
t
=
2
T
l
o
g
[
q
(
x
t
−
1
∣
x
0
)
q
(
x
t
∣
x
0
)
]
)
=
−
E
x
1
,
2
,
⋯
T
∼
q
(
x
1
,
2
⋯
T
∣
x
0
)
(
l
o
g
[
p
(
x
T
)
]
+
l
o
g
[
p
(
x
0
∣
x
1
)
p
(
x
1
∣
x
0
)
]
+
∑
t
=
2
T
l
o
g
[
p
(
x
t
−
1
∣
x
t
)
q
(
x
t
−
1
∣
x
t
,
x
0
)
]
+
l
o
g
[
q
(
x
1
∣
x
0
)
q
(
x
2
∣
x
0
)
⋅
q
(
x
2
∣
x
0
)
q
(
x
3
∣
x
0
)
⋯
q
(
x
T
−
1
∣
x
0
)
q
(
x
T
∣
x
0
)
]
)
=
−
E
x
1
,
2
,
⋯
T
∼
q
(
x
1
,
2
⋯
T
∣
x
0
)
(
l
o
g
[
p
(
x
T
)
]
+
l
o
g
[
p
(
x
0
∣
x
1
)
p
(
x
1
∣
x
0
)
]
+
∑
t
=
2
T
l
o
g
[
p
(
x
t
−
1
∣
x
t
)
q
(
x
t
−
1
∣
x
t
,
x
0
)
]
+
l
o
g
[
q
(
x
1
∣
x
0
)
q
(
x
T
∣
x
0
)
]
)
=
−
E
x
1
,
2
,
⋯
T
∼
q
(
x
1
,
2
⋯
T
∣
x
0
)
(
l
o
g
[
p
(
x
T
)
q
(
x
T
∣
x
0
)
]
+
∑
t
=
2
T
l
o
g
[
p
(
x
t
−
1
∣
x
t
)
q
(
x
t
−
1
∣
x
t
,
x
0
)
]
+
l
o
g
[
p
(
x
1
∣
x
0
)
]
⏟
l
o
g
[
p
(
x
T
)
+
l
o
g
[
p
(
x
0
∣
x
1
)
p
(
x
1
∣
x
0
)
]
]
+
l
o
g
[
q
(
x
1
∣
x
0
)
q
(
x
T
∣
x
0
)
]
=
l
o
g
[
p
(
x
T
)
⋅
p
(
x
0
∣
x
1
)
p
(
x
1
∣
x
0
)
⋅
q
(
x
1
∣
x
0
)
q
(
x
T
∣
x
0
)
]
)
=
−
E
x
1
,
2
,
⋯
T
∼
q
(
x
1
,
2
⋯
T
∣
x
0
)
(
l
o
g
[
p
(
x
T
)
q
(
x
T
∣
x
0
)
]
)
−
E
x
1
,
2
,
⋯
T
∼
q
(
x
1
,
2
⋯
T
∣
x
0
)
(
∑
t
=
2
T
l
o
g
[
p
(
x
t
−
1
∣
x
t
)
q
(
x
t
−
1
∣
x
t
,
x
0
)
]
)
−
E
x
1
,
2
,
⋯
T
∼
q
(
x
1
,
2
⋯
T
∣
x
0
)
(
l
o
g
[
p
(
x
0
∣
x
1
)
]
)
=
E
x
1
,
2
,
⋯
T
∼
q
(
x
1
,
2
⋯
T
∣
x
0
)
(
l
o
g
[
q
(
x
T
∣
x
0
)
p
(
x
T
)
]
)
+
E
x
1
,
2
,
⋯
T
∼
q
(
x
1
,
2
⋯
T
∣
x
0
)
(
∑
t
=
2
T
l
o
g
[
q
(
x
t
−
1
∣
x
t
,
x
0
)
p
(
x
t
−
1
∣
x
t
)
]
)
−
E
x
1
,
2
,
⋯
T
∼
q
(
x
1
,
2
⋯
T
∣
x
0
)
(
l
o
g
[
p
(
x
0
∣
x
1
)
]
)
=
E
x
1
,
2
,
⋯
T
∼
q
(
x
1
,
2
⋯
T
∣
x
0
)
(
l
o
g
[
q
(
x
T
∣
x
0
)
p
(
x
T
)
]
)
⏟
L
1
+
E
x
1
,
2
,
⋯
T
∼
q
(
x
1
,
2
⋯
T
∣
x
0
)
(
∑
t
=
2
T
l
o
g
[
q
(
x
t
−
1
∣
x
t
,
x
0
)
p
(
x
t
−
1
∣
x
t
)
]
)
⏟
L
2
−
l
o
g
[
p
(
x
0
∣
x
1
)
]
⏟
L
3
:
常数么
?
\begin{equation} \begin{split} L&:=- log\Big[p(x^0)\Big] \\ &= -log \Big[ E_{x^{1,2, \cdots T} \sim q(x^{1,2 \cdots T} | x^0)} p(x^T)\cdot \prod_{t=1}^{T} \frac{ p(x^{t-1}|x^t)}{q(x^t|x^{t-1})}\Big] \\ & \leq -E_{x^{1,2, \cdots T} \sim q(x^{1,2 \cdots T} | x^0)} \bigg( log [p(x^T)\cdot \prod_{t=1}^{T} \frac{ p(x^{t-1}|x^t)}{q(x^t|x^{t-1})}]\bigg)\\ &= -E_{x^{1,2, \cdots T} \sim q(x^{1,2 \cdots T} | x^0)} \bigg( log [p(x^T)]+\sum_{t=1}^{T} log \Big[ \frac{ p(x^{t-1}|x^t)}{q(x^t|x^{t-1})}\Big]\bigg)\\ &= -E_{x^{1,2, \cdots T} \sim q(x^{1,2 \cdots T} | x^0)} \bigg( log [p(x^T)]+ log\Big[\frac{ p(x^{0}|x^1)}{q(x^1|x^{0})} \Big]+\sum_{t=2}^{T} log \Big[ \frac{ p(x^{t-1}|x^t)}{q(x^t|x^{t-1})}\Big] \bigg)\\ &= -E_{x^{1,2, \cdots T} \sim q(x^{1,2 \cdots T} | x^0)} \bigg( log [p(x^T)]+ log\Big[\frac{ p(x^{0}|x^1)}{\underbrace{ p(x^1|x^{0})}_{p(x^1|x^{0})=q(x^1|x^{0})}} \Big]+\sum_{t=2}^{T} log \Big[\underbrace{ \frac{ p(x^{t-1}|x^t)}{q(x^t|x^{t-1},x^0)}}_{q(x^t|x^{t-1})=q(x^t|x^{t-1},x^0)}\Big] \bigg)\\ &= -E_{x^{1,2, \cdots T} \sim q(x^{1,2 \cdots T} | x^0)} \Bigg( log [p(x^T)]+ log\Big[\frac{ p(x^{0}|x^1)}{p(x^1|x^{0})} \Big]+\sum_{t=2}^{T} log \Big[\underbrace{ \frac{ p(x^{t-1}|x^t)}{q(x^t,x^{t-1},x^0)} \cdot q(x^{t-1}, x^0) \cdot \frac{q(x^0)}{q(x^0)}\cdot \frac{q(x^t,x^0)}{q(x^t,x^0)}}_{ q(x^t|x^{t-1},x^0)=\frac{q(x^t,x^{t-1},x^0)}{q(x^{t-1},x^0)}}\Big] \Bigg)\\ &= -E_{x^{1,2, \cdots T} \sim q(x^{1,2 \cdots T} | x^0)} \Bigg( log [p(x^T)]+ log\Big[\frac{ p(x^{0}|x^1)}{p(x^1|x^{0})} \Big]+\sum_{t=2}^{T} log \Big[\underbrace{ \frac{ p(x^{t-1}|x^t)}{q(x^{t-1}|x^t,x^0)} \cdot \frac{q(x^{t-1}, x^0) }{q(x^0)}\cdot \frac{ q(x^0)}{q(x^t,x^0)}}_{q(x^t,x^{t-1},x^0)= q(x^t,x^0) \cdot q(x^{t-1}|x^t,x^0)}\Big] \Bigg)\\ &= -E_{x^{1,2, \cdots T} \sim q(x^{1,2 \cdots T} | x^0)} \Bigg( log [p(x^T)]+ log\Big[\frac{ p(x^{0}|x^1)}{p(x^1|x^{0})} \Big]+\sum_{t=2}^{T} log \Big[\underbrace{ \frac{ p(x^{t-1}|x^t)}{q(x^{t-1}|x^t,x^0)} \cdot \frac{q(x^{t-1}| x^0) }{q(x^{t}|x^0)}}_{q(x^{t-1},x^0)=q(x^0) \cdot q(x^{t-1}|x^0) ; q(x^{t},x^0)=q(x^0) \cdot q(x^{t}|x^0)}\Big] \Bigg)\\ &= -E_{x^{1,2, \cdots T} \sim q(x^{1,2 \cdots T} | x^0)} \Bigg( log [p(x^T)]+ log\Big[\frac{ p(x^{0}|x^1)}{p(x^1|x^{0})} \Big]+\sum_{t=2}^{T} log \Big[\frac{ p(x^{t-1}|x^t)}{q(x^{t-1}|x^t,x^0)} \Big] + \sum_{t=2}^{T} log \Big[\frac{q(x^{t-1}| x^0) }{q(x^{t}|x^0)}\Big] \Bigg)\\ &= -E_{x^{1,2, \cdots T} \sim q(x^{1,2 \cdots T} | x^0)} \Bigg( log [p(x^T)]+ log\Big[\frac{ p(x^{0}|x^1)}{p(x^1|x^{0})} \Big]+\sum_{t=2}^{T} log \Big[\frac{ p(x^{t-1}|x^t)}{q(x^{t-1}|x^t,x^0)} \Big] + log \Big[\frac{q(x^{1}| x^0) }{q(x^{2}|x^0)} \cdot \frac{q(x^{2}| x^0) }{q(x^{3}|x^0)}\cdots \frac{q(x^{T-1}| x^0) }{q(x^{T}|x^0)}\Big] \Bigg)\\ &= -E_{x^{1,2, \cdots T} \sim q(x^{1,2 \cdots T} | x^0)} \Bigg( log [p(x^T)]+ log\Big[\frac{ p(x^{0}|x^1)}{p(x^1|x^{0})} \Big]+\sum_{t=2}^{T} log \Big[\frac{ p(x^{t-1}|x^t)}{q(x^{t-1}|x^t,x^0)} \Big] + log \Big[\frac{q(x^{1}| x^0) }{q(x^{T}|x^0)}\Big] \Bigg)\\ &= -E_{x^{1,2, \cdots T} \sim q(x^{1,2 \cdots T} | x^0)} \Bigg(\underbrace{log \Big[\frac{p(x^T)}{q(x^{T}|x^0)}\Big]+\sum_{t=2}^{T} log \Big[\frac{ p(x^{t-1}|x^t)}{q(x^{t-1}|x^t,x^0)} \Big] + log \Big[p(x^{1}|x^0)\Big] }_{log [p(x^T)+log\Big[\frac{ p(x^{0}|x^1)}{p(x^1|x^{0})} \Big]]+ log \Big[\frac{q(x^{1}| x^0) }{q(x^{T}|x^0)}\Big]=log\bigg[p(x^T) \cdot \frac{ p(x^{0}|x^1)}{\bcancel{p(x^1|x^{0})}} \cdot \frac{\bcancel{q(x^{1}| x^0) }}{q(x^{T}|x^0)} \bigg]}\Bigg)\\ &= -E_{x^{1,2, \cdots T} \sim q(x^{1,2 \cdots T} | x^0)} \Bigg(log \Big[ \frac{ p(x^T)}{q(x^{T}|x^0)}\Big]\Bigg)-E_{x^{1,2, \cdots T} \sim q(x^{1,2 \cdots T} | x^0)} \Bigg(\sum_{t=2}^{T} log \Big[\frac{ p(x^{t-1}|x^t)}{q(x^{t-1}|x^t,x^0)} \Big]\Bigg) - E_{x^{1,2, \cdots T} \sim q(x^{1,2 \cdots T} | x^0)} \Bigg( log \Big[p(x^{0}|x^1)\Big] \Bigg)\\ &= E_{x^{1,2, \cdots T} \sim q(x^{1,2 \cdots T} | x^0)} \Bigg(log \Big[ \frac{q(x^{T}|x^0)}{ p(x^T)}\Big]\Bigg)+E_{x^{1,2, \cdots T} \sim q(x^{1,2 \cdots T} | x^0)} \Bigg(\sum_{t=2}^{T} log \Big[\frac{q(x^{t-1}|x^t,x^0)}{ p(x^{t-1}|x^t)} \Big]\Bigg) - E_{x^{1,2, \cdots T} \sim q(x^{1,2 \cdots T} | x^0)} \Bigg( log \Big[p(x^{0}|x^1)\Big] \Bigg)\\ &= \underbrace{E_{x^{1,2, \cdots T} \sim q(x^{1,2 \cdots T} | x^0)} \Bigg(log \Big[ \frac{q(x^{T}|x^0)}{ p(x^T)}\Big]\Bigg)}_{L_1}+\underbrace{E_{x^{1,2, \cdots T} \sim q(x^{1,2 \cdots T} | x^0)} \Bigg(\sum_{t=2}^{T} log \Big[\frac{q(x^{t-1}|x^t,x^0)}{ p(x^{t-1}|x^t)} \Big]\Bigg)}_{L_2} - \underbrace{log \Big[p(x^{0}|x^1)\Big]}_{L_3:常数么?} \\ \end{split} \end{equation}
L:=−log[p(x0)]=−log[Ex1,2,⋯T∼q(x1,2⋯T∣x0)p(xT)⋅t=1∏Tq(xt∣xt−1)p(xt−1∣xt)]≤−Ex1,2,⋯T∼q(x1,2⋯T∣x0)(log[p(xT)⋅t=1∏Tq(xt∣xt−1)p(xt−1∣xt)])=−Ex1,2,⋯T∼q(x1,2⋯T∣x0)(log[p(xT)]+t=1∑Tlog[q(xt∣xt−1)p(xt−1∣xt)])=−Ex1,2,⋯T∼q(x1,2⋯T∣x0)(log[p(xT)]+log[q(x1∣x0)p(x0∣x1)]+t=2∑Tlog[q(xt∣xt−1)p(xt−1∣xt)])=−Ex1,2,⋯T∼q(x1,2⋯T∣x0)(log[p(xT)]+log[p(x1∣x0)=q(x1∣x0)
p(x1∣x0)p(x0∣x1)]+t=2∑Tlog[q(xt∣xt−1)=q(xt∣xt−1,x0)
q(xt∣xt−1,x0)p(xt−1∣xt)])=−Ex1,2,⋯T∼q(x1,2⋯T∣x0)(log[p(xT)]+log[p(x1∣x0)p(x0∣x1)]+t=2∑Tlog[q(xt∣xt−1,x0)=q(xt−1,x0)q(xt,xt−1,x0)
q(xt,xt−1,x0)p(xt−1∣xt)⋅q(xt−1,x0)⋅q(x0)q(x0)⋅q(xt,x0)q(xt,x0)])=−Ex1,2,⋯T∼q(x1,2⋯T∣x0)(log[p(xT)]+log[p(x1∣x0)p(x0∣x1)]+t=2∑Tlog[q(xt,xt−1,x0)=q(xt,x0)⋅q(xt−1∣xt,x0)
q(xt−1∣xt,x0)p(xt−1∣xt)⋅q(x0)q(xt−1,x0)⋅q(xt,x0)q(x0)])=−Ex1,2,⋯T∼q(x1,2⋯T∣x0)(log[p(xT)]+log[p(x1∣x0)p(x0∣x1)]+t=2∑Tlog[q(xt−1,x0)=q(x0)⋅q(xt−1∣x0);q(xt,x0)=q(x0)⋅q(xt∣x0)
q(xt−1∣xt,x0)p(xt−1∣xt)⋅q(xt∣x0)q(xt−1∣x0)])=−Ex1,2,⋯T∼q(x1,2⋯T∣x0)(log[p(xT)]+log[p(x1∣x0)p(x0∣x1)]+t=2∑Tlog[q(xt−1∣xt,x0)p(xt−1∣xt)]+t=2∑Tlog[q(xt∣x0)q(xt−1∣x0)])=−Ex1,2,⋯T∼q(x1,2⋯T∣x0)(log[p(xT)]+log[p(x1∣x0)p(x0∣x1)]+t=2∑Tlog[q(xt−1∣xt,x0)p(xt−1∣xt)]+log[q(x2∣x0)q(x1∣x0)⋅q(x3∣x0)q(x2∣x0)⋯q(xT∣x0)q(xT−1∣x0)])=−Ex1,2,⋯T∼q(x1,2⋯T∣x0)(log[p(xT)]+log[p(x1∣x0)p(x0∣x1)]+t=2∑Tlog[q(xt−1∣xt,x0)p(xt−1∣xt)]+log[q(xT∣x0)q(x1∣x0)])=−Ex1,2,⋯T∼q(x1,2⋯T∣x0)(log[p(xT)+log[p(x1∣x0)p(x0∣x1)]]+log[q(xT∣x0)q(x1∣x0)]=log[p(xT)⋅p(x1∣x0)
p(x0∣x1)⋅q(xT∣x0)q(x1∣x0)
]
log[q(xT∣x0)p(xT)]+t=2∑Tlog[q(xt−1∣xt,x0)p(xt−1∣xt)]+log[p(x1∣x0)])=−Ex1,2,⋯T∼q(x1,2⋯T∣x0)(log[q(xT∣x0)p(xT)])−Ex1,2,⋯T∼q(x1,2⋯T∣x0)(t=2∑Tlog[q(xt−1∣xt,x0)p(xt−1∣xt)])−Ex1,2,⋯T∼q(x1,2⋯T∣x0)(log[p(x0∣x1)])=Ex1,2,⋯T∼q(x1,2⋯T∣x0)(log[p(xT)q(xT∣x0)])+Ex1,2,⋯T∼q(x1,2⋯T∣x0)(t=2∑Tlog[p(xt−1∣xt)q(xt−1∣xt,x0)])−Ex1,2,⋯T∼q(x1,2⋯T∣x0)(log[p(x0∣x1)])=L1
Ex1,2,⋯T∼q(x1,2⋯T∣x0)(log[p(xT)q(xT∣x0)])+L2
Ex1,2,⋯T∼q(x1,2⋯T∣x0)(t=2∑Tlog[p(xt−1∣xt)q(xt−1∣xt,x0)])−L3:常数么?
log[p(x0∣x1)]
可以看出
L
L
L总共氛围了3项,首先考虑第一项
L
1
L_1
L1。
L
1
=
E
x
1
,
2
,
⋯
T
∼
q
(
x
1
,
2
⋯
T
∣
x
0
)
(
l
o
g
[
q
(
x
T
∣
x
0
)
p
(
x
T
)
]
)
=
∫
d
x
1
,
2
⋯
T
⋅
q
(
x
1
,
2
⋯
T
∣
x
0
)
⋅
l
o
g
[
q
(
x
T
∣
x
0
)
p
(
x
T
)
]
=
∫
d
x
1
,
2
⋯
T
⋅
q
(
x
1
,
2
⋯
T
∣
x
0
)
q
(
x
T
∣
x
0
)
⋅
q
(
x
T
∣
x
0
)
⋅
l
o
g
[
q
(
x
T
∣
x
0
)
p
(
x
T
)
]
=
∫
d
x
1
,
2
⋯
T
⋅
q
(
x
1
,
2
⋯
T
−
1
∣
x
0
,
x
T
)
⏟
q
(
x
1
,
2
⋯
T
∣
x
0
)
=
q
(
x
T
∣
x
0
)
⋅
q
(
x
1
,
2
⋯
T
−
1
∣
x
0
,
x
T
)
⋅
q
(
x
T
∣
x
0
)
⋅
l
o
g
[
q
(
x
T
∣
x
0
)
p
(
x
T
)
]
=
∫
(
∫
q
(
x
1
,
2
⋯
T
−
1
∣
x
0
,
x
T
)
⋅
∏
k
=
1
T
−
1
d
x
k
⏟
二重积分化为两个定积分相乘,并且
=
1
)
⋅
q
(
x
T
∣
x
0
)
⋅
l
o
g
[
q
(
x
T
∣
x
0
)
p
(
x
T
)
]
⋅
d
x
T
=
∫
q
(
x
T
∣
x
0
)
⋅
l
o
g
[
q
(
x
T
∣
x
0
)
p
(
x
T
)
]
⋅
d
x
T
=
E
x
T
∼
q
(
x
T
∣
x
0
)
l
o
g
[
q
(
x
T
∣
x
0
)
p
(
x
T
)
]
=
K
L
(
q
(
x
T
∣
x
0
)
∣
∣
p
(
x
T
)
)
\begin{equation} \begin{split} L_1&=E_{x^{1,2, \cdots T} \sim q(x^{1,2 \cdots T} | x^0)} \Bigg(log \Big[ \frac{q(x^{T}|x^0)}{ p(x^T)}\Big]\Bigg) \\ &=\int dx^{1,2\cdots T} \cdot q(x^{1,2 \cdots T}| x^0) \cdot log \Big[ \frac{q(x^{T}|x^0)}{ p(x^T)}\Big] \\ &=\int dx^{1,2\cdots T} \cdot \frac{q(x^{1,2 \cdots T}| x^0)}{q(x^T|x^0)} \cdot q(x^T|x^0) \cdot log \Big[ \frac{q(x^{T}|x^0)}{ p(x^T)}\Big] \\ &=\int dx^{1,2\cdots T} \cdot \underbrace{ q(x^{1,2 \cdots T-1}| x^0, x^T) }_{q(x^{1,2 \cdots T}| x^0)=q(x^{T}|x^0) \cdot q(x^{1,2 \cdots T-1}| x^0, x^T)} \cdot q(x^T|x^0) \cdot log \Big[ \frac{q(x^{T}|x^0)}{ p(x^T)}\Big] \\ &=\int \Bigg( \underbrace{ \int q(x^{1,2 \cdots T-1}| x^0, x^T) \cdot \prod_{k=1}^{T-1} dx^k }_{二重积分化为两个定积分相乘,并且=1} \Bigg) \cdot q(x^T|x^0) \cdot log \Big[ \frac{q(x^{T}|x^0)}{ p(x^T)} \Big] \cdot dx^{T} \\ &=\int q(x^T|x^0) \cdot log \Big[ \frac{q(x^{T}|x^0)}{ p(x^T)} \Big] \cdot dx^{T} \\ &=E_{x^T\sim q(x^T|x^0)} log \Big[ \frac{q(x^{T}|x^0)}{ p(x^T)} \Big]\\ &= KL\Big(q(x^T|x^0)||p(x^T)\Big) \end{split} \end{equation}
L1=Ex1,2,⋯T∼q(x1,2⋯T∣x0)(log[p(xT)q(xT∣x0)])=∫dx1,2⋯T⋅q(x1,2⋯T∣x0)⋅log[p(xT)q(xT∣x0)]=∫dx1,2⋯T⋅q(xT∣x0)q(x1,2⋯T∣x0)⋅q(xT∣x0)⋅log[p(xT)q(xT∣x0)]=∫dx1,2⋯T⋅q(x1,2⋯T∣x0)=q(xT∣x0)⋅q(x1,2⋯T−1∣x0,xT)
q(x1,2⋯T−1∣x0,xT)⋅q(xT∣x0)⋅log[p(xT)q(xT∣x0)]=∫(二重积分化为两个定积分相乘,并且=1
∫q(x1,2⋯T−1∣x0,xT)⋅k=1∏T−1dxk)⋅q(xT∣x0)⋅log[p(xT)q(xT∣x0)]⋅dxT=∫q(xT∣x0)⋅log[p(xT)q(xT∣x0)]⋅dxT=ExT∼q(xT∣x0)log[p(xT)q(xT∣x0)]=KL(q(xT∣x0)∣∣p(xT))
可以看出, L 1 L_1 L1是 q ( x T ∣ x 0 ) q(x^T|x^0) q(xT∣x0)和 p ( x T ) p(x^T) p(xT)。 q ( x T ∣ x 0 ) q(x^T|x^0) q(xT∣x0)是前向加噪过程的终点,是一个固定的分布。而 p ( x T ) p(x^T) p(xT)是高斯分布,这在论文《Denoising Diffusion Probabilistic Models》中的2 Background的第四行中有说明。由 两个高斯分布KL散度推导可以计算出 L 1 L_1 L1,也就是说 L 1 L_1 L1是一个定值。因此,在损失函数中 L 1 L_1 L1可以被忽略掉。
接着考虑第二项 L 2 L_2 L2。
L
2
=
E
x
1
,
2
,
⋯
T
∼
q
(
x
1
,
2
⋯
T
∣
x
0
)
(
∑
t
=
2
T
l
o
g
[
q
(
x
t
−
1
∣
x
t
,
x
0
)
p
(
x
t
−
1
∣
x
t
)
]
)
=
∑
t
=
2
T
E
x
1
,
2
,
⋯
T
∼
q
(
x
1
,
2
⋯
T
∣
x
0
)
(
l
o
g
[
q
(
x
t
−
1
∣
x
t
,
x
0
)
p
(
x
t
−
1
∣
x
t
)
]
)
=
∑
t
=
2
T
(
∫
d
x
1
,
2
⋯
T
⋅
q
(
x
1
,
2
⋯
T
∣
x
0
)
⋅
l
o
g
[
q
(
x
t
−
1
∣
x
t
,
x
0
)
p
(
x
t
−
1
∣
x
t
)
]
)
=
∑
t
=
2
T
(
∫
d
x
1
,
2
⋯
T
⋅
q
(
x
1
,
2
⋯
T
∣
x
0
)
q
(
x
t
−
1
∣
x
t
,
x
0
)
⋅
q
(
x
t
−
1
∣
x
t
,
x
0
)
⋅
l
o
g
[
q
(
x
t
−
1
∣
x
t
,
x
0
)
p
(
x
t
−
1
∣
x
t
)
]
)
=
∑
t
=
2
T
(
∫
d
x
1
,
2
⋯
T
⋅
q
(
x
0
,
1
,
2
⋯
T
)
q
(
x
0
)
⏟
q
(
x
0
,
1
,
2
⋯
T
)
=
q
(
x
0
)
⋅
q
(
x
1
,
2
⋯
T
∣
x
0
)
⋅
q
(
x
t
,
x
0
)
q
(
x
t
,
x
t
−
1
,
x
0
)
⏟
q
(
x
t
,
x
t
−
1
,
x
0
)
=
q
(
x
t
,
x
0
)
⋅
q
(
x
t
−
1
∣
x
t
,
x
0
)
⋅
q
(
x
t
−
1
∣
x
t
,
x
0
)
⋅
l
o
g
[
q
(
x
t
−
1
∣
x
t
,
x
0
)
p
(
x
t
−
1
∣
x
t
)
]
)
=
∑
t
=
2
T
(
∫
d
x
1
,
2
⋯
T
⋅
q
(
x
0
,
1
,
2
⋯
T
)
q
(
x
0
)
⋅
q
(
x
t
,
x
0
)
q
(
x
t
−
1
,
x
0
)
⋅
q
(
x
t
∣
x
t
−
1
,
x
0
)
⋅
q
(
x
t
−
1
∣
x
t
,
x
0
)
⋅
l
o
g
[
q
(
x
t
−
1
∣
x
t
,
x
0
)
p
(
x
t
−
1
∣
x
t
)
]
)
=
∑
t
=
2
T
(
∫
[
∫
q
(
x
0
,
1
,
2
⋯
T
)
q
(
x
0
)
⋅
q
(
x
t
,
x
0
)
q
(
x
t
−
1
,
x
0
)
⋅
q
(
x
t
∣
x
t
−
1
,
x
0
)
∏
k
≥
1
,
k
≠
t
−
1
d
x
k
]
⋅
q
(
x
t
−
1
∣
x
t
,
x
0
)
⋅
l
o
g
[
q
(
x
t
−
1
∣
x
t
,
x
0
)
p
(
x
t
−
1
∣
x
t
)
d
x
t
−
1
]
)
=
∑
t
=
2
T
(
∫
[
∫
q
(
x
0
,
1
,
2
⋯
T
)
q
(
x
t
−
1
,
x
0
)
⋅
q
(
x
t
,
x
0
)
q
(
x
0
)
⋅
q
(
x
t
∣
x
t
−
1
,
x
0
)
∏
k
≥
1
,
k
≠
t
−
1
d
x
k
]
⋅
q
(
x
t
−
1
∣
x
t
,
x
0
)
⋅
l
o
g
[
q
(
x
t
−
1
∣
x
t
,
x
0
)
p
(
x
t
−
1
∣
x
t
)
d
x
t
−
1
]
)
=
∑
t
=
2
T
(
∫
[
∫
q
(
x
k
:
k
≥
1
,
k
≠
t
−
1
∣
x
t
−
1
,
x
0
)
⏟
q
(
x
0
;
T
)
=
q
(
x
t
−
1
,
x
0
)
⋅
q
(
x
k
:
k
≥
1
,
k
≠
t
−
1
∣
x
t
−
1
,
x
0
)
⋅
q
(
x
t
∣
x
0
)
q
(
x
t
∣
x
t
−
1
,
x
0
)
⏟
q
(
x
t
,
x
0
)
=
q
(
x
0
)
⋅
q
(
x
t
∣
x
0
)
∏
k
≥
1
,
k
≠
t
−
1
d
x
k
]
⋅
q
(
x
t
−
1
∣
x
t
,
x
0
)
⋅
l
o
g
[
q
(
x
t
−
1
∣
x
t
,
x
0
)
p
(
x
t
−
1
∣
x
t
)
d
x
t
−
1
]
)
=
∑
t
=
2
T
(
∫
[
∫
q
(
x
k
:
k
≥
1
,
k
≠
t
−
1
∣
x
t
−
1
,
x
0
)
⋅
q
(
x
t
∣
x
0
)
q
(
x
t
∣
x
t
−
1
,
x
0
)
⏟
=
1
∏
k
≥
1
,
k
≠
t
−
1
d
x
k
]
⋅
q
(
x
t
−
1
∣
x
t
,
x
0
)
⋅
l
o
g
[
q
(
x
t
−
1
∣
x
t
,
x
0
)
p
(
x
t
−
1
∣
x
t
)
d
x
t
−
1
]
)
=
∑
t
=
2
T
(
∫
[
∫
q
(
x
k
:
k
≥
1
,
k
≠
t
−
1
∣
x
t
−
1
,
x
0
)
⋅
∏
k
≥
1
,
k
≠
t
−
1
d
x
k
]
⋅
q
(
x
t
−
1
∣
x
t
,
x
0
)
⋅
l
o
g
[
q
(
x
t
−
1
∣
x
t
,
x
0
)
p
(
x
t
−
1
∣
x
t
)
d
x
t
−
1
]
)
=
∑
t
=
2
T
(
∫
[
∫
q
(
x
k
:
k
≥
1
,
k
≠
t
−
1
∣
x
t
−
1
,
x
0
)
⋅
∏
k
≥
1
,
k
≠
t
−
1
d
x
k
⏟
=
1
]
⋅
q
(
x
t
−
1
∣
x
t
,
x
0
)
⋅
l
o
g
[
q
(
x
t
−
1
∣
x
t
,
x
0
)
p
(
x
t
−
1
∣
x
t
)
d
x
t
−
1
]
)
=
∑
t
=
2
T
(
∫
q
(
x
t
−
1
∣
x
t
,
x
0
)
⋅
l
o
g
[
q
(
x
t
−
1
∣
x
t
,
x
0
)
p
(
x
t
−
1
∣
x
t
)
d
x
t
−
1
]
)
=
∑
t
=
2
T
(
E
x
t
−
1
∼
q
(
x
t
−
1
∣
x
t
,
x
0
)
l
o
g
[
q
(
x
t
−
1
∣
x
t
,
x
0
)
p
(
x
t
−
1
∣
x
t
)
]
)
=
∑
t
=
2
T
K
L
(
q
(
x
t
−
1
∣
x
t
,
x
0
)
∣
∣
p
(
x
t
−
1
∣
x
t
)
)
\begin{equation} \begin{split} L_2&=E_{x^{1,2, \cdots T} \sim q(x^{1,2 \cdots T} | x^0)} \Bigg(\sum_{t=2}^{T} log \Big[\frac{q(x^{t-1}|x^t,x^0)}{ p(x^{t-1}|x^t)} \Big]\Bigg)\\ &=\sum_{t=2}^{T} E_{x^{1,2, \cdots T} \sim q(x^{1,2 \cdots T} | x^0)} \Bigg(log \Big[\frac{q(x^{t-1}|x^t,x^0)}{ p(x^{t-1}|x^t)} \Big]\Bigg)\\ &=\sum_{t=2}^{T} \Bigg( \int dx^{1,2\cdots T} \cdot q(x^{1,2 \cdots T}| x^0) \cdot log \Big[\frac{q(x^{t-1}|x^t,x^0)}{ p(x^{t-1}|x^t)} \Big] \Bigg)\\ &=\sum_{t=2}^{T} \Bigg( \int dx^{1,2\cdots T} \cdot \frac{ q(x^{1,2 \cdots T}| x^0)}{q(x^{t-1}|x^t,x^0)} \cdot q(x^{t-1}|x^t,x^0) \cdot log \Big[\frac{q(x^{t-1}|x^t,x^0)}{ p(x^{t-1}|x^t)} \Big] \Bigg)\\ &=\sum_{t=2}^{T} \Bigg( \int dx^{1,2\cdots T} \cdot \underbrace{ \frac{q(x^{0,1,2\cdots T})}{q(x^0)}}_{q(x^{0,1,2\cdots T})=q(x^0)\cdot q(x^{1,2 \cdots T}| x^0)} \cdot \underbrace{ \frac{q(x^t,x^0)}{q(x^t,x^{t-1},x^0)}}_{q(x^t,x^{t-1},x^0)=q(x^t,x^0)\cdot q(x^{t-1}|x^t,x^0)} \cdot q(x^{t-1}|x^t,x^0) \cdot log \Big[\frac{q(x^{t-1}|x^t,x^0)}{ p(x^{t-1}|x^t)} \Big] \Bigg)\\ &=\sum_{t=2}^{T} \Bigg( \int dx^{1,2\cdots T} \cdot \frac{q(x^{0,1,2\cdots T})}{q(x^0)}\cdot \frac{q(x^t,x^0)}{q(x^{t-1},x^0)\cdot q(x^t|x^{t-1},x^0)} \cdot q(x^{t-1}|x^t,x^0) \cdot log \Big[\frac{q(x^{t-1}|x^t,x^0)}{ p(x^{t-1}|x^t)} \Big] \Bigg)\\ &=\sum_{t=2}^{T} \Bigg( \int \bigg[ \int \frac{q(x^{0,1,2\cdots T})}{q(x^0)}\cdot \frac{q(x^t,x^0)}{q(x^{t-1},x^0)\cdot q(x^t|x^{t-1},x^0)} \prod_{k\geq1 ,k\neq t-1} dx^k \bigg] \cdot q(x^{t-1}|x^t,x^0) \cdot log \Big[\frac{q(x^{t-1}|x^t,x^0)}{ p(x^{t-1}|x^t)} dx^{t-1} \Big] \Bigg)\\ &=\sum_{t=2}^{T} \Bigg( \int \bigg[ \int \frac{q(x^{0,1,2\cdots T})}{q(x^{t-1},x^0)}\cdot \frac{q(x^t,x^0)}{q(x^0)\cdot q(x^t|x^{t-1},x^0)} \prod_{k\geq1 ,k\neq t-1} dx^k \bigg] \cdot q(x^{t-1}|x^t,x^0) \cdot log \Big[\frac{q(x^{t-1}|x^t,x^0)}{ p(x^{t-1}|x^t)} dx^{t-1} \Big] \Bigg)\\ &=\sum_{t=2}^{T} \Bigg( \int \bigg[ \underbrace{ \int q(x^{k:k\geq1,k\neq t-1}|x^{t-1},x^0)}_{q(x^{0;T})=q(x^{t-1},x^0)\cdot q(x^{k:k\geq1,k\neq t-1}|x^{t-1},x^0)} \cdot \underbrace {\frac{q(x^t|x^0)}{ q(x^t|x^{t-1},x^0)}}_{q(x^t,x^0)=q(x^0)\cdot q(x^t|x^0)} \prod_{k\geq1 ,k\neq t-1} dx^k \bigg] \cdot q(x^{t-1}|x^t,x^0) \cdot log \Big[\frac{q(x^{t-1}|x^t,x^0)}{ p(x^{t-1}|x^t)} dx^{t-1} \Big] \Bigg)\\ &=\sum_{t=2}^{T} \Bigg( \int \bigg[\int q(x^{k:k\geq1,k\neq t-1}|x^{t-1},x^0)\cdot \underbrace {\frac{q(x^t|x^0)}{ q(x^t|x^{t-1},x^0)}}_{=1} \prod_{k\geq1 ,k\neq t-1} dx^k \bigg] \cdot q(x^{t-1}|x^t,x^0) \cdot log \Big[\frac{q(x^{t-1}|x^t,x^0)}{ p(x^{t-1}|x^t)} dx^{t-1} \Big] \Bigg)\\ &=\sum_{t=2}^{T} \Bigg( \int \bigg[\int q(x^{k:k\geq1,k\neq t-1}|x^{t-1},x^0)\cdot \prod_{k\geq1 ,k\neq t-1} dx^k \bigg] \cdot q(x^{t-1}|x^t,x^0) \cdot log \Big[\frac{q(x^{t-1}|x^t,x^0)}{ p(x^{t-1}|x^t)} dx^{t-1} \Big] \Bigg)\\ &=\sum_{t=2}^{T} \Bigg( \int \bigg[\underbrace{ \int q(x^{k:k\geq1,k\neq t-1}|x^{t-1},x^0)\cdot \prod_{k\geq1 ,k\neq t-1} dx^k }_{=1}\bigg] \cdot q(x^{t-1}|x^t,x^0) \cdot log \Big[\frac{q(x^{t-1}|x^t,x^0)}{ p(x^{t-1}|x^t)} dx^{t-1} \Big] \Bigg)\\ &=\sum_{t=2}^{T} \Bigg( \int q(x^{t-1}|x^t,x^0) \cdot log \Big[\frac{q(x^{t-1}|x^t,x^0)}{ p(x^{t-1}|x^t)} dx^{t-1} \Big] \Bigg)\\ &=\sum_{t=2}^{T} \Bigg( E_{x^{t-1}\sim q(x^{t-1}|x^t,x^0)} log \Big[\frac{q(x^{t-1}|x^t,x^0)}{ p(x^{t-1}|x^t)} \Big] \Bigg)\\ &=\sum_{t=2}^{T}KL\bigg(q(x^{t-1}|x^t,x^0)||p(x^{t-1}|x^t) \bigg) \end{split} \end{equation}
L2=Ex1,2,⋯T∼q(x1,2⋯T∣x0)(t=2∑Tlog[p(xt−1∣xt)q(xt−1∣xt,x0)])=t=2∑TEx1,2,⋯T∼q(x1,2⋯T∣x0)(log[p(xt−1∣xt)q(xt−1∣xt,x0)])=t=2∑T(∫dx1,2⋯T⋅q(x1,2⋯T∣x0)⋅log[p(xt−1∣xt)q(xt−1∣xt,x0)])=t=2∑T(∫dx1,2⋯T⋅q(xt−1∣xt,x0)q(x1,2⋯T∣x0)⋅q(xt−1∣xt,x0)⋅log[p(xt−1∣xt)q(xt−1∣xt,x0)])=t=2∑T(∫dx1,2⋯T⋅q(x0,1,2⋯T)=q(x0)⋅q(x1,2⋯T∣x0)
q(x0)q(x0,1,2⋯T)⋅q(xt,xt−1,x0)=q(xt,x0)⋅q(xt−1∣xt,x0)
q(xt,xt−1,x0)q(xt,x0)⋅q(xt−1∣xt,x0)⋅log[p(xt−1∣xt)q(xt−1∣xt,x0)])=t=2∑T(∫dx1,2⋯T⋅q(x0)q(x0,1,2⋯T)⋅q(xt−1,x0)⋅q(xt∣xt−1,x0)q(xt,x0)⋅q(xt−1∣xt,x0)⋅log[p(xt−1∣xt)q(xt−1∣xt,x0)])=t=2∑T(∫[∫q(x0)q(x0,1,2⋯T)⋅q(xt−1,x0)⋅q(xt∣xt−1,x0)q(xt,x0)k≥1,k=t−1∏dxk]⋅q(xt−1∣xt,x0)⋅log[p(xt−1∣xt)q(xt−1∣xt,x0)dxt−1])=t=2∑T(∫[∫q(xt−1,x0)q(x0,1,2⋯T)⋅q(x0)⋅q(xt∣xt−1,x0)q(xt,x0)k≥1,k=t−1∏dxk]⋅q(xt−1∣xt,x0)⋅log[p(xt−1∣xt)q(xt−1∣xt,x0)dxt−1])=t=2∑T(∫[q(x0;T)=q(xt−1,x0)⋅q(xk:k≥1,k=t−1∣xt−1,x0)
∫q(xk:k≥1,k=t−1∣xt−1,x0)⋅q(xt,x0)=q(x0)⋅q(xt∣x0)
q(xt∣xt−1,x0)q(xt∣x0)k≥1,k=t−1∏dxk]⋅q(xt−1∣xt,x0)⋅log[p(xt−1∣xt)q(xt−1∣xt,x0)dxt−1])=t=2∑T(∫[∫q(xk:k≥1,k=t−1∣xt−1,x0)⋅=1
q(xt∣xt−1,x0)q(xt∣x0)k≥1,k=t−1∏dxk]⋅q(xt−1∣xt,x0)⋅log[p(xt−1∣xt)q(xt−1∣xt,x0)dxt−1])=t=2∑T(∫[∫q(xk:k≥1,k=t−1∣xt−1,x0)⋅k≥1,k=t−1∏dxk]⋅q(xt−1∣xt,x0)⋅log[p(xt−1∣xt)q(xt−1∣xt,x0)dxt−1])=t=2∑T(∫[=1
∫q(xk:k≥1,k=t−1∣xt−1,x0)⋅k≥1,k=t−1∏dxk]⋅q(xt−1∣xt,x0)⋅log[p(xt−1∣xt)q(xt−1∣xt,x0)dxt−1])=t=2∑T(∫q(xt−1∣xt,x0)⋅log[p(xt−1∣xt)q(xt−1∣xt,x0)dxt−1])=t=2∑T(Ext−1∼q(xt−1∣xt,x0)log[p(xt−1∣xt)q(xt−1∣xt,x0)])=t=2∑TKL(q(xt−1∣xt,x0)∣∣p(xt−1∣xt))
最后考虑
L
3
L_3
L3,事实上,在论文《Deep Unsupervised Learning using Nonequilibrium Thermodynamics》中提到为了防止边界效应,强制另
p
(
x
0
∣
x
1
)
=
q
(
x
1
∣
x
0
)
p(x^0|x^1)=q(x^1|x^0)
p(x0∣x1)=q(x1∣x0),因此这一项也是个常数。
由以上分析可知道,损失函数可以写为公式(9)。
L
:
=
L
1
+
L
2
+
L
3
=
K
L
(
q
(
x
T
∣
x
0
)
∣
∣
p
(
x
T
)
)
+
∑
t
=
2
T
K
L
(
q
(
x
t
−
1
∣
x
t
,
x
0
)
∣
∣
p
(
x
t
−
1
∣
x
t
)
)
−
l
o
g
[
p
(
x
0
∣
x
1
)
]
\begin{equation} \begin{split} L&:=L_1+L_2+L_3 \\ &=KL\Big(q(x^T|x^0)||p(x^T)\Big) + \sum_{t=2}^{T}KL\bigg(q(x^{t-1}|x^t,x^0)||p(x^{t-1}|x^t) \bigg)-log \Big[p(x^{0}|x^1)\Big] \end{split} \end{equation}
L:=L1+L2+L3=KL(q(xT∣x0)∣∣p(xT))+t=2∑TKL(q(xt−1∣xt,x0)∣∣p(xt−1∣xt))−log[p(x0∣x1)]
忽略掉
L
1
L_1
L1和
L
2
L_2
L2,损失函数可以写为公式10。
L
:
=
∑
t
=
2
T
K
L
(
q
(
x
t
−
1
∣
x
t
,
x
0
)
∣
∣
p
(
x
t
−
1
∣
x
t
)
)
\begin{equation} \begin{split} L:=\sum_{t=2}^{T}KL\bigg(q(x^{t-1}|x^t,x^0)||p(x^{t-1}|x^t) \bigg) \end{split} \end{equation}
L:=t=2∑TKL(q(xt−1∣xt,x0)∣∣p(xt−1∣xt))
可以看出 损失函数 L L L是两个高斯分布 q ( x t − 1 ∣ x t , x 0 ) q(x^{t-1}|x^t,x^0) q(xt−1∣xt,x0)和 p ( x t − 1 ∣ x t ) p(x^{t-1}|x^t) p(xt−1∣xt)的KL散度。 q ( x t − 1 ∣ x t , x 0 ) q(x^{t-1}|x^t,x^0) q(xt−1∣xt,x0)的均值和方差由论文阅读笔记:Denoising Diffusion Probabilistic Models (1)可知,分别为
σ 1 = β t ⋅ ( 1 − α t − 1 ˉ ) ( 1 − α t ˉ ) μ 1 = 1 α t ⋅ ( x t − β t 1 − α t ˉ ⋅ z t ) 或者 μ 1 = α t ⋅ ( 1 − α t − 1 ˉ ) 1 − α t ˉ ⋅ x t + β t ⋅ α t − 1 ˉ 1 − α t ˉ ⋅ x 0 \begin{equation} \begin{split} \sigma_1&=\sqrt{\frac{\beta_t\cdot (1-\bar{\alpha_{t-1}})}{(1-\bar{\alpha_{t}})}}\\ \mu_1&=\frac{1}{\sqrt{\alpha_t}}\cdot (x_t-\frac{\beta_t}{\sqrt{1-\bar{\alpha_t}}}\cdot z_t) \\ 或者 \mu_1&=\frac{\sqrt{\alpha_t}\cdot(1-\bar{\alpha_{t-1}})}{1-\bar{\alpha_t}}\cdot x_t+\frac{\beta_t\cdot \sqrt{\bar{\alpha_{t-1}}}}{1-\bar{\alpha_t}} \cdot x_0 \end{split} \end{equation} σ1μ1或者μ1=(1−αtˉ)βt⋅(1−αt−1ˉ)=αt1⋅(xt−1−αtˉβt⋅zt)=1−αtˉαt⋅(1−αt−1ˉ)⋅xt+1−αtˉβt⋅αt−1ˉ⋅x0
而
p
(
x
t
−
1
∣
x
t
)
p(x^{t-1}|x^t)
p(xt−1∣xt)则由模型(深度学习模型或者其他模型)估算出其均值和方差,分别记作
μ
2
,
σ
2
\mu_2,\sigma_2
μ2,σ2。
因此损失函数
L
L
L可以进一步写为公式12。
L
:
=
l
o
g
[
σ
2
σ
1
]
+
σ
1
2
+
(
μ
1
−
μ
2
)
2
2
σ
2
2
−
1
2
\begin{equation} \begin{split} L:=log \Big[\frac{\sigma_2}{\sigma_1}\Big]+\frac{\sigma_1^2 +(\mu_1-\mu_2)^2}{2\sigma_2^2}-\frac{1}{2} \end{split} \end{equation}
L:=log[σ1σ2]+2σ22σ12+(μ1−μ2)2−21
最后结合原文中的代码diffusion-https://github.com/hojonathanho/diffusion来理解一下训练过程和推理过程。
首先是训练过程
class GaussianDiffusion2:
"""
Contains utilities for the diffusion model.
Arguments:
- what the network predicts (x_{t-1}, x_0, or epsilon)
- which loss function (kl or unweighted MSE)
- what is the variance of p(x_{t-1}|x_t) (learned, fixed to beta, or fixed to weighted beta)
- what type of decoder, and how to weight its loss? is its variance learned too?
"""
# 模型中的一些定义
def __init__(self, *, betas, model_mean_type, model_var_type, loss_type):
self.model_mean_type = model_mean_type # xprev, xstart, eps
self.model_var_type = model_var_type # learned, fixedsmall, fixedlarge
self.loss_type = loss_type # kl, mse
assert isinstance(betas, np.ndarray)
self.betas = betas = betas.astype(np.float64) # computations here in float64 for accuracy
assert (betas > 0).all() and (betas <= 1).all()
timesteps, = betas.shape
self.num_timesteps = int(timesteps)
alphas = 1. - betas
self.alphas_cumprod = np.cumprod(alphas, axis=0)
self.alphas_cumprod_prev = np.append(1., self.alphas_cumprod[:-1])
assert self.alphas_cumprod_prev.shape == (timesteps,)
# calculations for diffusion q(x_t | x_{t-1}) and others
self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod)
self.sqrt_one_minus_alphas_cumprod = np.sqrt(1. - self.alphas_cumprod)
self.log_one_minus_alphas_cumprod = np.log(1. - self.alphas_cumprod)
self.sqrt_recip_alphas_cumprod = np.sqrt(1. / self.alphas_cumprod)
self.sqrt_recipm1_alphas_cumprod = np.sqrt(1. / self.alphas_cumprod - 1)
# calculations for posterior q(x_{t-1} | x_t, x_0)
self.posterior_variance = betas * (1. - self.alphas_cumprod_prev) / (1. - self.alphas_cumprod)
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
self.posterior_log_variance_clipped = np.log(np.append(self.posterior_variance[1], self.posterior_variance[1:]))
self.posterior_mean_coef1 = betas * np.sqrt(self.alphas_cumprod_prev) / (1. - self.alphas_cumprod)
self.posterior_mean_coef2 = (1. - self.alphas_cumprod_prev) * np.sqrt(alphas) / (1. - self.alphas_cumprod)
# 在模型Model类当中的方法
def train_fn(self, x, y):
B, H, W, C = x.shape
if self.randflip:
x = tf.image.random_flip_left_right(x)
assert x.shape == [B, H, W, C]
# 随机生成第t步
t = tf.random_uniform([B], 0, self.diffusion.num_timesteps, dtype=tf.int32)
# 计算第t步时对应的损失函数
losses = self.diffusion.training_losses(
denoise_fn=functools.partial(self._denoise, y=y, dropout=self.dropout), x_start=x, t=t)
assert losses.shape == t.shape == [B]
return {'loss': tf.reduce_mean(losses)}
# 根据x_start采样到第t步的带噪图像
def q_sample(self, x_start, t, noise=None):
"""
Diffuse the data (t == 0 means diffused for 1 step)
"""
if noise is None:
noise = tf.random_normal(shape=x_start.shape)
assert noise.shape == x_start.shape
return (
self._extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
self._extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
)
# 计算q(x^{t-1}|x^t,x^0)分布的均值和方差
def q_posterior_mean_variance(self, x_start, x_t, t):
"""
Compute the mean and variance of the diffusion posterior q(x_{t-1} | x_t, x_0)
"""
assert x_start.shape == x_t.shape
posterior_mean = (
self._extract(self.posterior_mean_coef1, t, x_t.shape) * x_start +
self._extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
)
posterior_variance = self._extract(self.posterior_variance, t, x_t.shape)
posterior_log_variance_clipped = self._extract(self.posterior_log_variance_clipped, t, x_t.shape)
assert (posterior_mean.shape[0] == posterior_variance.shape[0] == posterior_log_variance_clipped.shape[0] ==
x_start.shape[0])
return posterior_mean, posterior_variance, posterior_log_variance_clipped
# 由深度学习模型UNet估算出p(x^{t-1}|x^t)分布的方差和均值
def p_mean_variance(self, denoise_fn, *, x, t, clip_denoised: bool, return_pred_xstart: bool):
B, H, W, C = x.shape
assert t.shape == [B]
model_output = denoise_fn(x, t)
# Learned or fixed variance?
if self.model_var_type == 'learned':
assert model_output.shape == [B, H, W, C * 2]
model_output, model_log_variance = tf.split(model_output, 2, axis=-1)
model_variance = tf.exp(model_log_variance)
elif self.model_var_type in ['fixedsmall', 'fixedlarge']:
# below: only log_variance is used in the KL computations
model_variance, model_log_variance = {
# for fixedlarge, we set the initial (log-)variance like so to get a better decoder log likelihood
'fixedlarge': (self.betas, np.log(np.append(self.posterior_variance[1], self.betas[1:]))),
'fixedsmall': (self.posterior_variance, self.posterior_log_variance_clipped),
}[self.model_var_type]
model_variance = self._extract(model_variance, t, x.shape) * tf.ones(x.shape.as_list())
model_log_variance = self._extract(model_log_variance, t, x.shape) * tf.ones(x.shape.as_list())
else:
raise NotImplementedError(self.model_var_type)
# Mean parameterization
_maybe_clip = lambda x_: (tf.clip_by_value(x_, -1., 1.) if clip_denoised else x_)
if self.model_mean_type == 'xprev': # the model predicts x_{t-1}
pred_xstart = _maybe_clip(self._predict_xstart_from_xprev(x_t=x, t=t, xprev=model_output))
model_mean = model_output
elif self.model_mean_type == 'xstart': # the model predicts x_0
pred_xstart = _maybe_clip(model_output)
model_mean, _, _ = self.q_posterior_mean_variance(x_start=pred_xstart, x_t=x, t=t)
elif self.model_mean_type == 'eps': # the model predicts epsilon
pred_xstart = _maybe_clip(self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output))
model_mean, _, _ = self.q_posterior_mean_variance(x_start=pred_xstart, x_t=x, t=t)
else:
raise NotImplementedError(self.model_mean_type)
assert model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape
if return_pred_xstart:
return model_mean, model_variance, model_log_variance, pred_xstart
else:
return model_mean, model_variance, model_log_variance
# 损失函数的计算过程
def training_losses(self, denoise_fn, x_start, t, noise=None):
assert t.shape == [x_start.shape[0]]
# 随机生成一个噪音
if noise is None:
noise = tf.random_normal(shape=x_start.shape, dtype=x_start.dtype)
assert noise.shape == x_start.shape and noise.dtype == x_start.dtype
# 将随机生成的噪音加到x_start上得到第t步的带噪图像
x_t = self.q_sample(x_start=x_start, t=t, noise=noise)
# 有两种损失函数的方法,'kl'和'mse',并且这两种方法差别并不明显。
if self.loss_type == 'kl': # the variational bound
losses = self._vb_terms_bpd(
denoise_fn=denoise_fn, x_start=x_start, x_t=x_t, t=t, clip_denoised=False, return_pred_xstart=False)
elif self.loss_type == 'mse': # unweighted MSE
assert self.model_var_type != 'learned'
target = {
'xprev': self.q_posterior_mean_variance(x_start=x_start, x_t=x_t, t=t)[0],
'xstart': x_start,
'eps': noise
}[self.model_mean_type]
model_output = denoise_fn(x_t, t)
assert model_output.shape == target.shape == x_start.shape
losses = nn.meanflat(tf.squared_difference(target, model_output))
else:
raise NotImplementedError(self.loss_type)
assert losses.shape == t.shape
return losses
# 计算两个高斯分布的KL散度,代码中的logvar1,logvar2为方差的对数
def normal_kl(mean1, logvar1, mean2, logvar2):
"""
KL divergence between normal distributions parameterized by mean and log-variance.
"""
return 0.5 * (-1.0 + logvar2 - logvar1 + tf.exp(logvar1 - logvar2)
+ tf.squared_difference(mean1, mean2) * tf.exp(-logvar2))
# 使用'kl'方法计算损失函数
def _vb_terms_bpd(self, denoise_fn, x_start, x_t, t, *, clip_denoised: bool, return_pred_xstart: bool):
true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(x_start=x_start, x_t=x_t, t=t)
model_mean, _, model_log_variance, pred_xstart = self.p_mean_variance(
denoise_fn, x=x_t, t=t, clip_denoised=clip_denoised, return_pred_xstart=True)
kl = normal_kl(true_mean, true_log_variance_clipped, model_mean, model_log_variance)
kl = nn.meanflat(kl) / np.log(2.)
decoder_nll = -utils.discretized_gaussian_log_likelihood(
x_start, means=model_mean, log_scales=0.5 * model_log_variance)
assert decoder_nll.shape == x_start.shape
decoder_nll = nn.meanflat(decoder_nll) / np.log(2.)
# At the first timestep return the decoder NLL, otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t))
assert kl.shape == decoder_nll.shape == t.shape == [x_start.shape[0]]
output = tf.where(tf.equal(t, 0), decoder_nll, kl)
return (output, pred_xstart) if return_pred_xstart else output
接下来是推理过程。
def p_sample(self, denoise_fn, *, x, t, noise_fn, clip_denoised=True, return_pred_xstart: bool):
"""
Sample from the model
"""
# 使用深度学习模型,根据x^t和t估算出x^{t-1}的均值和分布
model_mean, _, model_log_variance, pred_xstart = self.p_mean_variance(
denoise_fn, x=x, t=t, clip_denoised=clip_denoised, return_pred_xstart=True)
noise = noise_fn(shape=x.shape, dtype=x.dtype)
assert noise.shape == x.shape
# no noise when t == 0
nonzero_mask = tf.reshape(1 - tf.cast(tf.equal(t, 0), tf.float32), [x.shape[0]] + [1] * (len(x.shape) - 1))
# 当t>0时,模型估算出的结果还要加上一个高斯噪音,因为要继续循环。当t=0时,循环停止,因此不需要再添加噪音了,输出最后的结果。
sample = model_mean + nonzero_mask * tf.exp(0.5 * model_log_variance) * noise
assert sample.shape == pred_xstart.shape
return (sample, pred_xstart) if return_pred_xstart else sample
def p_sample_loop(self, denoise_fn, *, shape, noise_fn=tf.random_normal):
"""
Generate samples
"""
assert isinstance(shape, (tuple, list))
# 生成总的布数T
i_0 = tf.constant(self.num_timesteps - 1, dtype=tf.int32)
# 随机生成一个噪音作为p(x^T)
img_0 = noise_fn(shape=shape, dtype=tf.float32)
# 循环T次,得到最终的图像
_, img_final = tf.while_loop(
cond=lambda i_, _: tf.greater_equal(i_, 0),
body=lambda i_, img_: [
i_ - 1,
self.p_sample(
denoise_fn=denoise_fn, x=img_, t=tf.fill([shape[0]], i_), noise_fn=noise_fn, return_pred_xstart=False)
],
loop_vars=[i_0, img_0],
shape_invariants=[i_0.shape, img_0.shape],
back_prop=False
)
assert img_final.shape == shape
return img_final