当前位置: 首页 > article >正文

深度生成模型(四)——VAE 简单项目实战 VAE on CelebA

用 VAE 做一个简单的人脸图像生成任务
使用 PyTorch 训练一个基于 VAE 的模型,对 CelebA 数据集进行训练,并生成新的人脸图像
项目开源地址:VAE-on-CelebA

目录

  • 1 README
    • 1.1 特性
    • 1.2 环境安装
    • 1.3 数据准备
    • 1.4 运行项目
  • 2 更多分析演示
    • 2.1 目录结构
    • 2.2 模型训练
      • 2.2.1 训练日志分析
      • 2.2.2 结果查看与可视化
    • 2.3 新头像生成
  • 3 模型分析
    • 3.1 编码器
      • 3.1.1 参数说明
      • 3.1.2 设计逻辑
      • 3.1.3小结
    • 3.2 解码器
      • 3.2.1 参数说明
      • 3.2.2 工作原理
      • 3.2.3小结

1 README

本项目使用 PyTorch 对 CelebA 数据集进行训练,构建一个简单的 Variational Autoencoder (VAE),并生成新的头像图像

1.1 特性

  • 使用自定义的 VAE 模型 (PyTorch)
  • 支持对 CelebA 数据进行裁剪/缩放等预处理
  • 训练后可直接从先验分布采样生成新的人脸图像

1.2 环境安装

  1. 克隆本仓库:

    git clone https://github.com/YemuRiven/VAE-on-CelebA.git
    cd vae-celeba
    
  2. 安装依赖 (Conda 或 pip 方式均可):

    conda env create -f environment.yml
    conda activate vae-celeba
    

1.3 数据准备

  1. 从 CelebA 官方地址 下载数据集,放置到 data/ 文件夹下,结构类似:
data/CelebA
   ├── Anno        (标注信息)
   ├── Eval        (验证/测试信息)
   └── Img
      └── img_align_celeba
         ├── 000001.jpg
         ├── 000002.jpg
         ├── ...
  1. main.py 中配置数据集路径

1.4 运行项目

python main.py \
  --data_path ./data/CelebA/Img \
  --epochs 10 \
  --batch_size 64 \
  --lr 0.001 \
  --image_size 64 \
  --latent_dim 128 \
  --out_dir ./outputs

2 更多分析演示

2.1 目录结构

项目目录结构以及脚本作用分别如下:

vae-celeba
├── README.md
├── environment.yml         # Conda 环境文件
├── main.py                 # 核心训练脚本,包含训练循环、损失函数定义、以及推理生成示例
├── generate.py         # 使用解码器生成新的头像
├── models
	│── vae.py             # 定义了一个简单的 VAE 模型,包括编码器、解码器和重参数化等关键逻辑
├── utils
	│── dataset.py         # 主要负责数据集的加载与预处理操作,如裁剪、缩放、归一化等
├── outputs                # 训练输出目录(保存模型、生成图像等)

在这里插入图片描述

2.2 模型训练

训练过程:控制台会输出每个 epoch 的损失、重构误差、KL 散度等信息

Alt

2.2.1 训练日志分析

  1. Loss
    总体损失 Loss = 重构损失 Recon + KL 散度 KL
    可以看到,从最初的 231.5161 下降到最后的 171.7694,说明模型的整体目标函数在逐步被优化
  2. Recon
    重构损失主要衡量模型生成图像与真实图像的差异(MSE 或 BCE 等)
    从 171.2477 下降到 113.7698,说明模型的重构能力在增强,输出图像和原图的差异越来越小
  3. KL
    KL 散度用来约束编码器输出的隐变量分布与先验(通常是标准正态分布)保持一致
    一般来说,KL 在训练早期会比较高,随着模型学习,会逐步下降或维持在一个平衡值
    KL 下降到 64.1615,说明模型在重构与正则化之间取得了一定平衡。在 VAE 中,KL 散度不一定会无限下降,保持在一个合理范围内即可

2.2.2 结果查看与可视化

  1. 查看生成图像
    代码设计了 sample_images() 函数,每个 Epoch 都会采样潜在变量并生成图像,保存为 outputs/epoch_{epoch}_samples.png
  2. 模型保存
    最终的模型权重会保存为 outputs/vae_celebA.pth,可用于后续的推理或微调
  3. 更多可视化
    除了在每个 Epoch 保存的示例图像,也写了一个单独的推理脚本 generate.py,反复采样不同的潜在向量 z,生成更多图像,观察其多样性和质量
    Alt

2.3 新头像生成

在训练结束后,VAE 已经学到了一种对人脸图像分布的近似表示,并在隐空间(例如 128 维标准正态分布)中对人脸的潜在特征进行建模。只要从该分布中随机采样一些隐变量 z,再利用已经训练好的解码器,就可以生成新的 64×64 尺寸头像。这些头像并不在训练集中出现过,而是通过模型对人脸特征分布的学习生成出来的
运行 generate.py 文件:
在这里插入图片描述
生成新的头像:
Alt

3 模型分析

3.1 编码器

在 vae-celeba 中,编码器的网络结构(见下表或对应代码)可以分为 三层卷积 + Flatten,然后输出到两条全连接分支(均值与对数方差)。其流程如下所示:

层 (类型)输入形状输出形状参数说明作用
Conv2d(nc=3, 32, k=4, s=2, p=1)B×3×64×64B×32×32×32卷积核=4×4, 步幅=2, 填充=1将输入通道从3扩展到32
ReLUB×32×32×32B×32×32×32-非线性激活,增强表达能力
Conv2d(32, 64, k=4, s=2, p=1)B×32×32×32B×64×16×16卷积核=4×4, 步幅=2, 填充=1进一步提取更深层次的特征
ReLUB×64×16×16B×64×16×16-同上
Conv2d(64, 128, k=4, s=2, p=1)B×64×16×16B×128×8×8卷积核=4×4, 步幅=2, 填充=1最后一层卷积,增大通道数
ReLUB×128×8×8B×128×8×8-同上
FlattenB×128×8×8B×(128×8×8)-拉直为全连接层的输入 (8192维)
fc_mu (Linear)B×8192B×latent_dim (128)线性层 (8192→128)输出均值 μ
fc_logvar (Linear)B×8192B×latent_dim (128)线性层 (8192→128)输出对数方差 logσ²

3.1.1 参数说明

  • B 代表批大小 (batch size)
  • nc=3 表示输入图像的通道数为 3(RGB)
  • k=4, s=2, p=1 分别表示卷积核大小 (kernel_size=4)、步幅 (stride=2) 和填充 (padding=1)
  • “Flatten” 将二维特征图拉直为一维向量,以便进入全连接层

在代码中,这部分网络被写成 self.encoder(三层卷积 + Flatten),再通过 fc_mu 与 fc_logvar 两个全连接层分别输出均值和对数方差

3.1.2 设计逻辑

  1. 卷积核大小与步幅
    每层卷积核都设为 4、步幅设为 2、填充设为 1,这样每经过一层卷积,特征图的宽高都减半。64 → 32 → 16 → 8,最终 Flatten 得到 128×8×8=8192 维
  2. 激活函数
    采用 ReLU 作为非线性激活函数,简洁且常用于卷积网络。如果想要更快收敛或更稳定,可以考虑换成 LeakyReLU 或 ELU 等
  3. BatchNorm / Dropout
    此项目没有使用批归一化(BatchNorm)或 Dropout。对人脸数据而言,若出现过拟合,可以考虑在卷积层之间插入 BatchNorm 或 Dropout
  4. 潜在空间维度
    由于人脸数据分布复杂,因此设置了 latent_dim=128。如果想要更丰富的生成能力,可以提高到 256 或 512;若显存有限,则可尝试减小维度

3.1.3小结

在 vae-celeba 项目中,编码器采用了三层卷积 + Flatten 的结构,每层卷积核大小为 4、步幅为 2,激活函数为 ReLU,最终输出 latent_dim=128 维的均值与对数方差。这种设计与针对灰度小尺寸数据(如 FashionMNIST)的自编码器相比,通道数更多、卷积深度更大、潜在维度更高,以便更好地适配彩色人脸数据的复杂度。若要进一步提高生成质量,可以在卷积层中加入 BatchNorm、LeakyReLU、Dropout 或提升网络深度与潜在空间维度

3.2 解码器

在 vae-celeba 项目的解码器(Decoder)中,采用了反卷积(转置卷积)来从潜在空间采样得到的特征向量恢复到原始分辨率的人脸图像。具体来说,先通过一个全连接层(decoder_input)将潜在向量映射到二维特征图,再依次通过多层反卷积逐步上采样,从 8×8 还原到 64×64。解码器的网络结构如下表所示,列出了解码器各层的类型、输入形状、输出形状以及参数量

层 (类型)输入形状输出形状参数说明作用
decoder_input (Linear)B×latent_dim (128)B×8192线性层 (128→8192)将潜在向量映射到更高维度
View/ReshapeB×8192B×128×8×8-将一维向量重塑为二维特征图
ConvTranspose2d(128→64, k=4, s=2, p=1)B×128×8×8B×64×16×16反卷积核=4×4, 步幅=2, 填充=1上采样,扩大特征图尺寸
ReLUB×64×16×16B×64×16×16-激活函数,提供非线性能力
ConvTranspose2d(64→32, k=4, s=2, p=1)B×64×16×16B×32×32×32反卷积核=4×4, 步幅=2, 填充=1进一步上采样
ReLUB×32×32×32B×32×32×32-同上
ConvTranspose2d(32→3, k=4, s=2, p=1)B×32×32×32B×3×64×64反卷积核=4×4, 步幅=2, 填充=1还原到 RGB 通道 (3) 的原始尺寸 (64×64)
SigmoidB×3×64×64B×3×64×64-输出范围限制在 [0,1]

3.2.1 参数说明

  • B 表示批大小 (batch size)
  • k=4, s=2, p=1 分别代表卷积核大小 (kernel_size=4)、步幅 (stride=2)、填充 (padding=1)
  • 参数量为近似计算值,实际可使用框架的 model.summary() 或 torchinfo 等工具自动统计
  • latent_dim (128) 表示潜在向量维度可根据需求调整(如 64、128、256 等)
  • “ConvTranspose2d” 用于逆卷积或上采样,将特征图逐层还原到原始图像大小
  • 最后一层 Sigmoid 将像素值限制在 [0,1] 区间,以配合 MSE 或 BCE 损失进行重构

3.2.2 工作原理

  1. 潜在向量映射 (Linear + Reshape)
    解码器首先接收来自编码器的潜在向量 z(维度为 latent_dim),通过一个全连接层将其映射到更高维度(8192),再 Reshape 成 (128,8,8) 的特征图,作为后续转置卷积的输入
  2. 反卷积上采样 (ConvTranspose2d)
    反卷积层(又称转置卷积)可以理解为卷积操作的“逆过程”,其作用是逐步还原空间分辨率:
    第一层将 (128,8,8) 上采样到 (64,16,16),
    第二层将 (64,16,16) 上采样到 (32,32,32),
    最后一层将 (32,32,32) 上采样到 (3,64,64),恢复到与原始人脸图像相同的尺寸与通道数(RGB)
  3. 激活函数 (ReLU / Sigmoid)
    反卷积层之间使用 ReLU 激活,为网络提供非线性能力。
    最后一层使用 Sigmoid,将输出像素值限制在 [0,1] 区间,适合与 MSE 或 BCE 重构损失配合使用

3.2.3小结

在 vae-celeba 项目的解码器设计中,通过线性映射 + 多层反卷积逐步上采样到原始图像分辨率,并在输出层使用 Sigmoid 激活,以确保生成像素值在 [0,1] 区间。与针对灰度小分辨率数据(如 FashionMNIST)的模型相比,该解码器需要更大的卷积通道深度来处理彩色人脸图像的丰富细节。通过这种结构设置,VAE 可以在潜在空间采样出多样化的向量 z,并解码为不同风格和特征的人脸图像


http://www.kler.cn/a/572795.html

相关文章:

  • 06 HarmonyOS Next性能优化之LazyForEach 列表渲染基础与实现详解 (一)
  • Pytorch的一小步,昇腾芯片的一大步
  • 演示汉字笔顺的工具
  • 构建一个Django的应用程序
  • MATLAB仿真:涡旋光束光强和相位分布同时展示
  • 图漾PercipioIPTool软件使用
  • setlocale()的参数,“zh_CN.UTF-8“, “chs“, “chinese-simplified“的差异。
  • 人工智能神经网络基本原理
  • STM32---FreeRTOS中断管理试验
  • KIKKKKKKK::::::::::::::
  • MR 1. 孟德尔随机化在生物医学研究中的应用概述
  • 探秘鸿蒙 HarmonyOS NEXT:权限申请策略指南
  • Linux网络 NAT、代理服务、内网穿透
  • c语言中的主要知识点
  • Qt:事件
  • 大模型在呼吸衰竭预测及围手术期方案制定中的应用研究
  • C语言-一维数组及综合案例
  • 鸿蒙NEXT开发-端云一体化开发概念开发准备
  • mysql下载与安装
  • SpringMVC控制器定义:@Controller注解详解