当前位置: 首页 > article >正文

生成对抗网络(GAN,Generative Adversarial Network)

生成对抗网络(GAN,Generative Adversarial Network)是一种深度学习模型,由两部分组成:生成器(Generator)和判别器(Discriminator)。生成器的目标是生成逼真的样本,而判别器的目标是区分真实样本与生成样本。它们通过对抗过程相互训练,最终使生成器能够生成高度逼真的数据。

基本概念

  1. 生成器:从随机噪声(通常是高斯噪声)生成数据,表示为 G ( z ) G(z) G(z),其中 z z z 是潜在变量(噪声)。

  2. 判别器:判断输入数据是否真实,表示为 D ( x ) D(x) D(x),其中 x x x 是输入数据。判别器输出一个值,表示其对输入数据为真实的概率。

目标函数

GAN 的目标是通过最小化以下对抗损失来训练生成器和判别器:

min ⁡ G max ⁡ D V ( D , G ) = E x ∼ p d a t a ( x ) [ log ⁡ D ( x ) ] + E z ∼ p z ( z ) [ log ⁡ ( 1 − D ( G ( z ) ) ) ] \min_G \max_D V(D, G) = \mathbb{E}_{x \sim p_{data}(x)}[\log D(x)] + \mathbb{E}_{z \sim p_z(z)}[\log(1 - D(G(z)))] GminDmaxV(D,G)=Expdata(x)[logD(x)]+Ezpz(z)[log(1D(G(z)))]

  • p d a t a ( x ) p_{data}(x) pdata(x):真实数据分布。
  • p z ( z ) p_z(z) pz(z):潜在变量分布(通常为高斯分布)。
  • D ( x ) D(x) D(x):判别器对真实样本的判别概率。
  • G ( z ) G(z) G(z):生成器生成的样本。

训练过程

  1. 判别器训练:通过真实样本和生成样本的损失来优化判别器。
  2. 生成器训练:通过判别器的反馈,优化生成器,使得生成的样本更逼近真实样本。

通过不断的对抗训练,生成器最终能够生成接近真实数据的样本,判别器则不断提高其区分能力。

以下是一个使用 PyTorch 实现的简单 GAN 案例,目标是生成手写数字(MNIST 数据集)。代码包括生成器和判别器的定义,以及训练过程。

GAN 案例代码

如果你要使用自定义的 MNISTDataset 类来加载数据,可以将它集成到之前的 GAN 示例中。以下是完整的代码示例,结合你的 MNISTDataset 实现。

完整的 GAN 示例代码

epoch设置为20,以作示例。

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import os
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader

# 自定义 MNIST 数据集类
class MNISTDataset(Dataset):
    def __init__(self, images_path, labels_path, transform=None):
        self.images = self.load_images(images_path)
        self.labels = self.load_labels(labels_path)
        self.transform = transform

    def load_images(self, path):
        with open(path, 'rb') as f:
            f.read(16)  # 跳过前16个字节
            images = np.frombuffer(f.read(), np.uint8).reshape(-1, 1, 28, 28)
        return torch.tensor(images, dtype=torch.float32) / 255.0  # 归一化到 [0, 1]

    def load_labels(self, path):
        with open(path, 'rb') as f:
            f.read(8)  # 跳过前8个字节
            labels = np.frombuffer(f.read(), np.uint8)
        return labels

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        image = self.images[idx]
        label = self.labels[idx]

        if self.transform:
            image = self.transform(image)

        return image, label

# 超参数
latent_size = 100
num_epochs = 20
batch_size = 64
learning_rate = 0.0002

# 数据准备
data_root = r'./MNIST'
train_dataset = MNISTDataset(
    images_path=os.path.join(data_root, 'train-images-idx3-ubyte'),
    labels_path=os.path.join(data_root, 'train-labels-idx1-ubyte')
)

data_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)

# 生成器模型
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(latent_size, 256),
            nn.ReLU(),
            nn.Linear(256, 512),
            nn.ReLU(),
            nn.Linear(512, 1024),
            nn.ReLU(),
            nn.Linear(1024, 28 * 28),
            nn.Tanh()  # 输出在[-1, 1]范围内
        )

    def forward(self, z):
        return self.model(z).view(-1, 1, 28, 28)

# 判别器模型
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(28 * 28, 1024),
            nn.LeakyReLU(0.2),
            nn.Linear(1024, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 1),
            nn.Sigmoid()  # 输出在[0, 1]范围内
        )

    def forward(self, x):
        return self.model(x.view(-1, 28 * 28))

# 初始化模型和优化器
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')  # 检查是否有 GPU
generator = Generator().to(device)  # 移动到 GPU
discriminator = Discriminator().to(device)  # 移动到 GPU
criterion = nn.BCELoss()
optimizer_G = optim.Adam(generator.parameters(), lr=learning_rate)
optimizer_D = optim.Adam(discriminator.parameters(), lr=learning_rate)

# 训练过程
for epoch in range(num_epochs):
    for i, (images, _) in enumerate(data_loader):
        images = images.to(device)  # 移动输入数据到 GPU

        # 标签
        real_labels = torch.ones(images.size(0), 1).to(device)  # 移动到 GPU
        fake_labels = torch.zeros(images.size(0), 1).to(device)  # 移动到 GPU

        # 判别器训练
        optimizer_D.zero_grad()
        outputs = discriminator(images)
        d_loss_real = criterion(outputs, real_labels)

        z = torch.randn(images.size(0), latent_size).to(device)  # 随机噪声,移动到 GPU
        fake_images = generator(z)
        outputs = discriminator(fake_images.detach())
        d_loss_fake = criterion(outputs, fake_labels)

        d_loss = d_loss_real + d_loss_fake
        d_loss.backward()
        optimizer_D.step()

        # 生成器训练
        optimizer_G.zero_grad()
        outputs = discriminator(fake_images)
        g_loss = criterion(outputs, real_labels)
        g_loss.backward()
        optimizer_G.step()

    if (epoch + 1) % 10 == 0:
        print(f'Epoch [{epoch + 1}/{num_epochs}], d_loss: {d_loss.item():.4f}, g_loss: {g_loss.item():.4f}')

# 可视化生成的图像
with torch.no_grad():
    z = torch.randn(64, latent_size).to(device)  # 随机噪声,移动到 GPU
    fake_images = generator(z)

# 显示生成的图像
grid_img = fake_images.cpu().numpy()  # 移动到 CPU 以便绘图
grid_img = grid_img.reshape(-1, 28, 28)
plt.figure(figsize=(8, 8))
for i in range(64):
    plt.subplot(8, 8, i + 1)
    plt.imshow(grid_img[i], cmap='gray')
    plt.axis('off')
plt.show()

在这里插入图片描述

代码说明

  1. 自定义数据集MNISTDataset 类用于从指定的路径加载 MNIST 数据。
  2. 数据归一化:在 load_images 方法中,将图像数据归一化到 [0, 1] 范围。
  3. 数据加载:使用 DataLoader 创建训练数据集的加载器。
  4. GAN 模型:包含生成器和判别器的定义。
  5. 训练过程:判别器和生成器交替更新。
  6. 可视化生成图像:训练结束后生成并显示手写数字图像。

你可以运行这个代码,并观察生成的手写数字。确保 MNIST 数据集文件在指定的路径下。


http://www.kler.cn/news/342880.html

相关文章:

  • 鸿蒙开发(NEXT/API 12)【安全单元访问开发】网络篇
  • 股市入门常见术语介绍
  • C#中ref关键字和out关键字
  • 微服务es+Kibana解析部署使用全流程
  • 千兆超薄lan transformer H82412S应用主板英特尔光仟网卡
  • 【Linux】来查看当前系统的架构
  • 【目标检测】木制地板缺陷破损数据集338张6类VOC+YOLO格式
  • 【系统架构设计师】案例专题四:嵌入式系统考点梳理
  • 网络嗅探:网络安全中的关键概念
  • 传知代码-自动车牌识别检测系统(论文复现)
  • 【HTML】制作一个简易图片轮播器
  • 简单的网络爬虫爬取视频
  • PyQt 的Tree Widget中拖放和点击的异常行为
  • 【LeetCode】动态规划—673. 最长递增子序列的个数(附完整Python/C++代码)
  • 014 属性分组
  • 牛客SQL29详解 计算用户的平均次日留存率
  • MySQL数据库表分区
  • DBO-BP回归预测 | MATLAB实现DBO-BP蜣螂优化算法优化神经网络多输入单输出回归预测
  • 20241011给荣品RD-RK3588-AHD开发板刷荣品预编译的Buildroot之后打开AP6275P的BT【命令行】
  • 单通道 LVDS 差分线路接收器MS21112S