当前位置: 首页 > article >正文

扩散模型实战:从零开始训练手写数字生成模型

Diffusion models代码解读:入门与实战

前言:手写数字的数据集是绝大部分炼丹师的深度学习初恋,这篇博客以代码为主,手把手带读者搭建一个基于扩散模型的手写数字生成模型,非常适合刚接触扩散模型的初学者学习。

目录

环境准备

数据集准备

加噪过程

模型

模型训练

推理


环境准备

pip install -q diffusers
import torch
import torchvision
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from diffusers import DDPMScheduler, UNet2DModel
from matplotlib import pyplot as plt

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f'Using device: {device}')

数据集准备

在这里,我们将使用一个非常小的经典数据集mnist来进行测试。如果您想在不改变任何其他内容的情况下给模型一个稍微困难一点的挑战,请使用torchvision.dataset。FashionMNIST应作为替代品。

dataset = torchvision.datasets.MNIST(root="mnist/", train=True, download=True, transform=torchvision.transforms.ToTensor())
train_dataloader = DataLoader(dataset, batch_size=8, shuffle=True)
x, y = next(iter(train_dataloader))
print('Input shape:', x.shape)
print('Labels:', y)
plt.imshow(torchvision.utils.make_grid(x)[0], cmap='Greys');

该数据集中的每张图都是一个数字的28x28像素的灰度图,像素值的范围事从0到1。

加噪过程

假设你没有读过任何扩散模型的论文,但你知道这个过程会增加噪声。你会怎么做?

我们可能想要一个简单的方法来控制损坏的程度。那么,如果我们要引入一个参数来控制输入的“噪声量”,那么我们会这么做:

noise = torch.rand_like(x)

noisy_x = (1-amount)*x + amount*noise

如果 amount = 0,则返回输入而不做任何更改。如果 amount = 1,我们将得到一个纯粹的噪声。通过这种方式将输入与噪声混合,我们将输出保持在相同的范围(0 to 1)。

我们可以很容易地实现这一点(但是要注意tensor的shape,以防被广播(broadcasting)机制不正确的影响到):

def corrupt(x, amount):
  """Corrupt the input `x` by mixing it with noise according to `amount`"""
  noise = torch.rand_like(x)
  amount = amount.view(-1, 1, 1, 1) # Sort shape so broadcasting works
  return x*(1-amount) + noise*amount 

让我们来可视化一下输出的结果,以了解是否符合我们的预期:

# Plotting the input data
fig, axs = plt.subplots(2, 1, figsize=(12, 5))
axs[0].set_title('Input data')
axs[0].imshow(torchvision.utils.make_grid(x)[0], cmap='Greys')

# Adding noise
amount = torch.linspace(0, 1, x.shape[0]) # Left to right -> more corruption
noised_x = corrupt(x, amount)

# Plottinf the noised version
axs[1].set_title('Corrupted data (-- amount increases -->)')
axs[1].imshow(torchvision.utils.make_grid(noised_x)[0], cmap='Greys');

当噪声量接近1时,我们的数据开始看起来像纯随机噪声。但对于大多数的噪声情况下,您还是可以很好地识别出数字。你认为这是最佳的吗?

模型

我们想要一个模型,它可以接收28px的噪声图像,并输出相同形状的预测。一个比较流行的选择是一个叫做UNet的架构。最初被发明用于医学图像中的分割任务,UNet由一个“压缩路径”和一个“扩展路径”组成。“压缩路径”会使通过该路径的数据被压缩,而通过“扩展路径”会将数据扩展回原始维度(类似于自动编码器)。模型中的残差连接也允许信息和梯度在不同层级之间流动。

一些UNet的设计在每个阶段都有复杂的blocks,但对于这个玩具demo,我们只会构建一个最简单的示例,它接收一个单通道图像,并通过下行路径上的三个卷积层(图和代码中的down_layers)和上行路径上的3个卷积层,在下行和上行层之间具有残差连接。我们将使用max pooling进行下采样和nn.Upsample用于上采样。某些比较复杂的UNets的设计会使用带有可学习参数的上采样和下采样layer。下面的结构图大致展示了每个layer的输出通道数:

class BasicUNet(nn.Module):
    """A minimal UNet implementation."""
    def __init__(self, in_channels=1, out_channels=1):
        super().__init__()
        self.down_layers = torch.nn.ModuleList([ 
            nn.Conv2d(in_channels, 32, kernel_size=5, padding=2),
            nn.Conv2d(32, 64, kernel_size=5, padding=2),
            nn.Conv2d(64, 64, kernel_size=5, padding=2),
        ])
        self.up_layers = torch.nn.ModuleList([
            nn.Conv2d(64, 64, kernel_size=5, padding=2),
            nn.Conv2d(64, 32, kernel_size=5, padding=2),
            nn.Conv2d(32, out_channels, kernel_size=5, padding=2), 
        ])
        self.act = nn.SiLU() # The activation function
        self.downscale = nn.MaxPool2d(2)
        self.upscale = nn.Upsample(scale_factor=2)

    def forward(self, x):
        h = []
        for i, l in enumerate(self.down_layers):
            x = self.act(l(x)) # Through the layer and the activation function
            if i < 2: # For all but the third (final) down layer:
              h.append(x) # Storing output for skip connection
              x = self.downscale(x) # Downscale ready for the next layer
              
        for i, l in enumerate(self.up_layers):
            if i > 0: # For all except the first up layer
              x = self.upscale(x) # Upscale
              x += h.pop() # Fetching stored output (skip connection)
            x = self.act(l(x)) # Through the layer and the activation function
            
        return x

我们可以验证输出shape是否如我们期望的那样与输入相同:

net = BasicUNet()
x = torch.rand(8, 1, 28, 28)
net(x).shape

模型训练

那么,模型到底应该做什么呢?同样,对这个问题有各种不同的看法,但对于这个演示,让我们选择一个简单的框架:给定一个损坏的输入noisy_x,模型应该输出它对原本x的最佳猜测。我们将通过均方误差将预测与真实值进行比较。

我们现在可以尝试训练网络了。

  • 获取一批数据
  • 添加随机噪声
  • 将数据输入模型
  • 将模型预测与干净图像进行比较,以计算loss
  • 更新模型的参数。

你可以自由进行修改来尝试获得更好的结果!

# Dataloader (you can mess with batch size)
batch_size = 128
train_dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# How many runs through the data should we do?
n_epochs = 3

# Create the network
net = BasicUNet()
net.to(device)

# Our loss finction
loss_fn = nn.MSELoss()

# The optimizer
opt = torch.optim.Adam(net.parameters(), lr=1e-3) 

# Keeping a record of the losses for later viewing
losses = []

# The training loop
for epoch in range(n_epochs):

    for x, y in train_dataloader:

        # Get some data and prepare the corrupted version
        x = x.to(device) # Data on the GPU
        noise_amount = torch.rand(x.shape[0]).to(device) # Pick random noise amounts
        noisy_x = corrupt(x, noise_amount) # Create our noisy x

        # Get the model prediction
        pred = net(noisy_x)

        # Calculate the loss
        loss = loss_fn(pred, x) # How close is the output to the true 'clean' x?

        # Backprop and update the params:
        opt.zero_grad()
        loss.backward()
        opt.step()

        # Store the loss for later
        losses.append(loss.item())

    # Print our the average of the loss values for this epoch:
    avg_loss = sum(losses[-len(train_dataloader):])/len(train_dataloader)
    print(f'Finished epoch {epoch}. Average loss for this epoch: {avg_loss:05f}')

# View the loss curve
plt.plot(losses)
plt.ylim(0, 0.1);

我们可以尝试通过抓取一批数据,以不同的数量损坏数据,然后喂进模型获得预测来观察结果:

你可以看到,对于较低的噪声水平数量,预测的结果相当不错!但是,当噪声水平非常高时,模型能够获得的信息就开始逐渐减少。而当我们达到amount = 1时,模型会输出一个模糊的预测,该预测会很接近数据集的平均值。模型通过这样的方式来猜测原始输入。

推理

如果我们在高噪声水平下的预测不是很好,我们如何才能生成图像呢?

如果我们从完全随机的噪声开始,检查一下模型预测的结果,然后只朝着预测方向移动一小部分,比如说20%。现在我们有一个噪声很多的图像,其中可能隐藏了一些关于输入数据的结构的提示,我们可以将其输入到模型中以获得新的预测。希望这个新的预测比第一个稍微好一点(因为我们这一次的输入稍微减少了一点噪声),所以我们可以用这个新的更好的预测再往前迈出一小步。

如果一切顺利的话,以上过程重复几次以后我们就会得到一个新的图像!以下图例是迭代了五次以后的结果,左侧是每个阶段的模型输入的可视化,右侧则是预测的去噪图像。Note that even though the model predicts the denoised image even at step 1, we only move x part of the way there. 重复几次以后,图像的结构开始逐渐出现并得到改善,直到获得我们的最终结果为止。

结果并不是非常好,但是已经出现了一些可以被认出来的数字!您可以尝试训练更长时间(例如,10或20个epoch),并调整模型配置、学习率、优化器等。此外,如果您想尝试稍微困难一点的数据集,您可以尝试一下fashionMNIST,只需要一行代码的替换就可以了。

原链接

https://colab.research.google.com/github/darcula1993/diffusion-models-class-CN/blob/main/unit1/02_diffusion_models_from_scratch_CN.ipynb#scrollTo=1mr5W7P-Ybhq


http://www.kler.cn/news/318714.html

相关文章:

  • ★ C++进阶篇 ★ 二叉搜索树
  • service 命令:管理系统服务
  • AI学习指南深度学习篇-Adagrad超参数调优与性能优化
  • C语言 | Leetcode C语言题解之第435题无重叠区间
  • 编译原理3——词法分析
  • Pytest-如何将allure报告发布至公司内网
  • 微生物多样性数据的可视化技巧
  • 新能源汽车数据大全(产销数据\充电桩\专利等)
  • brpc之io事件分发器
  • 【会议征稿通知】第三届图像处理、计算机视觉与机器学习国际学术会议(ICICML 2024)
  • Java使用Map数据结构配合函数式接口存储方法引用
  • 洛谷P2571.传送带
  • request库的使用 | get请求
  • 微软Active Directory:组织身份与访问管理的基石
  • 字符串——String
  • 量子计算如何引发第四次工业革命——解读加来道雄的量子物理观
  • Android平台使用VIA创建语音交互应用
  • 【ArcGIS微课1000例】0122:经纬网、方里网、参考格网绘制案例教程
  • 0基础带你入门Linux之使用
  • 初识C#(一)
  • 2.以太网
  • 毕设基于SSM+Vue3实现设备维修管理系统四:后台框架及基础增删改查功能实现
  • 常见区块链数据模型介绍
  • [leetcode]113_路径总和II_输出所有路径
  • 【AIGC】ChatGPT RAG提取文档内容,高效制作PPT、论文
  • day-59 四数之和
  • 【React】(推荐项目)使用 React、Socket.io、Nodejs、Redux-Toolkit、MongoDB 构建聊天应用程序 (2024)
  • 数据集-目标检测系列-海洋鱼类检测数据集 fish>> DataBall
  • short-link笔记
  • ubuntu 24.04 输入设备显示没有,系统没有找到电脑麦克风