当前位置: 首页 > article >正文

【论文阅读】Improved Denoising Diffusion Probabilistic Models

Improved Denoising Diffusion Probabilistic Models

文章目录

引用: Nichol A Q, Dhariwal P. Improved denoising diffusion probabilistic models[C]//International conference on machine learning. PMLR, 2021: 8162-8171.

论文链接: https://arxiv.org/abs/2102.09672

代码链接: https://github.com/openai/improved-diffusion

概述

去噪扩散概率模型 (DDPM) 是一类生成模型,最近已被证明可以产生出色的样本。实验表明,通过一些简单的修改,DDPM还可以在保持高样品质量的同时实现竞争性的对数似然。为了更紧密地优化变分下界 (VLB),我们使用简单的重新参数化和混合学习目标来学习逆向过程方差,该目标将 VLB 与 Ho 等人[1]的简化目标相结合,允许采样前向传递减少一个数量级,样本质量差异可以忽略不计,这对于这些模型的实际部署非常重要。使用混合目标,模型获得了比直接优化对数似然获得的对数似然更好的对数似然,并发现后一个目标在训练过程中具有更多的梯度噪声。与混合目标相比,一个简单的重要性采样技术可以减少这种噪声,并能够获得更好的对数似然。此外,论文还使用精确度和召回率来比较 DDPM 和 GAN 对目标分布的覆盖程度。最后,我们表明,这些模型的样本质量和可能性可以随着模型容量和训练计算而平滑扩展,使其易于扩展。

Improving the Log-likelihood

虽然Ho等人[1]发现DDPM可以根据FID[2]和Inception Score[3]生成高保真样本,但他们无法通过这些模型实现竞争对数可能性。对数似然是生成建模中广泛使用的指标,人们普遍认为,优化对数似然会迫使生成模型捕获数据分布的所有模式。此外,最近的工作[4]表明,对数似然的微小改进可以对样本质量和学习的特征表示产生巨大影响。因此,重要的是要探讨为什么 DDPM 似乎在这个指标上表现不佳,因为这可能表明一个根本性的缺点,例如模式覆盖率差。

为了研究不同修改的影响,在ImageNet 64×64和CIFAR-10数据集上训练具有固定超参数的固定模型架构。虽然 CIFAR-10 在此类模型中的应用更多,但论文选择研究 ImageNet 64 × 64,因为它在多样性和分辨率之间提供了良好的权衡,能够快速训练模型而不必担心过度拟合。此外,ImageNet 64×64 已在生成建模的背景下进行了广泛研究,能够将 DDPM 直接与许多其他生成模型进行比较。

Ho等人[1]的设置(在设置 σ t 2 = β t σ^2_t = β_t σt2=βt T = 1000 T = 1000 T=1000 的同时优化 L s i m p l e L_{simple} Lsimple )在 200K 训练迭代后,在 ImageNet 64 × 64 64 × 64 64×64 上实现了 3.99 3.99 3.99 b i t s / d i m bits/dim bits/dim) 的对数似然。论文在早期的实验中发现,可以通过将 T T T 1000 1000 1000 增加到 4000 4000 4000 来提高对数似然;通过此更改,对数似然提高到 3.77 3.77 3.77

Learning ∑ θ ( x t , t ) \sum_{\theta}(x_{t}, t) θ(xt,t)

在这里插入图片描述

Ho等人[1]将 ∑ θ ( x t , t ) = σ t 2 I \sum_{\theta}(x_{t}, t) = \sigma_{t}^{2}I θ(xt,t)=σt2I,其中 σ t σ_t σt 不是学习的。奇怪的是,他们发现将 σ t 2 σ^2_t σt2 固定到 β t β_t βt 产生的样品质量与将其固定到 β ~ t \tilde { \beta } _ { t } β~t 大致相同。考虑到 β t β_t βt β ~ t \tilde { \beta } _ { t } β~t 代表两个相反的极端,有理由问为什么这种选择不会影响样本。图 1 给出了一个线索,**它表明 β t β_t βt β ~ t \tilde { \beta } _ { t } β~t 几乎相等(除了接近 t = 0 t = 0 t=0),即模型正在处理难以察觉的细节。此外,随着扩散步骤数量的增加, β t β_t βt和β ̃t似乎在更多的扩散过程中彼此靠近。这表明,在无限扩散步骤的极限下, σ t σ_t σt的选择对样品质量可能完全无关紧要。换句话说,当添加更多的扩散步骤时,模型平均值 μ θ ( x t , t ) \mu _ { \theta } ( x _ { t } , t ) μθ(xt,t) ∑ θ ( x t , t ) \sum_{\theta}(x_{t}, t) θ(xt,t)更能决定分布。虽然上述论点表明,为了样本质量,固定 σ t σ_t σt 是一个合理的选择,但它并没有说明对数似然性。事实上,图2显示,扩散过程的前几步对变分下限的贡献最大。因此,似乎可以通过使用更好的 ∑ θ ( x t , t ) \sum_{\theta}(x_{t}, t) θ(xt,t) 选择来提高对数似然。为了实现这一目标,必须学习 ∑ θ ( x t , t ) \sum_{\theta}(x_{t}, t) θ(xt,t),而不会遇到 Ho 等人遇到的不稳定性。

由于图 1 显示 ∑ θ ( x t , t ) \sum_{\theta}(x_{t}, t) θ(xt,t)的理想范围非常小,因此神经网络很难直接预测 ∑ θ ( x t , t ) \sum_{\theta}(x_{t}, t) θ(xt,t),即使在对数域中也是如此。相反,我们发现最好将方差参数化为在log域 β t β_t βt β ~ t \tilde { \beta } _ { t } β~t之间的插值。具体而言,模型输出一个向量 v v v,每个维度包含一个分量,将此输出转换为方差,如下所示:

∑ θ ( x t , t ) = e x p ( v log ⁡ β t + ( 1 − v ) log ⁡ β ~ t ) \sum _ { \theta } ( x _ { t } , t ) = e x p ( v \log \beta _ { t } + ( 1 - v ) \log \tilde { \beta } _ { t } ) θ(xt,t)=exp(vlogβt+(1v)logβ~t)

没有对 v v v 施加任何约束,理论上允许模型预测插值范围之外的方差。由于 Lsimple 不依赖于 ∑ θ ( x t , t ) \sum_{\theta}(x_{t}, t) θ(xt,t),因此定义了一个新的混合目标:

L h y b r i d = L s i m p l e + λ L v l b L _ { h y b r i d } = L _ { s i m p l e } + \lambda L _ { v l b } Lhybrid=Lsimple+λLvlb

对于实验,设置 λ = 0.001 λ = 0.001 λ=0.001 以防止 L v l b L_{vlb} Lvlb 压倒 L s i m p l e L_{simple} Lsimple。按照同样的推理思路,还对 L v l b L_{vlb} Lvlb项的 μ θ ( x t , t ) \mu _ { \theta } ( x _ { t } , t ) μθ(xt,t)输出应用了停止梯度。这样,$L_{vlb} $可以引导 ∑ θ ( x t , t ) \sum_{\theta}(x_{t}, t) θ(xt,t),而 L s i m p l e L_{simple} Lsimple 仍然是影响 μ θ ( x t , t ) \mu _ { \theta } ( x _ { t } , t ) μθ(xt,t)的主要来源。

在这里插入图片描述

Improving the Noise Schedule

虽然Ho等人中使用的线性噪声调度对于高分辨率图像效果良好,但对于分辨率为64×64和32×32的图像来说,它是次优的。特别地,前向噪声处理的末尾噪声太大,因此对样本质量没有太大贡献。这可以在图3中直观地看到。这种影响的结果在图4中进行了研究,当跳过高达20%的反向扩散过程时,用线性时间表训练的模型不会变得更糟(通过FID测量)。为了解决这个问题,根据 α t ˉ \bar { \alpha _ { t } } αtˉ构建了一个不同的噪声表:

α t ˉ = f ( t ) f ( 0 ) , f ( t ) = cos ⁡ ( t / T + s 1 + s ⋅ π 2 ) 2 \bar { \alpha _ { t } } = \frac { f ( t ) } { f ( 0 ) } , f ( t ) = \cos \left( \frac { t / T + s } { 1 + s } \cdot \frac { \pi } { 2 } \right) ^ { 2 } αtˉ=f(0)f(t),f(t)=cos(1+st/T+s2π)2

β t = 1 − α ‾ t α ‾ t − 1 \beta _ { t } = 1 - \frac { \overline { \alpha } _ { t } } { \overline { \alpha } _ { t - 1 } } βt=1αt1αt

在实践中,将 β t \beta_t βt 裁剪为不大于 0.999,以防止在扩散过程结束时接近 $t = T $的奇点。

在这里插入图片描述

提出的余弦时间表被设计为在过程中具有 α t ˉ \bar { \alpha _ { t } } αtˉ的线性下降,同时在$ t = 0 $和 t = T t = T t=T 的极端附近变化很小,以防止噪声水平的突然变化。图 5 显示了两个计划的 α α α进展情况。可以看到,线性时间表以更快的速度趋向于零,破坏信息的速度比必要的要快得多。使用较小的偏移量 s s s 来防止 β t β_t βt 在$ t = 0 附近太小,因为在过程开始时有少量的噪声会使网络难以足够准确地预测。 ∗ ∗ 特别是,选择了 附近太小,因为在过程开始时有少量的噪声会使网络难以足够准确地预测。**特别是,选择了 附近太小,因为在过程开始时有少量的噪声会使网络难以足够准确地预测。特别是,选择了 s ,使得 ,使得 ,使得\sqrt { \beta _ { 0 } }$略小于像素箱大小 1 / 127.5 1/127.5 1/127.5,因此 s = 0.008 s = 0.008 s=0.008。我们特别选择使用 c o s 2 cos^2 cos2,因为它是一个具有我们正在寻找的形状的通用数学函数。这种选择是任意的,我们预计许多其他具有类似形状的函数也可以使用。**

Reducing Gradient Noise

在这里插入图片描述

在这里插入图片描述

我们希望通过直接优化 L v l b L_{vlb} Lvlb 而不是优化 L h y b r i d L_{hybrid} Lhybrid 来实现最佳的对数似然。然而, L v l b L_{vlb} Lvlb在实践中实际上很难优化,至少在多样化的 ImageNet 64×64 数据集上是这样。图 6 显示了 $L_{vlb} $和 L h y b r i d L{hybrid} Lhybrid 的学习曲线。两条曲线都是嘈杂的,但在训练时间相同的情况下,混合目标显然在训练集上实现了更好的对数似然。通过评估使用两个目标训练的模型的梯度噪声标度证实了 L v l b L_{vlb} Lvlb 的梯度比 L h y b r i d L_{hybrid} Lhybrid 的梯度大得多,如图7所示。因此,我们寻找一种方法来减少 L v l b L_{vlb} Lvlb 的方差,以便直接优化对数似然性。注意到 L v l b L_{vlb} Lvlb的不同项具有很大差异的幅度(图 2),假设采样$ t $在 $L_{vlb} $中均匀地产生不必要的噪声。为了解决这个问题,采用了重要性抽样:

L v l b = E t ∼ p t [ L t p t ] , w h e r e p t ∝ E [ L t 2 ] a n d ∑ p t = 1 L_{vlb} = E_{ t \sim p_{t} } \left[ \frac { L_{t} } { p_{t} } \right] , where p_{t} \propto \sqrt { E \left[ L_{t} ^ {2} \right] } and \sum p_{t} = 1 Lvlb=Etpt[ptLt],whereptE[Lt2] andpt=1

由于 E [ L t 2 ] E \left[ L _ { t } ^ { 2 } \right] E[Lt2] 事先是未知的,并且可能在整个训练过程中发生变化,因此我们维护每个损失项的前 10 个值的历史记录,并在训练期间动态更新。在训练开始时,均匀地采样 t t t,直到为每个 $t ∈ [0, T −1] $抽取 10 个样本。有了这个重要性抽样目标,就能够通过优化 L v l b L_{vlb} Lvlb 来实现最佳的对数似然。如图 6 所示,即 L v l b L_{vlb} Lvlb(重采样)曲线。该图还显示,重要性采样物镜的噪声比原始的均匀采样要小得多。可以发现,在直接优化噪声较小的L_{{hybrid}时,重要性采样技术没有帮助。

Improving Sampling Speed

在这里插入图片描述

为了减少从 T T T K K K 的采样步骤数,使用$ K$ 个介于 1 1 1 T T T(含)之间的均匀分布的实数,然后将每个结果数字四舍五入到最接近的整数。在图 8 中,评估了使用 4000 扩散步骤,使用 25、50、100、200、400、1000 和 4000 个采样步骤训练的$ L_{hybrid}$ 模型和 L s i m p l e L_{simple} Lsimple 模型的 FID。.我们既针对训练有素的检查点,也针对培训中途的检查点。对于 CIFAR-10,使用了 200K 和 500K 的训练迭代,对于 ImageNet 64,使用了 500K 和 1500K 的训练迭代。可以发现,当使用较少的采样步骤时,具有固定sigmas的 L s i m p l e L_{simple} Lsimple 模型在样本质量方面受到的影响要大得多,而学习sigmas的 L h y b r i d L_{hybrid} Lhybrid模型保持了较高的样本质量。使用此模型,100 个采样步骤足以为完全训练的模型实现近乎最佳的 FID。

Scaling Model Size

在这里插入图片描述

为了衡量性能如何通过训练计算进行扩展,我们在 ImageNet 64 × 64 上训练了四个不同的模型,并使用 L h y b r i d L_{hybrid} Lhybrid 目标。为了改变模型容量,在所有层上应用深度乘法器,使得第一层有 64、96、128 或 192 个通道。请注意,之前的实验在第一层中使用了 128 个通道。由于每一层的深度都会影响初始权重的规模,因此将每个模型的Adam学习率按 1 / c h a n n e l m u l t i p l i e r 1 / \sqrt{channel multiplier} 1/channelmultiplier 缩放,因此128通道模型的学习率为0.0001。图 10 显示了 FID 和 NLL 相对于理论训练计算的改进情况。FID 曲线在对数-对数图上看起来近似线性,表明 FID 根据幂律(绘制为黑色虚线)进行缩放。NLL曲线不能完全拟合幂律,这表明验证NLL的扩展方式不如FID。这可能是由多种因素引起的,例如 1) 这种类型的扩散模型出乎意料的高不可约损失,或 2) 模型过度拟合到训练分布。还注意到,这些模型通常无法实现最佳对数似然,因为它们是使用 L h y b r i d L_{hybrid} Lhybrid 而不是直接使用 L v l b L_{vlb} Lvlb 进行训练的,以保持良好的对数似然性和样本质量。

实验

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

参考文献

[1] Ho, J., Jain, A., and Abbeel, P. Denoising diffusion probabilistic models, 2020.
[2] Heusel, M., Ramsauer, H., Unterthiner, T., Nessler, B., and Hochreiter, S. Gans trained by a two time-scale update rule converge to a local nash equilibrium. Advances in Neural Information Processing Systems 30 (NIPS 2017), 2017.
[3] Salimans, T., Goodfellow, I., Zaremba, W., Cheung, V., Radford, A., and Chen, X. Improved techniques for training gans, 2016.
[4] Kaplan, J., McCandlish, S., Henighan, T., Brown, T. B., Chess, B., Child, R., Gray, S., Radford, A., Wu, J., and Amodei, D. Scaling laws for neural language models, 2020.


http://www.kler.cn/a/274329.html

相关文章:

  • Redisson锁简单使用
  • 精通 Numpy 数组:详解数据类型查看、转换与索引要点
  • 如何测量分辨率
  • 如何使用 WebAssembly 扩展后端应用
  • electron-vite【实战系列教程】
  • [网络安全]XSS之Cookie外带攻击姿势详析
  • mysql逗号分隔字段拆成行简述
  • Redis的安装和部署教程(Windows环境)
  • 全球变暖(蓝桥杯,acwing每日一题)
  • 【DL经典回顾】激活函数大汇总(二十五)(GEGLU附代码和详细公式)
  • 金蝶云星空——插件dll重新发布报错:鏃犳硶鏄剧ず椤甸潰锛屽洜涓哄彂鐢熷唴閮ㄦ湇鍔″櫒閿欒銆�
  • tesseract ocr 安装/调用/训练
  • 使用Java JDBC连接数据库
  • c语言指针(二)
  • 概率统计在AI中的作用
  • Java项目利用Redisson实现真正生产可用高并发秒杀功能 支持分布式高并发秒杀
  • 在线教育平台帮助教培机构打造线上
  • 代码随想录算法训练营第二十八天 | 93.复原IP地址 78.子集 90.子集II
  • IText5填充PDF表单使用自定义字体中文生效而英文和数字不生效?
  • Lua中文语言编程源码-第五节,更改lcorolib.c协程库函数, 使Lua加载中文库关键词(与所有的基础库相关)
  • 构建Helm chart和chart使用管道与函数简介
  • (008)Unity StateMachineBehaviour的坑
  • 自动驾驶决策 - 规划 - 控制 (持续更新!!!)
  • 移除元素(leetcode)
  • 人外周血单核细胞来源树突状细胞(MoDC)的制备(一)
  • 下拉树级带搜索功能