当前位置: 首页 > article >正文

CMU 10423 Generative AI:lec7、8、9(专题2:一张图理解diffusion model结构、代码实现和效果)

本文介绍diffusion model是什么(包括:模型详细的架构图、各模块原理和输入输出、训练算法解读、推理算法解读)、以及全套demo代码和效果。至于为什么要这么设计、以及公式背后的数学原理,过程推导很长很长,可见参考资料。

文章目录

  • 1 Diffusion Model概述
  • 2 一张图看懂diffusion model结构
  • 3 Diffusion Model训练和推理算法解释
    • 3.1 Diffusion Model的训练算法解释
    • 3.2 Diffusion Model的推理算法解释
  • 4 diffusion model 代码实现
    • 数据集
    • **噪声调度与预计算**
        • 2\. 重新参数化公式
        • 3\. 推理阶段公式
        • 4\. 代码实现与预计算
        • **4. 解释代码中的每个步骤**
    • 前向扩散过程(添加噪声)与反向采样过程(去除噪声)
        • 1\. 前向扩散过程:`forward_diffusion_viz`
        • 2\. 反向采样过程:`make_inference`
      • 前向扩散过程可视化效果
    • 开始训练
    • 效果展示

1 Diffusion Model概述

Diffusion Model严格意义上最早源于2015年的《Deep Unsupervised Learning using Nonequilibrium Thermodynamics》,但如下这篇论文才真正将Diffusion Model效果发扬光大,有点类似2013年的alexnet网络和1998年的lenet-5网络感觉。

全称:《Denoising Diffusion Probabilistic Models》
时间:2020年
作者人数:3人,加州伯克利大学
论文地址:https://proceedings.neurips.cc/paper_files/paper/2020/file/4c5bcfec8584af0d967f1ab10179ca4b-Paper.pdf
优缺点优点:生成图像的效果非常惊艳,超越VAE、生成式对抗网络等方法,SOTA级别。
缺点:生成图像的速度非常缓慢。

一段话总结DDPM算法:

  1. 目的:训练一个生成模型,以生成符合训练集分布的随机图像。
  2. 核心思想:灵感来自非平衡热力学,扩散模型是通过添加高斯噪声来破坏训练数据(称为正向扩散过程),然后学习如何通过逐步逆转此噪声过程来恢复原始信息(称为反向扩散过程)。
  3. 训练过程:输入一张参入噪声的图像,使用U-Net架构模型,使其学会预测在特定时步 $ t $​ 时对输入图像中添加的噪声。
  4. 生成过程:从一个高斯分布的纯噪声图像开始,输入模型进行噪声预测,然后逐步去除噪声,重复这一过程 T T T 次,以获得一张清晰的全新图像。

在这里插入图片描述

至于为什么这种方式能生效,其背后的数学原理 可见参考资料。

参考资料

  • 李宏毅的《机器学习》——Diffusion Model 原理剖析 (大概1.5小时)

    • 课程网址:https://speech.ee.ntu.edu.tw/~hylee/ml/2023-spring.php
    • B站转载视频:https://www.bilibili.com/video/BV14c411J7f2
  • 《Diffusion Model原理详解及源码解析》

    • https://juejin.cn/post/7210175991837507621
  • 《扩散模型是如何工作的:从零开始的数学原理》

    • https://shao.fun/blog/w/how-diffusion-models-work.html
  • 《What are Diffusion Models?》OpenAI 安全系统团队负责人翁丽莲(Lilian Weng 出品,必是精品)
    a. https://lilianweng.github.io/posts/2021-07-11-diffusion-models/

  • 《an-introduction-to-diffusion-models-and-stable-diffusion》

    • https://blog.marvik.ai/2023/11/28/an-introduction-to-diffusion-models-and-stable-diffusion/

2 一张图看懂diffusion model结构

在这里插入图片描述

3 Diffusion Model训练和推理算法解释

3.1 Diffusion Model的训练算法解释

在这里插入图片描述

  1. 第1步:repeat (重复)
  • 这是一个循环,意味着我们会不断重复以下步骤,直到模型收敛。
  1. 第2步:从数据分布​ q ( x 0 ) q(x_0) q(x0)​中采样 x 0 x_0 x0
  • 这里的 x 0 x_0 x0​ 是一张干净的原始图像,表示我们从真实的图像数据分布中采样(即从图像数据集中随机选择一张图片)。
  • q ( x 0 ) q(x_0) q(x0)​ 表示的是真实数据的分布(训练时它也就是个数据集),目标是让模型最终生成的图片和这个分布一致。
  1. 第3步:从均匀分布中采样 ​ t ∼ Uniform ( { 1 , … , T } ) t \sim \text{Uniform}(\{1,\ldots,T\}) tUniform({1,,T})
  • 这里的 t t t​ 是从1到 T T T​ 的均匀分布中采样的一个时间步。这个时间步控制扩散过程的阶段,表示噪声的加噪程度。
  • T T T​ 是扩散过程的总时间步数,表示噪声逐步增加的步骤。
  • 随机采样时间步,这使得模型能够在任何时间步学习逆转扩散过程,从而增强其适应性。
  1. 第4步:从正态分布 ​ ϵ ∼ N ( 0 , I ) \epsilon \sim \mathcal{N}(0, I) ϵN(0,I)​ 中采样噪声 ϵ \epsilon ϵ
  • 采样一个噪声向量 ϵ \epsilon ϵ​,它是从标准正态分布中生成的。这个噪声表示在每个时间步加入的噪声大小。
  1. 第5步:梯度下降更新
  • 这里我们需要根据误差 ϵ − ϵ θ ( ⋅ ) \epsilon - \epsilon_\theta(\cdot) ϵϵθ()​ 进行梯度下降。
  • 其中 ϵ θ ( ⋅ ) \epsilon_\theta(\cdot) ϵθ()​ 是一个神经网络,负责预测在当前时间步 t t t​ 下的噪声。训练的目标是让模型的预测噪声 ϵ θ \epsilon_\theta ϵθ​ 尽可能接近真实的加噪噪声 ϵ \epsilon ϵ​。
  • 更新的目标是最小化预测噪声与真实噪声之间的差异,公式为:

∇ θ ∥ ϵ − ϵ θ ( α ˉ t x 0 + 1 − α ˉ t ϵ , t ) ∥ 2 \nabla_\theta \| \epsilon - \epsilon_\theta(\sqrt{\bar{\alpha}_t}x_0 + \sqrt{1 - \bar{\alpha}_t}\epsilon, t) \|^2 θϵϵθ(αˉt x0+1αˉt ϵ,t)2

其中:

  • α ˉ t \bar{\alpha}_t αˉt​ :是时间步 t t t​ 对应的参数,用来控制在第 t t t​ 步时,图像的真实信息和噪声之间的权重。 α ˉ t \bar{\alpha}_t αˉt​ 是随着时间步 t t t​ 变化的一个量。具体地,噪声调度定义了每个时间步 t t t​ 如何给图像加噪声,常见的噪声调度方法是根据线性或指数衰减来设定一系列 α t \alpha_t αt​,然后通过这些 α t \alpha_t αt​ 累积计算出 α ˉ t \bar{\alpha}_t αˉt​。一般来说,在训练开始时, α ˉ t \bar{\alpha}_t αˉt​ 的值比较大,而随着时间 t t t​ 的增加, α ˉ t \bar{\alpha}_t αˉt​ 的值逐渐减小,这意味着图像中的真实信息减少,而噪声的比例逐渐增加。换句话说, α ˉ t \bar{\alpha}_t αˉt​ 越小,图像中的噪声就越多。到时间步 T T T​ 时,图像几乎完全变成了随机噪声,原始图像信息几乎无法辨认。
  • α ˉ t x 0 \sqrt{\bar{\alpha}_t}x_0 αˉt x0​ :是根据 x 0 x_0 x0​(即干净图像)经过时间步 t t t​ 后加噪的图像。
  • 1 − α ˉ t ϵ \sqrt{1 - \bar{\alpha}_t}\epsilon 1αˉt ϵ​ :是对应噪声在 t t t​ 时间步的成分。
  • ϵ θ \epsilon_\theta ϵθ​ :即模型,需要学会预测噪声。
  1. 第6步:直到收敛(until converged)
  • 当模型收敛时,梯度下降的目标值逐渐变小,意味着模型学会了如何准确地预测噪声,从而可以有效去噪图像。
  1. 备注:当T取不同值时,输入模型的图像含噪幅度示意图如下:
    a. 在这里插入图片描述

训练总结

这个训练过程的核心思想是:通过神经网络 $ \epsilon_\theta $​ 预测扩散过程中给定时间步下的噪声。模型的目标是最小化预测噪声与真实噪声之间的差异,采用均方误差(MSE)损失。

3.2 Diffusion Model的推理算法解释

在这里插入图片描述

  1. x T x_T xT​(左上角的图像)
  • 这个图像最开始是从正态分布 N ( 0 , I ) \mathcal{N}(0, I) N(0,I)​ 中采样的纯噪声图像。推理的任务是逐步从这个纯噪声图像中一步步的去噪(总共经过T步),从而生成一个清晰的图像。
  1. 噪声预测器(Noise Predictor)
  • 即之前训练好的U-net模型。在每个时间步 t t t​,噪声预测器(即神经网络 ϵ θ \epsilon_\theta ϵθ​)基于当前噪声图像 x t x_t xt​ 和时间步 t t t​ 来预测出噪声 ϵ θ ( x t , t ) \epsilon_\theta(x_t, t) ϵθ(xt,t)​。这是恢复图像过程中的关键一步。
  1. 计算 x t − 1 x_{t-1} xt1
  • 根据 t 步的图像 x t x_{t} xt,使用公式计算出时间步 t − 1 t-1 t1 的图像 x t − 1 x_{t-1} xt1。这一过程包括两个部分:
    • 减去噪声成分 x t x_t xt 中的噪声成分由 ϵ θ ( x t , t ) \epsilon_\theta(x_t, t) ϵθ(xt,t) 给出,通过减去该噪声可以得到较少噪声的图像。
    • 添加额外噪声 ​ z z z​:加入正态分布噪声 z z z​。它保证了生成过程中的随机性,使得生成的图像具有多样性;同时,它通过逐步减小噪声的尺度,使生成过程平滑过渡到清晰的图像。如果不加这一噪声项,生成过程会变得确定性,导致生成的图像多样性不足,质量下降,甚至可能出现模式坍塌问题。李宏毅课程中提到,如果不加这个额外噪声,模型几乎完全无法生成有意义的图像,如下图:
    • 在这里插入图片描述
  1. 重复T次,最终生成图像
  • 不断重复这个过程,逐步从时间步 T T T 恢复到时间步 1,噪声逐渐减少,图像逐步清晰。
  • 整个过程表示如下:
  • 在这里插入图片描述
  1. 备注:如果超过T次采样,继续去噪下去,图像将逐渐变得模糊、失去原有结构,变成无意义的噪声图。

4 diffusion model 代码实现

代码出处(某人完成的优达学城AIGC课程练习):

  • https://github.com/amanpreetsingh459/Generative-AI/tree/main/3.%20Computer%20Vision%20and%20Generative%20AI/3.%20Diffusion%20Models
  • 备注:pytorch官网数据集链接已经失效了,需要在kaggle或他人处下载,加载数据集的代码我已经修正了。
/home/ym/AIGC_udacity/exerise/diffusion/stanford_cars/
├── cars_annos.mat
├── cars_test_annos_withlabels.mat
├── car_ims/
    ├── 00001.jpg
    ├── 00002.jpg
    └── ...

数据集

使用斯坦福汽车数据集。该数据集包含 196 类汽车,共 16,185 张图片。在这个练习中,我们不需要任何标签,也不需要测试数据集。我们还将把图像转换为 64x64,以便更快地完成练习:

import os
from PIL import Image
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import multiprocessing

IMG_SIZE = 64
BATCH_SIZE = 100

class ImageDataset(Dataset):
    def __init__(self, root, transform=None):
        self.root = root
        self.transform = transform
        
        # 获取所有图像文件路径
        self.image_paths = [os.path.join(root, 'car_ims', fname) 
                            for fname in os.listdir(os.path.join(root, 'car_ims')) 
                            if fname.endswith(('.jpg', '.png'))]
        
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, index):
        img_path = self.image_paths[index]
        image = Image.open(img_path).convert('RGB')
        
        if self.transform is not None:
            image = self.transform(image)
        
        # 只返回图像
        return image

# 数据预处理
data_transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(0.5, 0.5)  # 将数据归一化到 [-1, 1]
])

# 创建数据集。文件夹结构:
# /home/ym/AIGC_udacity/exerise/diffusion/stanford_cars/
# ├── cars_annos.mat
# ├── cars_test_annos_withlabels.mat
# ├── car_ims/
#     ├── 00001.jpg
#     ├── 00002.jpg
#     └── ...
dataset = ImageDataset(root='/home/ym/AIGC_udacity/exerise/diffusion/stanford_cars', 
                       transform=data_transform)

# 创建数据加载器
dataloader = DataLoader(
    dataset, 
    batch_size=BATCH_SIZE, 
    shuffle=True, 
    drop_last=False, 
    pin_memory=True, 
    num_workers=multiprocessing.cpu_count(),
    persistent_workers=True
)

数据集可视化如下:

在这里插入图片描述

噪声调度与预计算

在扩散模型的前向过程(Forward Process)中,我们需要根据一个噪声调度策略向数据中添加随机噪声。为了方便模型训练和推理过程中的计算,我们需要预先定义并计算一些常量。

1. 噪声调度的定义

# Define beta schedule
T = 512  # number of diffusion steps
# YOUR CODE HERE
betas = torch.linspace(start=0.0001, end=0.02, steps=T)  # linear schedule

plt.plot(range(T), betas.numpy(), label='Beta Values')
plt.xlabel('Diffusion Step')
plt.ylabel('Beta Value')
_ = plt.title('Beta Schedule over Diffusion Steps')

在这里插入图片描述

  1. 噪声调度的定义
  • 代码中设置 $ T = 512 $​,表示扩散过程共有512个步骤。
  • 使用 torch.linspace(start=0.0001, end=0.02, steps=T) 创建了一个线性调度器,其中噪声参数 $ \beta_t $​ 从 0.0001 均匀递增至 0.02,生成512个数值,代表每个扩散步骤中的噪声幅度。
  1. 解释图像
  • 图像显示, β t \beta_t βt 随扩散步骤呈线性增长,这意味着在扩散模型的每一步,添加到样本中的噪声量逐渐增加。
  • 这种线性调度策略通常用于扩散模型的训练,帮助模型在生成过程中的前期添加较小噪声,后期添加较大噪声,使得模型能够更稳定地学习数据分布。
  1. 备注:后续改进的DDPM使用余弦噪声调度,能获取更好性能。原因:线性调度可能会导致输入图像中的信息快速丢失。因此,这通常会导致突然的扩散过程。相比之下,余弦调度提供了更平滑的退化。因此,允许后续步骤对没有被噪声完全淹没的图像进行操作。
    a. 在这里插入图片描述
2. 重新参数化公式

前向过程中的重新参数化允许我们在任意步骤生成带噪声的图像,而无需按顺序遍历所有的前面步骤。这通过以下公式实现:

α ˉ t = ∏ s = 1 t ( 1 − β s ) \bar{\alpha}_t = \prod_{s=1}^t (1 - \beta_s) αˉt=s=1t(1βs)

q ( x t ∣ x 0 ) = N ( α ˉ t x 0 , ( 1 − α ˉ t ) I ) q(x_t|x_0) = \mathcal{N}(\sqrt{\bar{\alpha}_t} x_0, (1 - \bar{\alpha}_t) I) q(xtx0)=N(αˉt x0,(1αˉt)I)

这意味着我们可以直接根据 x 0 x_0 x0 生成任意步骤 x t x_t xt 的带噪声样本,而不必每次从 x 1 , x 2 , . . . , x t − 1 x_1, x_2, ..., x_{t-1} x1,x2,...,xt1 逐步采样。

3. 推理阶段公式

推理阶段的反向过程使用以下公式生成样本 x t − 1 x_{t-1} xt1

x t − 1 = 1 α t ( x t − 1 − α t 1 − α ˉ t ϵ θ ( x t , t ) ) + σ t z x_{t-1} = \frac{1}{\sqrt{\alpha_t}}\left(x_t - \frac{1 - \alpha_t}{\sqrt{1 - \bar{\alpha}_t}} \epsilon_\theta(x_t, t) \right) + \sigma_t z xt1=αt 1(xt1αˉt 1αtϵθ(xt,t))+σtz

其中

σ t 2 = ( 1 − α ˉ t − 1 ) ( 1 − α ˉ t ) β t \sigma_t^2 = \frac{(1 - \bar{\alpha}_{t-1})}{(1 - \bar{\alpha}_t)} \beta_t σt2=(1αˉt)(1αˉt1)βt

这个公式描述了如何利用当前带噪声的样本 x t x_t xt 生成上一步的样本 x t − 1 x_{t-1} xt1

4. 代码实现与预计算

为了高效地进行前向和反向过程,我们在代码中对上述所有常数进行了预先计算:

# 预先计算在封闭形式中用到的各项
alphas = 1. - betas  # 计算每个步骤的 α_t
alphas_cumprod = torch.cumprod(alphas, axis=0)  # 计算累积乘积 \bar{\alpha}_t

# 计算 \bar{\alpha}_{t-1}
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0) 

# 计算 \sqrt{\bar{\alpha}_t}
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod) 

# 计算 1 / \sqrt{\alpha_t}
sqrt_recip_alphas = torch.sqrt(1.0 / alphas) 

# 计算 \sqrt{1 - \bar{\alpha}_t}
sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod) 

# 计算推理阶段的 σ_t^2
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
4. 解释代码中的每个步骤
  • alphas = 1. - betas:计算每个步骤的 α t \alpha_t αt
  • alphas_cumprod = torch.cumprod(alphas, axis=0):计算 α ˉ t = ∏ s = 1 t α s \bar{\alpha}_t = \prod_{s=1}^t \alpha_s αˉt=s=1tαs
  • alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0):计算 α ˉ t − 1 \bar{\alpha}_{t-1} αˉt1 并在序列开始处补1,使其长度与其他参数匹配。
  • sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod):计算 α ˉ t \sqrt{\bar{\alpha}_t} αˉt
  • sqrt_recip_alphas = torch.sqrt(1.0 / alphas):计算 1 α t \frac{1}{\sqrt{\alpha_t}} αt 1
  • sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod):计算 1 − α ˉ t \sqrt{1 - \bar{\alpha}_t} 1αˉt
  • posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod):计算推理阶段 x t − 1 x_{t-1} xt1 所需的方差 σ t 2 \sigma_t^2 σt2

前向扩散过程(添加噪声)与反向采样过程(去除噪声)

  • forward_diffusion_viz(): 实现了输入图像从原始样本到完全随机噪声的前向扩散过程。
  • make_inference(): 实现了扩散模型的逆过程,即将随机噪声逐步还原为原始图像。
@torch.no_grad()
def forward_diffusion_viz(image, device='cpu', num_images=16, dpi=75, interleave=False):
    """
    Generate the forward sequence of noisy images taking the input image to pure noise
    """
    # Visualize only num_images diffusion steps, instead of all of them
    stepsize = int(T/num_images)
    
    imgs = []
    noises = []
    
    for i in range(0, T, stepsize):
        t = torch.full((1,), i, device=device, dtype=torch.long)

        # Forward diffusion process
        bs = image.shape[0]
        noise = torch.randn_like(image, device=device)
        img = (
            sqrt_alphas_cumprod[t].view(bs, 1, 1, 1) * image + 
            sqrt_one_minus_alphas_cumprod[t].view(bs, 1, 1, 1) * noise
        )

        imgs.append(torch.clamp(img, -1, 1).squeeze(dim=0))
        noises.append(torch.clamp(noise, -1, 1).squeeze(dim=0))
    
    if interleave:
        imgs = [item for pair in zip(imgs, noises) for item in pair]
        
    fig = display_sequence(imgs, dpi=dpi)
    
    return fig, imgs[-1]


@torch.no_grad()
def make_inference(input_noise, return_all=False):
    """
    Implements the sampling algorithm from the DDPM paper
    """
    
    x = input_noise
    bs = x.shape[0]
    
    imgs = []
    
    # YOUR CODE HERE
    for time_step in range(0, T)[::-1]:
        
        noise = torch.randn_like(x) if time_step > 0 else 0
        
        t = torch.full((bs,), time_step, device=device, dtype=torch.long)
        
        # YOUR CODE HERE
        x = sqrt_recip_alphas[t].view(bs, 1, 1, 1) * (
            x - betas[t].view(bs, 1, 1, 1) * model(x, t) / 
            sqrt_one_minus_alphas_cumprod[t].view(bs, 1, 1, 1)
        ) + torch.sqrt(posterior_variance[t].view(bs, 1, 1, 1)) * noise
        
        imgs.append(torch.clamp(x, -1, 1))
    
    if return_all:
        return imgs
    else:
        return imgs[-1]
    
    return x
1. 前向扩散过程:forward_diffusion_viz

该函数实现了将输入图像通过扩散过程逐步变成纯噪声,并在过程中可视化各步骤的噪声生成效果。

代码解析:

  • 装饰器 @torch.no_grad():表示在该函数中不会计算梯度,节省内存并加快计算速度,这是因为我们只是想观察生成过程而不是进行训练。
  • 输入参数
    • image: 输入图像张量,通常是一个单个样本。
    • num_images: 想要可视化的扩散步骤数,默认是16。
    • dpi: 绘图的分辨率。
    • interleave: 是否交错显示生成的噪声图像。

主要步骤

  1. 计算扩散步长stepsize = int(T / num_images) 用于确定可视化过程中要间隔多少步。
  2. 在循环中实现前向扩散过程
  • 对每个步骤 t t t,计算当前噪声 n o i s e noise noise 并利用公式生成带噪声的图像 img

x t = α ˉ t x 0 + 1 − α ˉ t ϵ x_t = \sqrt{\bar{\alpha}_t} x_0 + \sqrt{1 - \bar{\alpha}_t} \epsilon xt=αˉt x0+1αˉt ϵ

  • 将结果添加到 imgs 列表中,noises 列表中存储对应的噪声。

  • 交错选项 interleave:如果 interleave=True,将噪声图像和带噪声图像交替存储在 imgs 列表中。

  • 显示结果:调用 display_sequence(imgs, dpi=dpi) 来显示扩散过程,并返回最终图像。

2. 反向采样过程:make_inference

该函数实现了扩散模型的反向采样过程,将随机噪声转化为生成的图像样本。这个过程基于DDPM论文中的采样算法。

代码解析:

  • 输入参数
    • input_noise: 输入噪声,用于开始反向采样过程。
    • return_all: 控制是否返回所有步骤的结果。

主要步骤

  1. 初始化:将输入噪声赋值给 x
  2. 反向扩散过程
  • T T T 到 0,逐步反向遍历扩散步骤。
  • 在每一步 t t t,根据扩散模型的公式生成上一步的样本 x t − 1 x_{t-1} xt1

x t − 1 = 1 α t ( x t − 1 − α t 1 − α ˉ t ϵ θ ( x t , t ) ) + σ t z x_{t-1} = \frac{1}{\sqrt{\alpha_t}}\left(x_t - \frac{1 - \alpha_t}{\sqrt{1 - \bar{\alpha}_t}} \epsilon_\theta(x_t, t) \right) + \sigma_t z xt1=αt 1(xt1αˉt 1αtϵθ(xt,t))+σtz

  • 使用 model(x, t) 预测当前步骤的噪声,使用 torch.sqrt(posterior_variance[t]) 计算加权噪声项。

  • 将生成的结果 x 添加到 imgs 列表中。

  • 返回结果:如果 return_all=True,则返回所有步骤的结果,否则只返回最后一个步骤的图像。

前向扩散过程可视化效果

在这里插入图片描述

开始训练

  • 下述代码实现了一个扩散模型的训练过程,主要包括模型初始化、数据处理、训练循环、损失计算和梯度更新等。
  • 特别注意了混合精度训练的部分,使用了 torch.cuda.amp 中的 autocastGradScaler 来实现自动混合精度训练,从而提高了训练效率并减少显存使用。(消耗10.5G显存,不使用混合精度训练估计需要18G显存)
  • 同时加入了学习率预热(warmup)和余弦退火(Cosine Annealing)的调度策略,以确保模型训练的稳定性和效果。
# 导入自定义的UNet模型
from unet import UNet

# 初始化UNet模型,使用默认的通道倍增数
model = UNet(ch_mults = (1, 2, 1, 1))

# 如果想进行非常长时间的训练,可以选择注释掉上面的模型初始化
# 并启用下面这一行的模型初始化
# model = UNet(ch_mults = (1, 2, 2, 2))

# 计算模型参数数量,并打印
n_params = sum(p.numel() for p in model.parameters())
print(f"Number of parameters: {n_params:,}")

# 设置设备为GPU或CPU
device = "cuda" if torch.cuda.is_available() else "cpu"

# 将模型移动到对应设备
model.to(device)

# 将所有需要的参数移动到设备(GPU/CPU)
sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device)
alphas = alphas.to(device)
alphas_cumprod = alphas_cumprod.to(device)
alphas_cumprod_prev = alphas_cumprod_prev.to(device)
sqrt_recip_alphas = sqrt_recip_alphas.to(device)
sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device)
sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device)
posterior_variance = posterior_variance.to(device)
betas = betas.to(device)

# 定义损失函数
criterion = torch.nn.MSELoss()

# 设置训练相关参数
base_lr = 0.0006 # 基础学习率
epochs = 10 # 总的训练轮次
T_max = epochs  # Cosine Annealing 的最大步数
warmup_epochs = 2  # 预热训练轮数

# 如果想要进行非常长时间的训练,可以启用以下设置
# base_lr = 0.0001 # 基础学习率
# epochs = 300 # 总训练轮数
# T_max = epochs  # Cosine Annealing 的最大步数
# warmup_epochs = 10  # 预热训练轮数

# 初始化优化器和学习率调度器
optimizer = Adam(model.parameters(), lr=base_lr)
scheduler = CosineAnnealingLR(
    optimizer, 
    T_max=T_max - warmup_epochs,  # 调度器的最大步数
    eta_min=base_lr / 10  # 学习率最小值
)

# 导入用于混合精度训练的模块
from torch.cuda.amp import autocast, GradScaler

# 初始化GradScaler用于缩放梯度
scaler = GradScaler()

# 生成固定的噪声,用于在训练过程中检查模型的生成效果
fixed_noise = torch.randn((1, 3, IMG_SIZE, IMG_SIZE), device=device)

# 设置 EMA 损失平滑因子
alpha = 0.1  # EMA(指数移动平均)平滑因子
ema_loss = None  # 初始化EMA损失

# 开始训练循环
for epoch in range(epochs):
    
    if epoch < warmup_epochs:
        # 线性预热学习率
        lr = base_lr * (epoch + 1) / warmup_epochs
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr
    else:
        # 预热完成后使用余弦退火学习率
        scheduler.step()

    current_lr = optimizer.param_groups[0]['lr']  # 当前学习率
        
    for batch in tqdm(dataloader):
        
        batch = batch.to(device)  # 将当前batch移动到设备
        bs = batch.shape[0]  # 获取当前batch大小
        
        optimizer.zero_grad()  # 清空优化器的梯度
        
        # 混合精度训练开始
        with autocast():
            # 随机选择t时刻
            t = torch.randint(0, T, (batch.shape[0],), device=device).long()
            
            # 生成目标噪声并添加到图像中
            noise = torch.randn_like(batch, device=device)
            x_noisy = (
                sqrt_alphas_cumprod[t].view(bs, 1, 1, 1) * batch + 
                sqrt_one_minus_alphas_cumprod[t].view(bs, 1, 1, 1) * noise
            )
            
            # 通过模型预测噪声
            noise_pred = model(x_noisy, t)
            loss = criterion(noise, noise_pred)  # 计算损失
        
        # 使用 scaler 进行混合精度训练的反向传播
        scaler.scale(loss).backward()
        
        # 更新优化器的权重
        scaler.step(optimizer)
        
        # 更新 scaler 
        scaler.update()
        
        if ema_loss is None:
            # 第一个 batch 初始化 ema_loss
            ema_loss = loss.item()
        else:
            # 计算损失的指数移动平均
            ema_loss = alpha * loss.item() + (1 - alpha) * ema_loss
    
    if epoch == epochs-1:
        with torch.no_grad():
            # 在训练结束时对固定噪声进行推理,查看生成结果
            imgs = make_inference(fixed_noise, return_all=True)
            fig = display_sequence([imgs[0].squeeze(dim=0)] + [x.squeeze(dim=0) for x in imgs[63::64]], nrow=9, dpi=150)
            plt.show(fig)
        
        # 保存结果图像
        os.makedirs("diffusion_output_long", exist_ok=True)
        fig.savefig(f"diffusion_output_long/frame_{epoch:05d}.png")
    
    # 打印当前轮次的损失和学习率
    print(f"epoch {epoch+1}: loss: {ema_loss:.3f}, lr: {current_lr:.6f}")

训练结果:

在这里插入图片描述

效果展示

考虑到这个模型如此之小,我们对它的训练又如此之少,这个结果已经相当不错了。我们已经可以看出,它确实在创建汽车,而且还带有挡风玻璃和车轮,尽管这还只是初步阶段。如果我们训练的时间更长,和/或使用更大的模型(例如,上面注释行中定义的模型有 5,500 万个参数),并让它训练几个小时,我们会得到更好的结果,就像这样:

在这里插入图片描述


http://www.kler.cn/a/322407.html

相关文章:

  • C++AVL平衡树
  • Vue3 pinia使用
  • vue2动态导出多级表头表格
  • css数据不固定情况下,循环加不同背景颜色
  • 【Docker】在 Ubuntu 上安装 Docker 的详细指南
  • IntelliJ IDEA 2024.3(Ultimate Edition)免费化教学
  • 论文阅读 | 一种基于潜在向量优化的可证明安全的图像隐写方法(TMM 2023)
  • 端上自动化测试平台实践
  • Go 实现:椭圆曲线数字签名算法ECDSA
  • 50道渗透测试面试题,全懂绝对是高手
  • 边裁员边收购,思科逐渐变身软件并购之王
  • Java 入门指南:并发设计模式 —— 两端终止模式
  • C++之STL—常用排序算法
  • JavaScript 操作 DOM元素CSS 样式的几种方法
  • MATLAB软件开发通用控制的软件架构参考
  • (JAVA)浅尝关于 “栈” 数据结构
  • Android 增加宏开关控制android.bp
  • MySQL查询语句优化
  • DataGrip远程连接Hive
  • Python中列表常用方法
  • C语言 15 预处理
  • vue3 TagInput 实现
  • 监控易监测对象及指标之:Kubernetes(K8s)集群的全方位监控策略
  • webpack与vite读取base64图片
  • django开发流程1
  • manim中实现文字换行和设置字体格式