当前位置: 首页 > article >正文

第十四站:生成对抗网络(GAN)

前言:生成对抗网络(GAN)是由 Ian Goodfellow 在 2014 年提出的,它是一种 无监督学习 方法,广泛应用于图像生成、图像修复、图像超分辨率等任务。GAN 的核心思想是通过 两个神经网络的对抗训练,使得一个网络生成数据,另一个网络判断生成数据的真实度,最终实现数据生成。


1. GAN 的基本原理:

GAN 由 两个神经网络 组成:

  • 生成器(Generator,G):负责生成数据(如图像),它试图生成尽可能真实的数据来骗过判别器。
  • 判别器(Discriminator,D):负责判断数据是真实的(来自训练数据)还是生成的(来自生成器)。

这两个网络通过 对抗训练 来优化。具体过程如下:

  1. 生成器 生成一个假的数据(比如假图像)。
  2. 判别器 判断这个数据是否真实。
  3. 判别器的反馈信息传递给生成器,生成器根据这个反馈优化自己的生成方式,生成更逼真的数据。
  4. 最终,生成器生成的假数据越来越真实,判别器越来越难以区分。
2. GAN 的目标函数:

生成器和判别器的目标函数如下:

  • 生成器的目标:最大化判别器对生成数据的误判率,使得判别器认为生成数据是真实的。

    生成器的损失函数:
    L G = − log ⁡ ( D ( G ( z ) ) ) L_G = - \log(D(G(z))) LG=log(D(G(z)))

    其中 D ( G ( z ) ) D(G(z)) D(G(z))是判别器对生成数据的预测值, G ( z ) G(z) G(z)是生成器生成的假数据。

  • 判别器的目标:最大化其对真实数据和生成数据的判断正确率。

    判别器的损失函数:
    L D = − log ⁡ ( D ( x ) ) − log ⁡ ( 1 − D ( G ( z ) ) ) L_D = - \log(D(x)) - \log(1 - D(G(z))) LD=log(D(x))log(1D(G(z)))
    其中 D ( x ) D(x) D(x) 是判别器对真实数据的预测值, D ( G ( z ) ) D(G(z)) D(G(z))是判别器对生成数据的预测值。

3. GAN 的训练过程:
  1. 训练判别器:让判别器学会区分真实数据和生成数据。
  2. 训练生成器:通过反向传播让生成器生成更逼真的数据,骗过判别器。

这个过程会不断进行,直到生成器生成的假数据越来越真实,判别器难以区分真假数据。


4. GAN 的应用:

GAN 的应用非常广泛,以下是一些典型的应用领域:

  • 图像生成:GAN 可以生成非常逼真的图像,比如人脸生成、艺术风格生成等。
  • 图像超分辨率:通过训练 GAN 来将低分辨率图像恢复到高分辨率图像。
  • 图像修复:填补图像中的缺失部分(比如去除图片中的噪声或缺失区域)。
  • 语音生成:生成与给定文本相对应的语音。
  • 风格迁移:将一种艺术风格应用到图像上(例如把一张照片转换成油画风格)。

5. 生成对抗网络(GAN)的代码示例:

下面是一个简单的 GAN 代码示例,用于生成手写数字(基于 MNIST 数据集):

import torch
import torch.nn as nn  # 引入 PyTorch 的神经网络模块
import torch.optim as optim  # 引入 PyTorch 的优化器模块
import torchvision  # 引入 PyTorch 的计算机视觉工具包
import torchvision.transforms as transforms  # 用于图像数据的变换
from torch.utils.data import DataLoader  # 用于批量加载数据
import numpy as np

# 生成器网络
class Generator(nn.Module):  # 定义生成器类,继承自 nn.Module
    def __init__(self):
        super(Generator, self).__init__()  # 调用父类的构造函数
        self.fc1 = nn.Linear(100, 256)  # 输入为 100 维噪声,输出 256 维特征
        self.fc2 = nn.Linear(256, 512)  # 将 256 维特征转换为 512 维
        self.fc3 = nn.Linear(512, 1024)  # 将 512 维特征转换为 1024 维
        self.fc4 = nn.Linear(1024, 28 * 28)  # 最终输出 28x28 的图像数据(展平为一维)
        self.relu = nn.ReLU()  # 定义 ReLU 激活函数
        self.tanh = nn.Tanh()  # 定义 Tanh 激活函数,确保输出范围在 [-1, 1]

    def forward(self, x):  # 定义前向传播逻辑
        x = self.relu(self.fc1(x))  # 第一个全连接层后接 ReLU
        x = self.relu(self.fc2(x))  # 第二个全连接层后接 ReLU
        x = self.relu(self.fc3(x))  # 第三个全连接层后接 ReLU
        x = self.tanh(self.fc4(x))  # 最后一层输出接 Tanh,确保生成的像素值在 [-1, 1] 范围内
        return x.view(-1, 1, 28, 28)  # 将输出 reshape 为 (batch_size, 1, 28, 28) 的图像格式

# 判别器网络
class Discriminator(nn.Module):  # 定义判别器类,继承自 nn.Module
    def __init__(self):
        super(Discriminator, self).__init__()  # 调用父类的构造函数
        self.fc1 = nn.Linear(28 * 28, 1024)  # 输入为展平的 28x28 图像数据,输出 1024 维特征
        self.fc2 = nn.Linear(1024, 512)  # 将 1024 维特征转换为 512 维
        self.fc3 = nn.Linear(512, 256)  # 将 512 维特征转换为 256 维
        self.fc4 = nn.Linear(256, 1)  # 最终输出一个标量,表示是否为真实图像
        self.relu = nn.ReLU()  # 定义 ReLU 激活函数
        self.sigmoid = nn.Sigmoid()  # 定义 Sigmoid 激活函数,将输出值映射到 [0, 1]

    def forward(self, x):  # 定义前向传播逻辑
        x = x.view(-1, 28 * 28)  # 将输入图像展平成一维
        x = self.relu(self.fc1(x))  # 第一个全连接层后接 ReLU
        x = self.relu(self.fc2(x))  # 第二个全连接层后接 ReLU
        x = self.relu(self.fc3(x))  # 第三个全连接层后接 ReLU
        x = self.sigmoid(self.fc4(x))  # 最后一层输出接 Sigmoid,生成概率值
        return x

# 设置设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  # 检查是否有 GPU 可用,否则使用 CPU

# 创建生成器和判别器实例,并将它们移动到设备上
generator = Generator().to(device)
discriminator = Discriminator().to(device)

# 定义损失函数和优化器
criterion = nn.BCELoss()  # 使用二进制交叉熵损失函数,适合二分类任务
lr = 0.0002  # 学习率
optimizer_g = optim.Adam(generator.parameters(), lr=lr, betas=(0.5, 0.999))  # 为生成器定义 Adam 优化器
optimizer_d = optim.Adam(discriminator.parameters(), lr=lr, betas=(0.5, 0.999))  # 为判别器定义 Adam 优化器

# 加载 MNIST 数据集
transform = transforms.Compose([
    transforms.ToTensor(),  # 将图像转换为张量
    transforms.Normalize((0.5,), (0.5,))  # 将像素值归一化到 [-1, 1] 范围
])
trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)  # 下载 MNIST 数据集
trainloader = DataLoader(trainset, batch_size=64, shuffle=True)  # 使用 DataLoader 按批量加载数据

# 训练 GAN
num_epochs = 10  # 训练轮数
for epoch in range(num_epochs):
    for i, (real_images, _) in enumerate(trainloader):  # 遍历训练数据集
        real_images = real_images.to(device)  # 将真实图像移动到设备上
        batch_size = real_images.size(0)  # 获取当前批次的大小

        # 训练判别器
        optimizer_d.zero_grad()  # 清空判别器的梯度
        real_labels = torch.ones(batch_size, 1).to(device)  # 定义真实标签 (1)
        fake_labels = torch.zeros(batch_size, 1).to(device)  # 定义假标签 (0)

        outputs = discriminator(real_images)  # 判别器处理真实图像
        d_loss_real = criterion(outputs, real_labels)  # 计算判别器对真实图像的损失
        d_loss_real.backward()  # 反向传播计算梯度

        noise = torch.randn(batch_size, 100).to(device)  # 生成随机噪声向量
        fake_images = generator(noise)  # 使用生成器生成假图像
        outputs = discriminator(fake_images.detach())  # 判别器处理假图像(使用 detach 不计算生成器的梯度)
        d_loss_fake = criterion(outputs, fake_labels)  # 计算判别器对假图像的损失
        d_loss_fake.backward()  # 反向传播计算梯度

        optimizer_d.step()  # 更新判别器的参数

        # 训练生成器
        optimizer_g.zero_grad()  # 清空生成器的梯度
        outputs = discriminator(fake_images)  # 判别器处理假图像
        g_loss = criterion(outputs, real_labels)  # 生成器希望判别器将假图像判定为真实
        g_loss.backward()  # 反向传播计算梯度

        optimizer_g.step()  # 更新生成器的参数

    # 打印每个 epoch 的损失
    print(f"Epoch [{epoch+1}/{num_epochs}], D Loss: {d_loss_real.item() + d_loss_fake.item()}, G Loss: {g_loss.item()}")

print("Finished Training")  # 训练完成
  • 生成器(Generator):生成假图像,通过 fc 层将随机噪声(100 维)映射为 28x28 的图像。
  • 判别器(Discriminator):判定输入图像是真实的还是生成的,输出一个概率值。
  • 训练过程
    • 训练判别器去分辨真实和生成的图像。
    • 训练生成器去骗过判别器,让判别器认为生成的图像是真实的。


http://www.kler.cn/a/567446.html

相关文章:

  • 基于SpringBoot的美妆购物网站系统设计与实现现(源码+SQL脚本+LW+部署讲解等)
  • Spark 介绍
  • final 关键字在不同上下文中的用法及其名称
  • Ubuntu 下 nginx-1.24.0 源码分析 - ngx_open_file
  • 性能测试监控工具jmeter+grafana
  • ave-form.vue 组件中 如何将产品名称发送给后端 ?
  • Unity插件-Mirror使用方法(二)组件介绍
  • 【学术会议论文投稿】Spring Boot实战:零基础打造你的Web应用新纪元
  • C++之 “” 用法(总结)
  • 【Oracle脚本】消耗CPU高的SQL抓取
  • JavaPro _JVM 知识点速记 JVM大全
  • 【AVL树】—— 我与C++的不解之缘(二十三)
  • GitCode 助力 python-office:开启 Python 自动化办公新生态
  • 机器学习的通用工作流程
  • 若依框架修改为多租户
  • OptiTrack光学跟踪系统:引领工厂机器人应用的革新浪潮
  • 克隆项目到本地
  • C++(Qt)软件调试---Linux 性能分析器perf(29)
  • lua学习(二)
  • Compose笔记(七)--Modifier