当前位置: 首页 > article >正文

采样算法一:去噪扩散概率模型(DDPM)采样算法详解

参考

https://arxiv.org/pdf/2006.11239

在这里插入图片描述


一、背景知识

扩散模型(Diffusion Models) 是一种基于概率论的生成模型,通过模拟数据的扩散和逆扩散过程生成样本。核心思想分为两个阶段:

  • 前向扩散过程:逐步向数据中添加噪声,直到数据变为纯高斯噪声。
  • 反向去噪过程:学习如何逐步去除噪声,从噪声中恢复原始数据。

DDPM(Denoising Diffusion Probabilistic Models) 是一种经典的扩散模型,其采样算法是反向去噪过程的核心。


二、前向扩散过程

前向过程通过固定方差调度(Variance Schedule)逐步破坏数据,定义为马尔可夫链:

  1. 公式定义
    每一步的噪声添加公式为:
    q ( x t ∣ x t − 1 ) = N ( x t ; 1 − β t x t − 1 , β t I ) q(x_t | x_{t-1}) = \mathcal{N}(x_t; \sqrt{1-\beta_t} x_{t-1}, \beta_t \mathbf{I}) q(xtxt1)=N(xt;1βt xt1,βtI)
    其中:

    • β t ∈ ( 0 , 1 ) \beta_t \in (0,1) βt(0,1) 是预先定义的噪声调度参数。
    • α t = 1 − β t \alpha_t = 1 - \beta_t αt=1βt, α ˉ t = ∏ i = 1 t α i \bar{\alpha}_t = \prod_{i=1}^t \alpha_i αˉt=i=1tαi.
  2. 任意时刻 ( x_t ) 的闭式解
    通过重参数化技巧,可直接从 x 0 x_0 x0 计算 x t x_t xt
    x t = α ˉ t x 0 + 1 − α ˉ t ϵ , ϵ ∼ N ( 0 , I ) x_t = \sqrt{\bar{\alpha}_t} x_0 + \sqrt{1 - \bar{\alpha}_t} \epsilon, \quad \epsilon \sim \mathcal{N}(0, \mathbf{I}) xt=αˉt x0+1αˉt ϵ,ϵN(0,I)


三、反向去噪过程

反向过程通过神经网络学习逐步去噪,定义为另一个马尔可夫链:

  1. 反向条件分布
    假设每一步服从高斯分布:
    p θ ( x t − 1 ∣ x t ) = N ( x t − 1 ; μ θ ( x t , t ) , Σ θ ( x t , t ) ) p_\theta(x_{t-1} | x_t) = \mathcal{N}(x_{t-1}; \mu_\theta(x_t, t), \Sigma_\theta(x_t, t)) pθ(xt1xt)=N(xt1;μθ(xt,t),Σθ(xt,t))

    • μ θ \mu_\theta μθ Σ θ \Sigma_\theta Σθ 由神经网络预测。
  2. 关键简化
    DDPM 固定方差为 σ t 2 = β t \sigma_t^2 = \beta_t σt2=βt,仅需预测均值 μ θ \mu_\theta μθ。通过推导可得:
    μ θ ( x t , t ) = 1 α t ( x t − β t 1 − α ˉ t ϵ θ ( x t , t ) ) \mu_\theta(x_t, t) = \frac{1}{\sqrt{\alpha_t}} \left( x_t - \frac{\beta_t}{\sqrt{1 - \bar{\alpha}_t}} \epsilon_\theta(x_t, t) \right) μθ(xt,t)=αt 1(xt1αˉt βtϵθ(xt,t))

    • 核心任务:训练神经网络 ( \epsilon_\theta ) 预测噪声 ( \epsilon )。

四、采样算法步骤

从噪声 x T ∼ N ( 0 , I ) x_T \sim \mathcal{N}(0, \mathbf{I}) xTN(0,I) 出发,逐步生成数据:

  1. 输入:训练好的噪声预测模型 ϵ θ \epsilon_\theta ϵθ,时间步总数 T T T,方差调度 { β t } \{\beta_t\} {βt}
  2. 初始化:采样 x T ∼ N ( 0 , I ) x_T \sim \mathcal{N}(0, \mathbf{I}) xTN(0,I)
  3. 迭代去噪 t = T , T − 1 , … , 1 t = T, T-1, \dots, 1 t=T,T1,,1):
    • 预测噪声 ϵ t = ϵ θ ( x t , t ) \epsilon_t = \epsilon_\theta(x_t, t) ϵt=ϵθ(xt,t)
    • 计算均值
      μ t = 1 α t ( x t − β t 1 − α ˉ t ϵ t ) \mu_t = \frac{1}{\sqrt{\alpha_t}} \left( x_t - \frac{\beta_t}{\sqrt{1 - \bar{\alpha}_t}} \epsilon_t \right) μt=αt 1(xt1αˉt βtϵt)
    • 采样前一时刻
      x t − 1 = μ t + σ t z , z ∼ N ( 0 , I ) x_{t-1} = \mu_t + \sigma_t z, \quad z \sim \mathcal{N}(0, \mathbf{I}) xt1=μt+σtz,zN(0,I)
      • t = 1 t=1 t=1时, z = 0 z=0 z=0(不添加噪声)。
  4. 输出 x 0 x_0 x0 为生成的数据。

五、数学推导关键点
  1. 反向过程均值的推导
    通过最小化变分下界(ELBO)中的 KL 散度项,可得:
    μ θ = α t ( 1 − α ˉ t − 1 ) 1 − α ˉ t x t + α ˉ t − 1 β t 1 − α ˉ t x ^ 0 \mu_\theta = \frac{\sqrt{\alpha_t}(1 - \bar{\alpha}_{t-1})}{1 - \bar{\alpha}_t} x_t + \frac{\sqrt{\bar{\alpha}_{t-1}} \beta_t}{1 - \bar{\alpha}_t} \hat{x}_0 μθ=1αˉtαt (1αˉt1)xt+1αˉtαˉt1 βtx^0
    代入 x ^ 0 = x t − 1 − α ˉ t ϵ θ α ˉ t \hat{x}_0 = \frac{x_t - \sqrt{1 - \bar{\alpha}_t} \epsilon_\theta}{\sqrt{\bar{\alpha}_t}} x^0=αˉt xt1αˉt ϵθ,化简后得到均值公式。

  2. 噪声预测的直观解释
    模型 ϵ θ \epsilon_\theta ϵθ 预测的是前向过程中添加到 x 0 x_0 x0 的噪声,通过移除该噪声可逐步恢复数据。


六、伪代码示例
def ddpm_sample(model, T, betas):
    alpha = 1 - betas
    alpha_bar = np.cumprod(alpha)
    
    x = torch.randn_like(data)  # x_T ~ N(0, I)
    for t in range(T, 0, -1):
        # 预测噪声
        epsilon = model(x, t)
        # 计算均值和方差
        mu = (x - (betas[t]/np.sqrt(1-alpha_bar[t])) * epsilon) / np.sqrt(alpha[t])
        if t > 1:
            z = torch.randn_like(x)
        else:
            z = 0
        # 更新x
        x = mu + np.sqrt(betas[t]) * z
    return x


http://www.kler.cn/a/564160.html

相关文章:

  • wpf中如何让TextBox 显示字体的颜色为白色
  • DeepSeek如何辅助学术写作的性质研究?
  • 365天之第P7周:马铃薯病害识别(VGG-16复现)
  • Java之File(文件操作)
  • LeetCode225.用队列实现栈
  • Linux 使用 sosreport 生成系统报告
  • VC++ MFC中 CTreeCtrl的自绘
  • Cesium高级开发教程之四十三:缓冲区分析#线
  • Python毕业设计选题:基于协同过滤算法的儿童图书推荐系统_django
  • 第十一章:服务器信道管理模块
  • 验证环境中为什么要用virtual interface
  • 【R包】pathlinkR转录组数据分析和可视化利器
  • 常用空间数据结构对比
  • visual studio 2022 C++ OpenCV开发环境配置(详细教程)
  • 通过AI大模型 下达指令控制物理设备实现完全自动化
  • JavaScript 深浅拷贝全面解析
  • 《模拟器过检测教程:Nox、雷电、Mumu、逍遥模拟器 Magisk、LSposed 框架安装与隐藏应用配置》
  • JAVA多商户家政同城上门服务预约服务抢单派单+自营商城系统支持小程序+APP+公众号+h5
  • 如何通过JS实现关闭网页时清空该页面在本地电脑的缓存存储?
  • C/C++易错点:函数指针与指针函数的核心区别与避坑指南