第十四站:生成对抗网络(GAN)
前言:生成对抗网络(GAN)是由 Ian Goodfellow 在 2014 年提出的,它是一种 无监督学习 方法,广泛应用于图像生成、图像修复、图像超分辨率等任务。GAN 的核心思想是通过 两个神经网络的对抗训练,使得一个网络生成数据,另一个网络判断生成数据的真实度,最终实现数据生成。
1. GAN 的基本原理:
GAN 由 两个神经网络 组成:
- 生成器(Generator,G):负责生成数据(如图像),它试图生成尽可能真实的数据来骗过判别器。
- 判别器(Discriminator,D):负责判断数据是真实的(来自训练数据)还是生成的(来自生成器)。
这两个网络通过 对抗训练 来优化。具体过程如下:
- 生成器 生成一个假的数据(比如假图像)。
- 判别器 判断这个数据是否真实。
- 判别器的反馈信息传递给生成器,生成器根据这个反馈优化自己的生成方式,生成更逼真的数据。
- 最终,生成器生成的假数据越来越真实,判别器越来越难以区分。
2. GAN 的目标函数:
生成器和判别器的目标函数如下:
-
生成器的目标:最大化判别器对生成数据的误判率,使得判别器认为生成数据是真实的。
生成器的损失函数:
L G = − log ( D ( G ( z ) ) ) L_G = - \log(D(G(z))) LG=−log(D(G(z)))其中 D ( G ( z ) ) D(G(z)) D(G(z))是判别器对生成数据的预测值, G ( z ) G(z) G(z)是生成器生成的假数据。
-
判别器的目标:最大化其对真实数据和生成数据的判断正确率。
判别器的损失函数:
L D = − log ( D ( x ) ) − log ( 1 − D ( G ( z ) ) ) L_D = - \log(D(x)) - \log(1 - D(G(z))) LD=−log(D(x))−log(1−D(G(z)))
其中 D ( x ) D(x) D(x) 是判别器对真实数据的预测值, D ( G ( z ) ) D(G(z)) D(G(z))是判别器对生成数据的预测值。
3. GAN 的训练过程:
- 训练判别器:让判别器学会区分真实数据和生成数据。
- 训练生成器:通过反向传播让生成器生成更逼真的数据,骗过判别器。
这个过程会不断进行,直到生成器生成的假数据越来越真实,判别器难以区分真假数据。
4. GAN 的应用:
GAN 的应用非常广泛,以下是一些典型的应用领域:
- 图像生成:GAN 可以生成非常逼真的图像,比如人脸生成、艺术风格生成等。
- 图像超分辨率:通过训练 GAN 来将低分辨率图像恢复到高分辨率图像。
- 图像修复:填补图像中的缺失部分(比如去除图片中的噪声或缺失区域)。
- 语音生成:生成与给定文本相对应的语音。
- 风格迁移:将一种艺术风格应用到图像上(例如把一张照片转换成油画风格)。
5. 生成对抗网络(GAN)的代码示例:
下面是一个简单的 GAN 代码示例,用于生成手写数字(基于 MNIST 数据集):
import torch
import torch.nn as nn # 引入 PyTorch 的神经网络模块
import torch.optim as optim # 引入 PyTorch 的优化器模块
import torchvision # 引入 PyTorch 的计算机视觉工具包
import torchvision.transforms as transforms # 用于图像数据的变换
from torch.utils.data import DataLoader # 用于批量加载数据
import numpy as np
# 生成器网络
class Generator(nn.Module): # 定义生成器类,继承自 nn.Module
def __init__(self):
super(Generator, self).__init__() # 调用父类的构造函数
self.fc1 = nn.Linear(100, 256) # 输入为 100 维噪声,输出 256 维特征
self.fc2 = nn.Linear(256, 512) # 将 256 维特征转换为 512 维
self.fc3 = nn.Linear(512, 1024) # 将 512 维特征转换为 1024 维
self.fc4 = nn.Linear(1024, 28 * 28) # 最终输出 28x28 的图像数据(展平为一维)
self.relu = nn.ReLU() # 定义 ReLU 激活函数
self.tanh = nn.Tanh() # 定义 Tanh 激活函数,确保输出范围在 [-1, 1]
def forward(self, x): # 定义前向传播逻辑
x = self.relu(self.fc1(x)) # 第一个全连接层后接 ReLU
x = self.relu(self.fc2(x)) # 第二个全连接层后接 ReLU
x = self.relu(self.fc3(x)) # 第三个全连接层后接 ReLU
x = self.tanh(self.fc4(x)) # 最后一层输出接 Tanh,确保生成的像素值在 [-1, 1] 范围内
return x.view(-1, 1, 28, 28) # 将输出 reshape 为 (batch_size, 1, 28, 28) 的图像格式
# 判别器网络
class Discriminator(nn.Module): # 定义判别器类,继承自 nn.Module
def __init__(self):
super(Discriminator, self).__init__() # 调用父类的构造函数
self.fc1 = nn.Linear(28 * 28, 1024) # 输入为展平的 28x28 图像数据,输出 1024 维特征
self.fc2 = nn.Linear(1024, 512) # 将 1024 维特征转换为 512 维
self.fc3 = nn.Linear(512, 256) # 将 512 维特征转换为 256 维
self.fc4 = nn.Linear(256, 1) # 最终输出一个标量,表示是否为真实图像
self.relu = nn.ReLU() # 定义 ReLU 激活函数
self.sigmoid = nn.Sigmoid() # 定义 Sigmoid 激活函数,将输出值映射到 [0, 1]
def forward(self, x): # 定义前向传播逻辑
x = x.view(-1, 28 * 28) # 将输入图像展平成一维
x = self.relu(self.fc1(x)) # 第一个全连接层后接 ReLU
x = self.relu(self.fc2(x)) # 第二个全连接层后接 ReLU
x = self.relu(self.fc3(x)) # 第三个全连接层后接 ReLU
x = self.sigmoid(self.fc4(x)) # 最后一层输出接 Sigmoid,生成概率值
return x
# 设置设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # 检查是否有 GPU 可用,否则使用 CPU
# 创建生成器和判别器实例,并将它们移动到设备上
generator = Generator().to(device)
discriminator = Discriminator().to(device)
# 定义损失函数和优化器
criterion = nn.BCELoss() # 使用二进制交叉熵损失函数,适合二分类任务
lr = 0.0002 # 学习率
optimizer_g = optim.Adam(generator.parameters(), lr=lr, betas=(0.5, 0.999)) # 为生成器定义 Adam 优化器
optimizer_d = optim.Adam(discriminator.parameters(), lr=lr, betas=(0.5, 0.999)) # 为判别器定义 Adam 优化器
# 加载 MNIST 数据集
transform = transforms.Compose([
transforms.ToTensor(), # 将图像转换为张量
transforms.Normalize((0.5,), (0.5,)) # 将像素值归一化到 [-1, 1] 范围
])
trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform) # 下载 MNIST 数据集
trainloader = DataLoader(trainset, batch_size=64, shuffle=True) # 使用 DataLoader 按批量加载数据
# 训练 GAN
num_epochs = 10 # 训练轮数
for epoch in range(num_epochs):
for i, (real_images, _) in enumerate(trainloader): # 遍历训练数据集
real_images = real_images.to(device) # 将真实图像移动到设备上
batch_size = real_images.size(0) # 获取当前批次的大小
# 训练判别器
optimizer_d.zero_grad() # 清空判别器的梯度
real_labels = torch.ones(batch_size, 1).to(device) # 定义真实标签 (1)
fake_labels = torch.zeros(batch_size, 1).to(device) # 定义假标签 (0)
outputs = discriminator(real_images) # 判别器处理真实图像
d_loss_real = criterion(outputs, real_labels) # 计算判别器对真实图像的损失
d_loss_real.backward() # 反向传播计算梯度
noise = torch.randn(batch_size, 100).to(device) # 生成随机噪声向量
fake_images = generator(noise) # 使用生成器生成假图像
outputs = discriminator(fake_images.detach()) # 判别器处理假图像(使用 detach 不计算生成器的梯度)
d_loss_fake = criterion(outputs, fake_labels) # 计算判别器对假图像的损失
d_loss_fake.backward() # 反向传播计算梯度
optimizer_d.step() # 更新判别器的参数
# 训练生成器
optimizer_g.zero_grad() # 清空生成器的梯度
outputs = discriminator(fake_images) # 判别器处理假图像
g_loss = criterion(outputs, real_labels) # 生成器希望判别器将假图像判定为真实
g_loss.backward() # 反向传播计算梯度
optimizer_g.step() # 更新生成器的参数
# 打印每个 epoch 的损失
print(f"Epoch [{epoch+1}/{num_epochs}], D Loss: {d_loss_real.item() + d_loss_fake.item()}, G Loss: {g_loss.item()}")
print("Finished Training") # 训练完成
- 生成器(Generator):生成假图像,通过
fc
层将随机噪声(100 维)映射为 28x28 的图像。 - 判别器(Discriminator):判定输入图像是真实的还是生成的,输出一个概率值。
- 训练过程:
- 训练判别器去分辨真实和生成的图像。
- 训练生成器去骗过判别器,让判别器认为生成的图像是真实的。