当前位置: 首页 > article >正文

VAE的原理及MNIST数据生成

⭐️ 变分自编码器(VAE)

变分自编码器(Variational Autoencoder,VAE)是一种生成模型,通过学习训练数据的概率分布,可以生成与训练数据分布相似的新样本。与经典的自编码器不同,VAE 的目标不仅仅是学习压缩和重构数据,而是通过学习数据的潜在分布来进行概率建模,适合图像生成、异常检测、数据缺失填补等任务。本文将深入探讨 VAE 的原理和实现步骤。

在这里插入图片描述

⭐️ VAE 的基本原理

VAE 是一种概率生成模型。与传统的自编码器类似,VAE 也包含一个编码器和一个解码器:

  • 编码器:将输入数据(如图像)映射到一个潜在空间的概率分布中。
  • 解码器:从潜在空间中采样的点生成新的数据。

VAE 的关键在于引入了 概率分布的思想,即假设潜在空间中的数据服从某种分布(通常是高斯分布),并通过学习这个分布来对输入数据进行生成建模。

在生成模型的框架下,VAE 的目标是找到一组参数,使得模型生成的样本分布尽可能接近训练数据的分布。这需要解决的问题是如何从潜在空间中采样数据并计算样本的重构误差。


⭐️ 编码器与解码器结构

VAE 的编码器和解码器结构如下:

  • 编码器:将输入 x x x 映射到潜在空间的均值 μ ( x ) \mu(x) μ(x) 和标准差 σ ( x ) \sigma(x) σ(x)。这样,我们就可以从正态分布中采样出一个潜在变量 z z z
    z ∼ N ( μ ( x ) , σ ( x ) 2 ) z \sim \mathcal{N}(\mu(x), \sigma(x)^2) zN(μ(x),σ(x)2)

  • 解码器:从潜在变量 z z z 生成重构的样本 x ′ x' x,即 p ( x ∣ z ) p(x|z) p(xz)。解码器的目标是让生成的 x ′ x' x 尽量接近原始输入 x x x


⭐️ VAE 的损失函数

VAE 的损失函数由两部分组成:

  1. 重构误差:度量重构数据与原始数据之间的相似度。通常使用二元交叉熵或均方误差来计算。
    Reconstruction Loss = − E q ( z ∣ x ) [ log ⁡ p ( x ∣ z ) ] \text{Reconstruction Loss} = -\mathbb{E}_{q(z|x)}[\log p(x|z)] Reconstruction Loss=Eq(zx)[logp(xz)]

  2. KL 散度:表示生成分布和标准正态分布之间的差异。它将潜在变量的分布约束为标准正态分布,从而使采样出的点在潜在空间上形成连续分布,能够生成平滑的图像。
    KL Divergence = D K L ( q ( z ∣ x ) ∣ ∣ p ( z ) ) \text{KL Divergence} = D_{KL}(q(z|x) || p(z)) KL Divergence=DKL(q(zx)∣∣p(z))
    其中 D K L D_{KL} DKL 是 KL 散度。

综上,VAE 的总损失可以表示为:
L = Reconstruction Loss + KL Divergence \mathcal{L} = \text{Reconstruction Loss} + \text{KL Divergence} L=Reconstruction Loss+KL Divergence


⭐️ 重参数化技巧

VAE 的训练难点在于采样过程不具有可微性。为了解决这个问题,引入了 重参数化技巧

  1. 通过编码器输出均值 μ \mu μ 和方差 σ \sigma σ,得到一个标准正态分布的噪声变量 ϵ \epsilon ϵ
  2. 使用公式 z = μ + σ ⋅ ϵ z = \mu + \sigma \cdot \epsilon z=μ+σϵ 得到潜在变量 z z z

这样,我们可以通过反向传播来训练模型,因为这个过程可以微分。


⭐️ 训练VAE并随机生成MNIST数据

代码如下

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.utils import make_grid
import matplotlib.pyplot as plt


# 定义VAE模型
class VAE(nn.Module):

    def __init__(self, latent_dim=20):
        super(VAE, self).__init__()
        # 编码器部分
        self.encoder = nn.Sequential(
            nn.Linear(784, 400),
            nn.ReLU(),
            nn.Linear(400, 2 * latent_dim)  # 输出均值和对数方差
        )
        # 解码器部分
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, 400),
            nn.ReLU(),
            nn.Linear(400, 784),
            nn.Sigmoid()  # 将输出值约束到0-1之间
        )
        self.latent_dim = latent_dim

    def encode(self, x):
        h = self.encoder(x)
        mu, log_var = h.chunk(2, dim=-1)  # 分割成均值和对数方差
        return mu, log_var

    # 重参数化
    def reparameterize(self, mu, log_var):
        std = torch.exp(0.5 * log_var)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        return self.decoder(z)

    # 前向传播
    def forward(self, x):
        mu, log_var = self.encode(x.view(-1, 784))
        z = self.reparameterize(mu, log_var)
        return self.decode(z), mu, log_var


# 定义VAE的损失函数
def vae_loss(recon_x, x, mu, log_var):
    # 重构损失:二元交叉熵
    recon_loss = nn.functional.binary_cross_entropy(recon_x,
                                                    x.view(-1, 784),
                                                    reduction='sum')
    # KL 散度
    kl_divergence = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
    return recon_loss + kl_divergence


# 数据加载
transform = transforms.ToTensor()
train_dataset = datasets.MNIST(root='./data',
                               train=True,
                               transform=transform,
                               download=True)
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)

# 初始化模型和优化器
latent_dim = 20
vae = VAE(latent_dim=latent_dim).to('cuda')
optimizer = optim.Adam(vae.parameters(), lr=1e-3)

# 训练VAE模型,这里训练100轮
epochs = 100
vae.train()
for epoch in range(epochs):
    train_loss = 0
    for x, _ in train_loader:
        x = x.to('cuda')
        optimizer.zero_grad()
        recon_x, mu, log_var = vae(x)
        loss = vae_loss(recon_x, x, mu, log_var)
        loss.backward()
        train_loss += loss.item()
        optimizer.step()
    print(f"Epoch {epoch + 1}, Loss: {train_loss / len(train_loader.dataset):.4f}")


# 保存模型
torch.save(vae.state_dict(), 'vae_gen_mnist_image.pth')

# 生成图像
vae.eval()
with torch.no_grad():
    z = torch.randn(64, latent_dim).to('cuda')  # 生成64个随机latent向量
    generated_images = vae.decode(z).cpu()
    generated_images = generated_images.view(-1, 1, 28, 28)  # 调整维度适应MNIST格式

    # 可视化生成的图像
    grid = make_grid(generated_images, nrow=8, padding=2, normalize=True)
    plt.figure(figsize=(8, 8))
    plt.imshow(grid.permute(1, 2, 0), cmap='gray')
    plt.axis('off')
    # plt.show()
    plt.savefig('gen_mnist.png', dpi=800)

训练100轮后的结果

在这里插入图片描述


⭐️ 总结与展望

VAE 通过概率建模在潜在空间中进行有效采样,生成数据的能力优于经典自编码器。这种方法使得 VAE 在图像生成、数据增强、异常检测和生成对抗网络的预训练等任务中表现出色。通过不断调整网络结构和损失函数,VAE 还可以扩展到其他复杂任务,如自然语言生成、音频生成等。

尽管 VAE 的生成效果可能比 GAN 略逊色,但其稳定的训练过程和概率模型框架使得 VAE 在多个领域得到了广泛应用。


http://www.kler.cn/a/389016.html

相关文章:

  • PostgreSQL中的COPY命令:高效数据导入与导出
  • [ComfyUI]Flux:繁荣生态魔盒已开启,6款LORA已来,更有MJ6写实动漫风景艺术迪士尼全套
  • #渗透测试#SRC漏洞挖掘#云技术基础02之容器与云
  • 5G 现网信令参数学习(3) - RrcSetup(1)
  • 基于STM32的智能充电桩:集成RTOS、MQTT与SQLite的先进管理系统设计思路
  • 动态规划 —— dp 问题-买卖股票的最佳时机IV
  • 【计算机网络】基础知识,常识应用知识
  • Webpack知识点—publicPath
  • JVM 进阶:深入理解与高级调优
  • YOLOv6-4.0部分代码阅读笔记-engine.py
  • Skyeye 云智能制造 v3.14.12 发布,ERP + 商城
  • 【AI技术】PaddleSpeech部署方案
  • Python实现SSA智能麻雀搜索算法优化BP神经网络分类模型(优化权重和阈值)项目实战
  • 数据结构之排序补充
  • 12.UE5朝向鼠标攻击,状态机入门
  • fabric操作canvas绘图(1)共32节
  • 计算机毕业设计Python流量检测可视化 DDos攻击流量检测与可视化分析 SDN web渗透测试系统 网络安全 信息安全 大数据毕业设计
  • Mysql COUNT() 函数详解
  • 手动实现promise的all,race,finally方法
  • 深入理解Linux内核中的虚拟文件系统(VFS)
  • Mac中禁用系统更新
  • g++/gcc版本切换
  • 传输协议设计与牧村摆动(Makimoto‘s Wave)
  • 18、论文阅读:AOD-Net:一体化除雾网络
  • 【系统架构设计师】高分论文:论企业集成平合的技术与应用
  • Linux五种IO模型和fctnl的使用