当前位置: 首页 > article >正文

手撕Diffusion系列 - 第四期 - Diffusion前向扩散

手撕Diffusion系列 - 第四期 - Diffusion前向扩散

目录

  • 手撕Diffusion系列 - 第四期 - Diffusion前向扩散
  • DDPM 原理图
  • DDPM 前向扩散介绍
  • Diffusion(前向扩散) 代码
    • Part1 引入相关库函数
    • Part2 去噪的一些参数初始化
    • Part3 定义前向传播函数
    • Part4 测试
  • 参考

DDPM 原理图

​ DDPM包括两个过程:前向过程(forward process)反向过程(reverse process),其中前向过程又称为扩散过程(diffusion process),如下图所示。无论是前向过程还是反向过程都是一个参数化的马尔可夫链(Markov chain),其中反向过程可以用来生成图片。

在这里插入图片描述

DDPM 整体大概流程

​ 图中,由高斯随机噪声 x T x_T xT 生成原始图片 x 0 x_0 x0 为反向过程,反之为前向过程(噪音扩散)。

DDPM 前向扩散介绍

一句话概括,前向过程就是对原始图片 x 0 x_0 x0 不断加高斯噪声最后生成随机噪声 x T x_T xT 的过程,如下图所示

在这里插入图片描述

Diffusion前向传播过程(加噪)

前向传播的过程中,上个时刻图像( x t − 1 x_{t−1} xt1)到下个时刻图像( x t x_{t} xt)的加噪传播函数如下所示:
x t = α t x t − 1 + 1 − α t ϵ t − 1 x_t=\sqrt{\alpha_t}x_{t-1} + \sqrt{1-\alpha_t}\epsilon_{t-1} xt=αt xt1+1αt ϵt1
其中 α t \alpha_t αt 是一个很小值的超参数, ϵ t − 1 ∼ N ( 0 , 1 ) \epsilon_{t-1}∼N(0,1) ϵt1N(0,1) 是高斯噪声。由公式推导,最终可以得到 x 0 x0 x0 x t xt xt 的公式,表示如下:
x t = α t ˉ x 0 + 1 − α t ˉ ϵ x_t=\sqrt{\bar{\alpha_t}}x_{0} + \sqrt{1-\bar{\alpha_t}}\epsilon xt=αtˉ x0+1αtˉ ϵ
其中 α t ˉ = ∏ i = 1 t α i , ϵ ∼ N ( 0 , 1 ) \bar{\alpha_{t}} = \prod_{i=1}^t\alpha_i,\epsilon∼N(0,1) αtˉ=i=1tαiϵN(0,1) 也是一个高斯噪声,利用这个公式,我们可以随机生成t,然后利用 x 0 x_0 x0生成 x t x_t xt

Diffusion(前向扩散) 代码

Part1 引入相关库函数

# 该模块实现的过程是利用函数输入批次的图像和批次的t,对图像利用公式进行正向扩散,得到批次的噪声图,
# 注意,这里实现的时候不是类,因为这个严格来说不算网络结构,也就是不在计算图的构建范围内,只是构造训练数据集的其中一个处理过程,所以写成函数即可。

'''
# Part1 引入相关的库函数
'''
from config import * # 一些基础参数
from dataset import minist_train, minist_loader,TenosrtoPil_action # 数据的测试
import torch # 用一系列的数据处理
import matplotlib.pyplot as plt # 用于绘图

Part2 去噪的一些参数初始化

'''
# Part2 初始化一些beta和alpha参数用于函数中的加噪过程。
'''

# 首先获取Beta_t,主要是T步里面的参数,这个一般是直接定好的,(left,right)中取样T次(单调增加的)
beta_t=torch.linspace(start=0.0001,end=0.02,steps=T) # (T,) # 单维度tensor

# 然后是获取alpha_t ,因为和beta_t相加为1,所以直接和1相减 (单调递减)
alpha_t=1-beta_t # (T,)

# 然后需要得到alpha_bar,这个是累乘得到的
alpha_bar=torch.cumprod(alpha_t,dim=-1) # alpha_t累乘 (T,)    [a1,a2,a3,....] ->  [a1,a1*a2,a1*a2*a3,.....]

Part3 定义前向传播函数

def forward_diffusion(batch_x0,batch_t): # (batch_size,chanal,imag_sie,imag_size) , (batch_size,)
    # 第一步首先要获取,整个batch图像的,t-1时刻的噪声。
    batch_noise_t=torch.randn_like(input=batch_x0) # (batch_size,chanal,imag_sie,imag_size),默认是标准正态分布
    # 首先需要利用batch_t,从alpha_bar里面取出对应的t,并且为了便于广播机制,需要对batch_t进行形状的转换,至少保持同纬度
    alpha_bar_t=alpha_bar[batch_t].reshape(batch_t.size()[0],1,1,1)
    # 计算得到噪声后的图
    batch_xt=torch.sqrt(alpha_bar_t)*batch_x0+torch.sqrt(1-alpha_bar_t)*batch_noise_t
    return batch_xt,batch_noise_t

Part4 测试

if __name__ == '__main__':
    batch_x = next(iter(minist_loader))[0]  # 2个图片拼batch, (2,1,48,48)

    # 加噪前的样子
    plt.figure(figsize=(10, 10))
    plt.subplot(1, 2, 1)
    plt.imshow(TenosrtoPil_action(batch_x[0]))
    plt.subplot(1, 2, 2)
    plt.imshow(TenosrtoPil_action(batch_x[1]))
    plt.show()
    # 如果需要将噪声(通常是从[-1, 1]范围生成的)加到图像上,你需要将图像数据重新缩放到[-1, 1]范围,以便它与噪声匹配,能够平衡噪声对训练的影响。
    # 虽然正态分布不是严格[-1,1]之间,但是通过三sigema定理,我们这里可以初略定在一个sigema之间。
    # 总之就是没有偏差
    batch_x = batch_x * 2 - 1  # [0,1]像素值调整到[-1,1]之间,以便与高斯噪音值范围匹配
    batch_t = torch.randint(0, T, size=(batch_x.size(0),))  # 每张图片随机生成diffusion步数
    # batch_t=torch.tensor([5,100],dtype=torch.long)
    print('batch_t:', batch_t)

    batch_x_t, batch_noise_t = forward_diffusion(batch_x, batch_t)
    print('batch_x_t:', batch_x_t.size())
    print('batch_noise_t:', batch_noise_t.size())

    # 加噪后的样子
    plt.figure(figsize=(10, 10))
    plt.subplot(1, 2, 1)
    plt.imshow(TenosrtoPil_action((batch_x_t[0] + 1) / 2)) # 返回原来的图像范围
    plt.subplot(1, 2, 2)
    plt.imshow(TenosrtoPil_action((batch_x_t[1] + 1) / 2))
    plt.show()

参考

视频讲解:diffusion前向扩散_哔哩哔哩_bilibili

原理博客:手撕Diffusion系列 - 第一期 - DDPM原理-CSDN博客


http://www.kler.cn/a/514774.html

相关文章:

  • QT:QTabWidget设置tabPosition为West时,文字向上
  • 基于 Spring Boot 和 Vue.js 的全栈购物平台开发实践
  • thinkphp8在使用apidoc时, 4层的接口会有问题 解决办法
  • NextJs - ServerAction获取文件并处理Excel
  • 在centos上编译安装opensips【初级-默认安装】
  • TCP断开通信前的四次挥手(为啥不是三次?)
  • grafana新增email告警
  • .net 项目引用与 .NET Framework 项目引用之间的区别和相同
  • React的响应式
  • Python聚合的概念与实现
  • 告别繁琐的Try-Catch!优雅的异常处理解决方案
  • 2024电赛H题参考方案(+视频演示+核心控制代码)——自动行驶小车
  • Java多线程中Condition类的详细介绍、应用场景和示例代码
  • 大模型GUI系列论文阅读 DAY2续2:《使用指令微调基础模型的多模态网页导航》
  • leetcode136.寻找重复数
  • Python开发GUI的方法汇总
  • 基于springboot实验室信息管理系统
  • 2024:成长和学习之旅
  • 改写中断例程,用中断响应外设
  • 专业又简单:Geotiff文件转Cesium影像切片教程
  • stm32 L051 adc配置及代码实例解析
  • 2025-01学习笔记
  • 物联网常见的传感器和执行器-带表格整理
  • 多线程之旅:开启多线程安全之门的钥匙
  • 如何使用CRM数据分析优化销售和客户关系?
  • 【搞机】GMK-G3因特尔n100处理器核显直通win10虚拟机