当前位置: 首页 > article >正文

【深度学习】常见模型-生成对抗网络(Generative Adversarial Network, GAN)

生成对抗网络(Generative Adversarial Network, GAN)是一种深度学习模型框架,由 Ian Goodfellow 等人在 2014 年提出。GAN 由 生成器(Generator)判别器(Discriminator) 两个对抗网络组成,通过彼此博弈的方式训练,从而生成与真实数据分布极为相似的高质量数据。GAN 在图像生成、文本生成、数据增强等领域中有广泛应用。


核心思想

GAN 的核心是两个神经网络之间的对抗:

  1. 生成器(Generator)

    • 输入随机噪声,生成“假数据”。
    • 学习目标是欺骗判别器,让生成的数据尽可能接近真实数据。
  2. 判别器(Discriminator)

    • 接收真实数据和生成器生成的数据,判断数据是真实的还是伪造的。
    • 学习目标是准确区分真实数据和伪造数据。

两者通过对抗学习(min-max 游戏)达到一个动态平衡,使得生成器生成的数据逐渐逼近真实数据分布。


数学公式

GAN 的目标是通过以下损失函数进行优化:

\min_G \max_D V(D, G) = \mathbb{E}_{x \sim p_{\text{data}}(x)}[\log D(x)] + \mathbb{E}_{z \sim p_z(z)}[\log(1 - D(G(z)))]G

其中:

  • G(z):生成器的输出,输入随机噪声 z。
  • D(x):判别器的输出,表示数据 x 为真实数据的概率。
  • p_{\text{data}}(x):真实数据的分布。
  • p_z(z):随机噪声的分布(通常为正态分布或均匀分布)。

生成器 G 的目标是最小化 D(G(z)) 的概率(使判别器认为生成数据是真实的),而判别器 D 的目标是最大化其正确判断的概率。


GAN 的训练过程

  1. 初始化模型:随机初始化生成器和判别器的参数。

  2. 训练判别器 D

    • 从真实数据中抽取样本,计算 D(x)。
    • 从生成器 G 生成伪造样本 G(z),计算 D(G(z))。
    • 优化 D 的目标函数,使得其能够正确区分真实和伪造数据。
  3. 训练生成器 G

    • 生成伪造数据 G(z)。
    • 优化生成器的目标函数,使得 D(G(z)) 尽可能接近真实数据。
  4. 重复以上过程,直到生成数据的质量达到目标。


代码实现

以下是一个使用 TensorFlow/Keras 实现简单 GAN 的代码示例:

import numpy as np
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, LeakyReLU, Flatten, Reshape

# 创建生成器
def build_generator(latent_dim):
    model = Sequential([
        Dense(128, input_dim=latent_dim),
        LeakyReLU(alpha=0.2),
        Dense(256),
        LeakyReLU(alpha=0.2),
        Dense(512),
        LeakyReLU(alpha=0.2),
        Dense(784, activation='tanh'),
        Reshape((28, 28, 1))
    ])
    return model

# 创建判别器
def build_discriminator():
    model = Sequential([
        Flatten(input_shape=(28, 28, 1)),
        Dense(512),
        LeakyReLU(alpha=0.2),
        Dense(256),
        LeakyReLU(alpha=0.2),
        Dense(1, activation='sigmoid')
    ])
    return model

# 定义损失函数和优化器
latent_dim = 100
generator = build_generator(latent_dim)
discriminator = build_discriminator()
discriminator.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])

# 构建 GAN 模型
discriminator.trainable = False
gan = Sequential([generator, discriminator])
gan.compile(optimizer='adam', loss='binary_crossentropy')

# 数据准备(MNIST 数据集)
(X_train, _), (_, _) = tf.keras.datasets.mnist.load_data()
X_train = X_train / 127.5 - 1.0  # 归一化到 [-1, 1]
X_train = np.expand_dims(X_train, axis=-1)

# 训练 GAN
batch_size = 64
epochs = 10000
for epoch in range(epochs):
    # 随机选取真实样本
    idx = np.random.randint(0, X_train.shape[0], batch_size)
    real_imgs = X_train[idx]

    # 生成伪造样本
    noise = np.random.normal(0, 1, (batch_size, latent_dim))
    fake_imgs = generator.predict(noise)

    # 训练判别器
    real_y = np.ones((batch_size, 1))
    fake_y = np.zeros((batch_size, 1))
    d_loss_real = discriminator.train_on_batch(real_imgs, real_y)
    d_loss_fake = discriminator.train_on_batch(fake_imgs, fake_y)

    # 训练生成器
    g_loss = gan.train_on_batch(noise, real_y)

    # 输出训练进度
    if epoch % 1000 == 0:
        print(f"Epoch {epoch}: D Loss Real: {d_loss_real[0]}, D Loss Fake: {d_loss_fake[0]}, G Loss: {g_loss}")

运行结果

2/2 [==============================] - 0s 2ms/step
Epoch 0: D Loss Real: 0.829259991645813, D Loss Fake: 0.6967335343360901, G Loss: 0.8764752149581909
2/2 [==============================] - 0s 3ms/step
2/2 [==============================] - 0s 2ms/step
2/2 [==============================] - 0s 3ms/step
2/2 [==============================] - 0s 2ms/step
2/2 [==============================] - 0s 2ms/step
......

应用领域

  1. 图像生成
    • 生成高清人脸图像(如 StyleGAN)。
    • 图像上色、去噪等。
  2. 文本生成
    • 生成自然语言文本。
  3. 数据增强
    • 增强数据集的多样性,特别是在样本稀缺的情况下。
  4. 艺术创作
    • 生成艺术画作、音乐等。

优缺点

优点
  • 能够生成逼真的高维数据。
  • 应用范围广泛,尤其在生成任务中。
缺点
  • 训练不稳定:可能出现模式崩溃(Mode Collapse)。
  • 训练复杂:需要调参并找到生成器和判别器的平衡。
  • 难以评估:生成数据的质量通常需要人工判断。

扩展与变体

  1. DCGAN(Deep Convolutional GAN):引入卷积层,提高图像生成的质量。
  2. WGAN(Wasserstein GAN):优化损失函数,解决模式崩溃问题。
  3. CycleGAN:用于图像到图像的转换(如风格迁移)。
  4. StyleGAN:高质量图像生成技术,支持对生成图像样式的控制。

GAN 的对抗思想极具创新性,为生成任务提供了一种全新的解决方案,是深度学习领域的里程碑技术之一。


http://www.kler.cn/a/518039.html

相关文章:

  • 代码随想录——二叉树(二)
  • 【ArcGIS微课1000例】0141:提取多波段影像中的单个波段
  • 03-机器学习-数据获取
  • Android GLSurfaceView 覆盖其它控件问题 (RK平台)
  • openlava/LSF 用户组管理脚本
  • SocketCAN
  • 【优选算法】10----无重复字符的最长子串
  • Vue.js组件开发-如何实现带有搜索功能的下拉框
  • CASAIM与友达光电达成深度合作,CASAIM IS自动化蓝光测量技术为创新显示技术发展注入新的活力
  • Poetry shell --> poetry-plugin-shell
  • Hnu电子电路实验4
  • 基于数智立体化V2.0体系构建医疗综合智能体:理论、实践与展望
  • C语言内存管理详解
  • LKT4304新一代算法移植加密芯片,守护 物联网设备和云服务安全
  • leetcode——最大子数组和(java)
  • 15.7k!DISM++一款快捷的系统优化工具
  • 使用RocketMQ 的业务系统怎么处理消息的积压?
  • kafka-保姆级配置说明(broker)
  • 计算机视觉-卷积
  • Qt调用ffmpeg库实现简易视频播放器示例
  • 嵌入式音视频开发——视频篇(三)
  • 如何在Linux中找到MySQL的安装目录
  • python实现http文件服务器访问下载
  • YOLOv11改进,YOLOv11添加ASFF检测头,并添加小目标检测层(四头检测),适合目标检测、分割等任务,全网首发
  • 微信小程序云开发服务端存储API 从云存储空间删除文件
  • DeepSeek R1 模型详解与微调