结合代码详细讲解DDPM的训练和采样过程
本篇文章结合代码讲解Denoising Diffusion Probabilistic Models(DDPM),首先我们先不关注推导过程,而是结合代码来看一下训练和推理过程是如何实现的,推导过程会在别的文章中讲解;首先我们来看一下论文中的算法描述。DDPM分为扩散过程和反向扩散过程,也就是训练过程和采样过程;
代码来自https://github.com/zoubohao/DenoisingDiffusionProbabilityModel-ddpm-
1. 训练(扩散)过程
首先我们来逐个看一下训练过程中的所有符号的含义:
x 0 x_0 x0是真实图像;
t 是扩散的步数,取值范围从1到T;
ϵ \epsilon ϵ是从标准正态分布中采样的噪声;
ϵ θ \epsilon_\theta ϵθ是模型,用于预测噪声,其输入是 x t x_t xt和 t;
x t x_t xt的表达式如下:
x
t
x_t
xt由
x
0
x_0
x0加噪获得,其中
α
t
‾
\overline{\alpha_{t}}
αt是常数
因此训练过程总结成一句话就是,向真实图像
x
0
x_0
x0中加噪,获得加噪后的图像
x
t
x_t
xt;然后将
x
t
x_t
xt和t输入到网络中,得到预测的噪声,通过使得网络预测的噪声和真实加入的噪声更接近,完成网络的训练。
从另一个角度,我们也可以这么理解:向
x
0
x_0
x0中加噪的过程,可以理解成是编码的过程,加噪之后获取到了图像的中间表示
x
t
x_t
xt;而预测噪声的过程则是从
x
t
x_t
xt解码的过程,只是并没有选择直接解码出
x
0
x_0
x0,而是解码出加入的噪声,也就是残差。
下面来看一下代码,跟上面讲解的过程是一一对应的,首先在初始化函数中我们需要准备好每个时刻t所需要的常数量 α t ‾ \sqrt{\overline{\alpha_{t}}} αt和 1 − α t ‾ \sqrt{1-\overline{\alpha_{t}}} 1−αt。这些参数最原始来源于一个超参数 β t \beta_t βt,这个参数为加入噪声的方差。他们的关系如下:
所以很容易理解代码中的sqrt_alphas_bar就是
α
t
‾
\sqrt{\overline{\alpha_{t}}}
αt,sqrt_one_minus_alphas_bar 就是
1
−
α
t
‾
\sqrt{1-\overline{\alpha_{t}}}
1−αt。
接着在forward函数中,首先从[0,T]中随机选取一个时刻t,然后从标准正态分布中采样一个噪声,shape和
x
0
x_0
x0一致,接着获取
x
t
x_t
xt:
x_t = (
extract(self.sqrt_alphas_bar, t, x_0.shape) * x_0 +
extract(self.sqrt_one_minus_alphas_bar, t, x_0.shape) * noise)
然后将然后将 x t x_t xt和t输入到网络中,得到预测的噪声:
self.model(x_t, t)
计算Loss函数:
loss = F.mse_loss(self.model(x_t, t), noise, reduction='none')
训练过程的完整代码:
class GaussianDiffusionTrainer(nn.Module):
def __init__(self, model, beta_1, beta_T, T):
super().__init__()
self.model = model
self.T = T
self.register_buffer(
'betas', torch.linspace(beta_1, beta_T, T).double())
alphas = 1. - self.betas
alphas_bar = torch.cumprod(alphas, dim=0)
# calculations for diffusion q(x_t | x_{t-1}) and others
self.register_buffer(
'sqrt_alphas_bar', torch.sqrt(alphas_bar))
self.register_buffer(
'sqrt_one_minus_alphas_bar', torch.sqrt(1. - alphas_bar))
# 每次forward时,给每个样本随机取一个t,并采样一个高斯噪声,然后根据t从sqrt_alphas_bar和sqrt_one_minus_alphas_bar中取出对应的系数,然后根据x_0和采样的高斯噪声生成x_t。然后将x_t和t输入到噪声预测网络中,得到预测的噪声。预测出的噪声输入到网络中,计算loss,从而实现model的训练。
def forward(self, x_0):
"""
Algorithm 1.
"""
t = torch.randint(self.T, size=(x_0.shape[0], ), device=x_0.device) # 给batch中每个样本取一个t,取值范围是[0, 1000]
noise = torch.randn_like(x_0) # 采样高斯噪声,shape与x_0一致
x_t = (
extract(self.sqrt_alphas_bar, t, x_0.shape) * x_0 +
extract(self.sqrt_one_minus_alphas_bar, t, x_0.shape) * noise)
loss = F.mse_loss(self.model(x_t, t), noise, reduction='none')
return loss
2. 推理(反向)过程
首先我们来明确一下,反向过程的目标是什么。反向过程的目标是逐步从一张噪声图像
x
T
x_T
xT中恢复出一张图像,表示成
p
θ
(
x
t
−
1
∣
x
t
)
p_{\theta}(x_{t-1}|x_t)
pθ(xt−1∣xt),我们没法推导出
p
(
x
t
−
1
∣
x
t
)
p(x_{t-1}|x_t)
p(xt−1∣xt),但是
p
(
x
t
−
1
∣
x
t
,
x
0
)
p(x_{t-1}|x_t, x_0)
p(xt−1∣xt,x0)是可以用贝叶斯公式推导出来的,其也是一个高斯分布,并且可以把
x
0
x_0
x0化简掉。最终
p
θ
(
x
t
−
1
∣
x
t
)
p_{\theta}(x_{t-1}|x_t)
pθ(xt−1∣xt)分布的均值为:
方差为
β
t
\beta_t
βt。
因此我们可以从
p
θ
(
x
t
−
1
∣
x
t
)
p_{\theta}(x_{t-1}|x_t)
pθ(xt−1∣xt)分布中采样出一个
x
t
−
1
x_{t-1}
xt−1:
这种采样方式叫做重参数技巧,如果不了解可以看如下介绍:
注意:是标准差与标准正态分布相乘,而不是方差;
因为DDPM的方差固定为 β t \beta_t βt,所以反向过程的重点就是学习出这个分布的方差,从上面的表达式可以看出分布的均值与 x t x_t xt和当前时刻加入的噪声 ϵ t \epsilon_t ϵt有关,而我们的模型可以完成对 ϵ t \epsilon_t ϵt的预测,只要将 x t x_t xt和 t 输入进去模型中即可。代码中描述的过程与此一一对应。
注意代码中存在三个噪声,其中eps是模型预测出来的,其和分布的均值计算相关;forward函数中的noise也是噪声,但是它是从标准正态分布中采样的,用于从 p θ ( x t − 1 ∣ x t ) p_{\theta}(x_{t-1}|x_t) pθ(xt−1∣xt)采样;forward函数中的 x T x_T xT是整个反向过程的输入,也是从标准正态分布中采样的。
# 反向过程是从纯噪声x_T开始逐步去噪以生成样本,此过程也是一个高斯分布,均值和x_t以及预测出的噪声相关,方差在ddpm中没有进行学习,直接使用的是后验分布q(x_t-1|x_t,x_0)的方差。
class GaussianDiffusionSampler(nn.Module):
def __init__(self, model, beta_1, beta_T, T):
super().__init__()
self.model = model
self.T = T
self.register_buffer('betas', torch.linspace(beta_1, beta_T, T).double())
alphas = 1. - self.betas
alphas_bar = torch.cumprod(alphas, dim=0)
alphas_bar_prev = F.pad(alphas_bar, [1, 0], value=1)[:T]
self.register_buffer('coeff1', torch.sqrt(1. / alphas))
self.register_buffer('coeff2', self.coeff1 * (1. - alphas) / torch.sqrt(1. - alphas_bar))
self.register_buffer('posterior_var', self.betas * (1. - alphas_bar_prev) / (1. - alphas_bar))
def predict_xt_prev_mean_from_eps(self, x_t, t, eps):
assert x_t.shape == eps.shape
return (
extract(self.coeff1, t, x_t.shape) * x_t -
extract(self.coeff2, t, x_t.shape) * eps
)
def p_mean_variance(self, x_t, t):
# below: only log_variance is used in the KL computations
var = torch.cat([self.posterior_var[1:2], self.betas[1:]])
var = extract(var, t, x_t.shape)
eps = self.model(x_t, t)
xt_prev_mean = self.predict_xt_prev_mean_from_eps(x_t, t, eps=eps)
return xt_prev_mean, var
def forward(self, x_T):
"""
Algorithm 2.
"""
x_t = x_T # 输入是一个标准正态分布噪声
# 从T到1进行reverse过程
for time_step in reversed(range(self.T)):
print(time_step)
t = x_t.new_ones([x_T.shape[0], ], dtype=torch.long) * time_step
mean, var= self.p_mean_variance(x_t=x_t, t=t)
# no noise when t == 0
if time_step > 0:
noise = torch.randn_like(x_t)
else:
noise = 0
x_t = mean + torch.sqrt(var) * noise # 从q(x_t-1|x_t)中采样
assert torch.isnan(x_t).int().sum() == 0, "nan in tensor."
x_0 = x_t
return torch.clip(x_0, -1, 1)