PyTorch深度学习框架60天进阶学习计划第16天:循环神经网络进阶!
PyTorch深度学习框架60天进阶学习计划 - 第16天:生成对抗网络原理
学习目标
今天我们将深入探讨生成对抗网络(GAN)的基本原理和数学基础,重点解析GAN的minimax博弈公式,推导生成器与判别器的损失函数,分析Wasserstein GAN的改进方案以及DCGAN的架构设计规范。
1. GAN的基本原理
生成对抗网络(Generative Adversarial Networks, GAN)是由Ian Goodfellow在2014年提出的一种生成模型框架。GAN由两个网络组成:
- 生成器(Generator, G):学习生成逼真的数据样本
- 判别器(Discriminator, D):学习区分真实数据和生成器生成的数据
这两个网络通过对抗训练相互提升。
1.1 GAN的博弈过程
GAN的训练过程可以看作是一个两人零和博弈:
网络角色 | 目标 | 策略 |
---|---|---|
生成器(G) | 生成逼真的假样本欺骗判别器 | 最小化判别器正确分类的概率 |
判别器(D) | 准确区分真实样本和生成样本 | 最大化判别器正确分类的概率 |
2. GAN的数学表达:Minimax博弈公式
2.1 经典GAN的价值函数
GAN的核心可以用以下数学公式表示:
min G max D V ( D , G ) = E x ∼ p d a t a ( x ) [ log D ( x ) ] + E z ∼ p z ( z ) [ log ( 1 − D ( G ( z ) ) ) ] \min_G \max_D V(D, G) = \mathbb{E}_{x \sim p_{data}(x)}[\log D(x)] + \mathbb{E}_{z \sim p_z(z)}[\log(1 - D(G(z)))] GminDmaxV(D,G)=Ex∼pdata(x)[logD(x)]+Ez∼pz(z)[log(1−D(G(z)))]
这个公式表达了什么?让我们逐步分解:
- V ( D , G ) V(D, G) V(D,G) 是价值函数,判别器D试图最大化它,而生成器G试图最小化它
- E x ∼ p d a t a ( x ) [ log D ( x ) ] \mathbb{E}_{x \sim p_{data}(x)}[\log D(x)] Ex∼pdata(x)[logD(x)] 表示判别器对真实数据的平均预测概率的对数
- E z ∼ p z ( z ) [ log ( 1 − D ( G ( z ) ) ) ] \mathbb{E}_{z \sim p_z(z)}[\log(1 - D(G(z)))] Ez∼pz(z)[log(1−D(G(z)))] 表示判别器对生成数据的平均预测概率的对数的负值
2.2 损失函数的推导
从minimax公式中,我们可以分别推导出判别器和生成器的损失函数:
判别器的损失函数
判别器的目标是最大化 V ( D , G ) V(D, G) V(D,G),即最小化以下损失函数:
L D = − E x ∼ p d a t a ( x ) [ log D ( x ) ] − E z ∼ p z ( z ) [ log ( 1 − D ( G ( z ) ) ) ] L_D = -\mathbb{E}_{x \sim p_{data}(x)}[\log D(x)] - \mathbb{E}_{z \sim p_z(z)}[\log(1 - D(G(z)))] LD=−Ex∼pdata(x)[logD(x)]−Ez∼pz(z)[log(1−D(G(z)))]
在实际实现中,我们通常使用二元交叉熵损失:
L D = − 1 m ∑ i = 1 m [ log D ( x ( i ) ) + log ( 1 − D ( G ( z ( i ) ) ) ) ] L_D = -\frac{1}{m}\sum_{i=1}^{m}[\log D(x^{(i)}) + \log(1 - D(G(z^{(i)})))] LD=−m1i=1∑m[logD(x(i))+log(1−D(G(z(i))))]
生成器的损失函数
生成器的目标是最小化 V ( D , G ) V(D, G) V(D,G),即最小化以下损失函数:
L G = E z ∼ p z ( z ) [ log ( 1 − D ( G ( z ) ) ) ] L_G = \mathbb{E}_{z \sim p_z(z)}[\log(1 - D(G(z)))] LG=Ez∼pz(z)[log(1−D(G(z)))]
然而,在训练初期,当生成器的输出质量较差时,判别器可以轻松区分真假样本,这会导致梯度消失问题。因此,实践中通常使用一个非饱和的损失函数:
L G = − E z ∼ p z ( z ) [ log D ( G ( z ) ) ] L_G = -\mathbb{E}_{z \sim p_z(z)}[\log D(G(z))] LG=−Ez∼pz(z)[logD(G(z))]
在代码实现中:
L G = − 1 m ∑ i = 1 m [ log D ( G ( z ( i ) ) ) ] L_G = -\frac{1}{m}\sum_{i=1}^{m}[\log D(G(z^{(i)}))] LG=−m1i=1∑m[logD(G(z(i)))]
3. Wasserstein GAN (WGAN) 改进方案
3.1 传统GAN的问题
传统GAN存在以下问题:
- 训练不稳定
- 模式崩溃(Mode Collapse)
- 梯度消失或爆炸
3.2 Wasserstein距离的引入
Wasserstein GAN引入了Wasserstein距离(也称为Earth Mover’s Distance, EMD)来衡量两个概率分布之间的差异。其数学表达式为:
W ( p r , p g ) = inf γ ∈ Π ( p r , p g ) E ( x , y ) ∼ γ [ ∥ x − y ∥ ] W(p_r, p_g) = \inf_{\gamma \in \Pi(p_r, p_g)} \mathbb{E}_{(x,y) \sim \gamma}[\|x-y\|] W(pr,pg)=γ∈Π(pr,pg)infE(x,y)∼γ[∥x−y∥]
WGAN的价值函数变为:
min G max D ∈ D E x ∼ p d a t a ( x ) [ D ( x ) ] − E z ∼ p z ( z ) [ D ( G ( z ) ) ] \min_G \max_{D \in \mathcal{D}} \mathbb{E}_{x \sim p_{data}(x)}[D(x)] - \mathbb{E}_{z \sim p_z(z)}[D(G(z))] GminD∈DmaxEx∼pdata(x)[D(x)]−Ez∼pz(z)[D(G(z))]
其中 D \mathcal{D} D 是所有1-Lipschitz函数的集合。
3.3 梯度惩罚(Gradient Penalty)实现原理
为了满足Lipschitz约束,WGAN-GP提出了梯度惩罚的方法:
L D = E z ∼ p z ( z ) [ D ( G ( z ) ) ] − E x ∼ p d a t a ( x ) [ D ( x ) ] + λ E x ^ ∼ p x ^ [ ( ∥ ∇ x ^ D ( x ^ ) ∥ 2 − 1 ) 2 ] L_D = \mathbb{E}_{z \sim p_z(z)}[D(G(z))] - \mathbb{E}_{x \sim p_{data}(x)}[D(x)] + \lambda \mathbb{E}_{\hat{x} \sim p_{\hat{x}}}[(\|\nabla_{\hat{x}} D(\hat{x})\|_2 - 1)^2] LD=Ez∼pz(z)[D(G(z))]−Ex∼pdata(x)[D(x)]+λEx^∼px^[(∥∇x^D(x^)∥2−1)2]
其中 x ^ \hat{x} x^ 是真实样本和生成样本之间的随机插值:
x ^ = ϵ x + ( 1 − ϵ ) G ( z ) \hat{x} = \epsilon x + (1 - \epsilon) G(z) x^=ϵx+(1−ϵ)G(z)
ϵ \epsilon ϵ 是从均匀分布 U [ 0 , 1 ] U[0,1] U[0,1] 采样的随机数。
4. DCGAN架构设计规范
Deep Convolutional GAN (DCGAN) 是GAN在计算机视觉领域的一个重要应用。它提出了一系列架构设计规范:
4.1 DCGAN主要设计规范
规范 | 生成器(G) | 判别器(D) |
---|---|---|
池化层 | 使用转置卷积进行上采样,不使用池化层 | 使用带步长的卷积替代池化层进行下采样 |
批量归一化 | 在除输出层外的所有层使用 | 在除输入层外的所有层使用 |
激活函数 | 隐藏层使用ReLU激活函数,输出层使用Tanh | 所有层使用LeakyReLU |
全连接层 | 最后一层可以使用全连接层 | 最后一层可以使用全连接层 |
4.2 批量归一化在生成器中的特殊应用
在生成器中,批量归一化具有以下特殊应用:
- 促进网络收敛:通过归一化特征,加速训练过程
- 防止模式崩溃:帮助不同的生成样本保持多样性
- 减轻内部协变量偏移:保持各层的输入分布相对稳定
- 特殊位置:通常不在生成器的输出层应用批量归一化,以保留生成数据的原始分布特性
5. PyTorch实现标准GAN
下面是一个使用PyTorch实现标准GAN的代码示例:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
# 设置随机种子
torch.manual_seed(42)
# 设备配置
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 超参数
batch_size = 64
z_dim = 100
lr = 0.0002
beta1 = 0.5
epochs = 30
image_size = 64
# 数据集预处理
transform = transforms.Compose([
transforms.Resize(image_size),
transforms.CenterCrop(image_size),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])
# 加载MNIST数据集
dataset = torchvision.datasets.MNIST(root='./data', train=True, transform=transform, download=True)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=2)
# 判别器网络
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.model = nn.Sequential(
# 输入: 1 x 64 x 64
nn.Conv2d(1, 64, 4, 2, 1, bias=False),
nn.LeakyReLU(0.2, inplace=True),
# 32 x 32
nn.Conv2d(64, 128, 4, 2, 1, bias=False),
nn.BatchNorm2d(128),
nn.LeakyReLU(0.2, inplace=True),
# 16 x 16
nn.Conv2d(128, 256, 4, 2, 1, bias=False),
nn.BatchNorm2d(256),
nn.LeakyReLU(0.2, inplace=True),
# 8 x 8
nn.Conv2d(256, 512, 4, 2, 1, bias=False),
nn.BatchNorm2d(512),
nn.LeakyReLU(0.2, inplace=True),
# 4 x 4
nn.Conv2d(512, 1, 4, 1, 0, bias=False),
nn.Sigmoid()
)
def forward(self, x):
return self.model(x).view(-1, 1).squeeze(1)
# 生成器网络
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.model = nn.Sequential(
# 输入: z_dim x 1 x 1
nn.ConvTranspose2d(z_dim, 512, 4, 1, 0, bias=False),
nn.BatchNorm2d(512),
nn.ReLU(True),
# 4 x 4
nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False),
nn.BatchNorm2d(256),
nn.ReLU(True),
# 8 x 8
nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),
nn.BatchNorm2d(128),
nn.ReLU(True),
# 16 x 16
nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False),
nn.BatchNorm2d(64),
nn.ReLU(True),
# 32 x 32
nn.ConvTranspose2d(64, 1, 4, 2, 1, bias=False),
nn.Tanh()
# 输出: 1 x 64 x 64
)
def forward(self, z):
return self.model(z)
# 初始化网络
netG = Generator().to(device)
netD = Discriminator().to(device)
# 权重初始化
def weights_init(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
nn.init.normal_(m.weight.data, 0.0, 0.02)
elif classname.find('BatchNorm') != -1:
nn.init.normal_(m.weight.data, 1.0, 0.02)
nn.init.constant_(m.bias.data, 0)
netG.apply(weights_init)
netD.apply(weights_init)
# 设置损失函数和优化器
criterion = nn.BCELoss()
optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))
# 训练GAN
for epoch in range(epochs):
for i, data in enumerate(dataloader, 0):
############################
# (1) 更新判别器: 最大化 log(D(x)) + log(1 - D(G(z)))
###########################
# 训练真实样本
netD.zero_grad()
real_images = data[0].to(device)
batch_size = real_images.size(0)
labels = torch.full((batch_size,), 1, dtype=torch.float, device=device)
output = netD(real_images)
errD_real = criterion(output, labels)
errD_real.backward()
# 训练生成样本
noise = torch.randn(batch_size, z_dim, 1, 1, device=device)
fake_images = netG(noise)
labels.fill_(0)
output = netD(fake_images.detach())
errD_fake = criterion(output, labels)
errD_fake.backward()
errD = errD_real + errD_fake
optimizerD.step()
############################
# (2) 更新生成器: 最大化 log(D(G(z)))
###########################
netG.zero_grad()
labels.fill_(1) # 生成器希望判别器将生成的图像判为真
output = netD(fake_images)
errG = criterion(output, labels)
errG.backward()
optimizerG.step()
# 输出训练状态
if i % 100 == 0:
print(f'[{epoch}/{epochs}][{i}/{len(dataloader)}] Loss_D: {errD.item():.4f} Loss_G: {errG.item():.4f}')
# 保存一些生成的图像
with torch.no_grad():
fixed_noise = torch.randn(64, z_dim, 1, 1, device=device)
fake = netG(fixed_noise).detach().cpu()
img_grid = torchvision.utils.make_grid(fake, padding=2, normalize=True)
plt.figure(figsize=(8, 8))
plt.imshow(np.transpose(img_grid, (1, 2, 0)))
plt.axis('off')
plt.savefig(f'fake_images_epoch_{epoch}.png')
plt.close()
print("Training finished!")
6. WGAN-GP的PyTorch实现
下面是WGAN-GP的PyTorch实现示例:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt
# 设置随机种子
torch.manual_seed(42)
# 设备配置
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 超参数
batch_size = 64
z_dim = 100
lr = 0.0002
beta1 = 0.5
beta2 = 0.9
epochs = 30
image_size = 64
n_critic = 5 # 判别器训练次数
lambda_gp = 10 # 梯度惩罚系数
# 数据集预处理
transform = transforms.Compose([
transforms.Resize(image_size),
transforms.CenterCrop(image_size),
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,)),
])
# 加载MNIST数据集
dataset = torchvision.datasets.MNIST(root='./data', train=True, transform=transform, download=True)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=2)
# 判别器网络 (Critic)
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.model = nn.Sequential(
# 输入: 1 x 64 x 64
nn.Conv2d(1, 64, 4, 2, 1),
nn.LeakyReLU(0.2, inplace=True),
# 32 x 32
nn.Conv2d(64, 128, 4, 2, 1),
nn.InstanceNorm2d(128), # 使用Instance Normalization替代Batch Normalization
nn.LeakyReLU(0.2, inplace=True),
# 16 x 16
nn.Conv2d(128, 256, 4, 2, 1),
nn.InstanceNorm2d(256),
nn.LeakyReLU(0.2, inplace=True),
# 8 x 8
nn.Conv2d(256, 512, 4, 2, 1),
nn.InstanceNorm2d(512),
nn.LeakyReLU(0.2, inplace=True),
# 4 x 4
nn.Conv2d(512, 1, 4, 1, 0),
# 注意: 没有Sigmoid激活函数,因为WGAN直接输出Wasserstein距离
)
def forward(self, x):
return self.model(x).view(-1)
# 生成器网络
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.model = nn.Sequential(
# 输入: z_dim x 1 x 1
nn.ConvTranspose2d(z_dim, 512, 4, 1, 0),
nn.BatchNorm2d(512),
nn.ReLU(True),
# 4 x 4
nn.ConvTranspose2d(512, 256, 4, 2, 1),
nn.BatchNorm2d(256),
nn.ReLU(True),
# 8 x 8
nn.ConvTranspose2d(256, 128, 4, 2, 1),
nn.BatchNorm2d(128),
nn.ReLU(True),
# 16 x 16
nn.ConvTranspose2d(128, 64, 4, 2, 1),
nn.BatchNorm2d(64),
nn.ReLU(True),
# 32 x 32
nn.ConvTranspose2d(64, 1, 4, 2, 1),
nn.Tanh()
# 输出: 1 x 64 x 64
)
def forward(self, z):
return self.model(z)
# 计算梯度惩罚
def compute_gradient_penalty(D, real_samples, fake_samples):
# 随机权重项: 在真实样本和生成样本之间进行插值
alpha = torch.rand(real_samples.size(0), 1, 1, 1, device=device)
# 获取插值样本
interpolates = (alpha * real_samples + ((1 - alpha) * fake_samples)).requires_grad_(True)
# 计算判别器对插值样本的输出
d_interpolates = D(interpolates)
# 为反向传播创建虚拟标签
fake = torch.ones(real_samples.size(0), device=device, requires_grad=False)
# 计算梯度
gradients = torch.autograd.grad(
outputs=d_interpolates,
inputs=interpolates,
grad_outputs=fake,
create_graph=True,
retain_graph=True,
only_inputs=True,
)[0]
# 计算梯度惩罚
gradients = gradients.view(gradients.size(0), -1)
gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
return gradient_penalty
# 初始化网络
netG = Generator().to(device)
netD = Discriminator().to(device)
# 权重初始化
def weights_init(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
nn.init.normal_(m.weight.data, 0.0, 0.02)
elif classname.find('BatchNorm') != -1 or classname.find('InstanceNorm') != -1:
nn.init.normal_(m.weight.data, 1.0, 0.02)
nn.init.constant_(m.bias.data, 0)
netG.apply(weights_init)
netD.apply(weights_init)
# 设置优化器
optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, beta2))
optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, beta2))
# 记录生成器和判别器的损失
G_losses = []
D_losses = []
# 训练WGAN-GP
for epoch in range(epochs):
for i, data in enumerate(dataloader, 0):
############################
# (1) 更新判别器
###########################
# 训练判别器多次
for _ in range(n_critic):
# 配置网络
netD.zero_grad()
# 训练真实样本
real_images = data[0].to(device)
batch_size = real_images.size(0)
# 生成噪声
noise = torch.randn(batch_size, z_dim, 1, 1, device=device)
# 生成假样本
fake_images = netG(noise)
# 计算损失
real_validity = netD(real_images)
fake_validity = netD(fake_images.detach())
# 计算梯度惩罚
gradient_penalty = compute_gradient_penalty(netD, real_images, fake_images.detach())
# 判别器总损失
d_loss = -torch.mean(real_validity) + torch.mean(fake_validity) + lambda_gp * gradient_penalty
# 反向传播
d_loss.backward()
optimizerD.step()
############################
# (2) 更新生成器
###########################
netG.zero_grad()
# 生成新的假样本
noise = torch.randn(batch_size, z_dim, 1, 1, device=device)
fake_images = netG(noise)
# 判别器评估假样本
fake_validity = netD(fake_images)
# 生成器损失
g_loss = -torch.mean(fake_validity)
# 反向传播
g_loss.backward()
optimizerG.step()
# 保存损失
G_losses.append(g_loss.item())
D_losses.append(d_loss.item())
# 输出训练状态
if i % 50 == 0:
print(f'[{epoch}/{epochs}][{i}/{len(dataloader)}] Loss_D: {d_loss.item():.4f} Loss_G: {g_loss.item():.4f}')
# 每个epoch结束后保存一些生成的图像
with torch.no_grad():
fixed_noise = torch.randn(64, z_dim, 1, 1, device=device)
fake = netG(fixed_noise).detach().cpu()
img_grid = torchvision.utils.make_grid(fake, padding=2, normalize=True)
plt.figure(figsize=(8, 8))
plt.imshow(np.transpose(img_grid, (1, 2, 0)), cmap='gray')
plt.axis('off')
plt.title(f'WGAN-GP Generated Images - Epoch {epoch}')
plt.savefig(f'wgan_gp_images_epoch_{epoch}.png')
plt.close()
print("Training finished!")
# 绘制损失曲线
plt.figure(figsize=(10, 5))
plt.title("Generator and Discriminator Loss During Training")
plt.plot(G_losses, label="G")
plt.plot(D_losses, label="D")
plt.xlabel("iterations")
plt.ylabel("Loss")
plt.legend()
plt.savefig("wgan_gp_loss_curve.png")
plt.close()
# 展示最终生成的图像
with torch.no_grad():
noise = torch.randn(16, z_dim, 1, 1, device=device)
generated_images = netG(noise).detach().cpu()
# 去归一化
generated_images = (generated_images + 1) / 2
# 创建图像网格
img_grid = torchvision.utils.make_grid(generated_images, nrow=4, padding=2, normalize=False)
plt.figure(figsize=(10, 10))
plt.imshow(np.transpose(img_grid, (1, 2, 0)), cmap='gray')
plt.axis('off')
plt.title("Final WGAN-GP Generated Images")
plt.savefig("final_wgan_gp_images.png")
plt.show()
7. GAN的训练流程图
以下是GAN训练的基本流程图:
8. WGAN-GP与标准GAN的对比
特性 | 标准GAN | WGAN-GP |
---|---|---|
损失函数 | 二元交叉熵 | Wasserstein距离 + 梯度惩罚项 |
判别器/评论家输出 | 概率值(0~1) | 实数值(Wasserstein距离) |
最后一层激活函数 | Sigmoid | 无(线性输出) |
参数裁剪 | 不需要 | 通过梯度惩罚实现Lipschitz约束 |
优化器推荐 | Adam(β1=0.5, β2=0.999) | Adam(β1=0.5, β2=0.9) |
训练稳定性 | 容易不稳定 | 更加稳定 |
模式崩溃 | 常见问题 | 大幅减轻 |
梯度消失 | 容易发生 | 基本解决 |
判别器训练次数 | 通常1:1 | 通常5:1(判别器:生成器) |
归一化层 | BatchNorm | 推荐使用LayerNorm或InstanceNorm |
训练速度 | 相对较快 | 相对较慢(需要多次训练判别器) |
超参数敏感度 | 较高 | 较低 |
理论基础 | JS散度 | Wasserstein距离 |
9. 批量归一化在GAN中的特殊应用分析
批量归一化(Batch Normalization)在GAN中具有重要作用,尤其在生成器中有特殊的应用方式:
9.1 批量归一化的基本原理
批量归一化通过以下公式对每个批次的数据进行标准化:
x ^ i = x i − μ B σ B 2 + ϵ \hat{x}_i = \frac{x_i - \mu_B}{\sqrt{\sigma_B^2 + \epsilon}} x^i=σB2+ϵxi−μB
其中, μ B \mu_B μB 和 σ B 2 \sigma_B^2 σB2 分别是批次的均值和方差, ϵ \epsilon ϵ 是防止除零的小常数。
然后,通过可学习的参数 γ \gamma γ 和 β \beta β 调整标准化后的分布:
y i = γ x ^ i + β y_i = \gamma \hat{x}_i + \beta yi=γx^i+β
9.2 生成器中批量归一化的特殊应用
位置选择
在生成器网络中,批量归一化层的位置有以下特殊考虑:
-
输入层后不使用:生成器接收随机噪声作为输入,这些噪声通常已经是标准正态分布,不需要额外的归一化。
-
输出层前不使用:输出层通常使用Tanh激活函数将值映射到[-1,1]区间,输出前不应用批量归一化以保持生成数据的自然分布。
-
中间层广泛使用:在中间层广泛使用BatchNorm可以稳定训练过程,防止梯度消失和爆炸。
训练模式与评估模式的区别
在GAN中,批量归一化层在训练和评估模式下的行为有重要区别:
- 训练模式(train()):使用当前批次的统计值进行归一化
- 评估模式(eval()):使用整个训练过程累积的统计值进行归一化
在GAN训练中,正确切换这两种模式至关重要。在生成样本时,必须使用评估模式,确保输出的一致性。
应对小批量大小的策略
当批量大小较小时,BatchNorm可能导致统计不稳定。在GAN中,尤其是高分辨率图像生成时,可以考虑以下替代方案:
- 实例归一化(Instance Normalization):对每个样本独立归一化
- 层归一化(Layer Normalization):对每个特征通道独立归一化
- 组归一化(Group Normalization):将通道分组后归一化
批量归一化与条件GAN
在条件GAN(Conditional GAN)中,批量归一化可以结合条件信息:
class ConditionalBatchNorm2d(nn.Module):
def __init__(self, num_features, num_classes):
super().__init__()
self.num_features = num_features
self.bn = nn.BatchNorm2d(num_features, affine=False)
self.embed = nn.Embedding(num_classes, num_features * 2)
self.embed.weight.data[:, :num_features].normal_(1, 0.02) # 初始化为1
self.embed.weight.data[:, num_features:].zero_() # 初始化为0
def forward(self, x, y):
out = self.bn(x)
gamma, beta = self.embed(y).chunk(2, 1)
out = gamma.view(-1, self.num_features, 1, 1) * out + beta.view(-1, self.num_features, 1, 1)
return out
这种条件批量归一化允许生成器根据标签条件调整其特征统计,从而生成特定类别的样本。
10. 实现DCGAN的最佳实践
下面是实现DCGAN时的一些最佳实践:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
# 设置随机种子,确保结果可复现
torch.manual_seed(42)
np.random.seed(42)
# 设备配置
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# 超参数
batch_size = 128
z_dim = 100
lr = 0.0002
beta1 = 0.5
epochs = 25
image_size = 64
nc = 3 # 彩色图像的通道数
# 数据集预处理
transform = transforms.Compose([
transforms.Resize(image_size),
transforms.CenterCrop(image_size),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])
# 加载CIFAR10数据集
dataset = torchvision.datasets.CIFAR10(root='./data', train=True, transform=transform, download=True)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
# DCGAN判别器网络
class Discriminator(nn.Module):
def __init__(self, nc=3, ndf=64):
super(Discriminator, self).__init__()
self.main = nn.Sequential(
# 输入: nc x 64 x 64
nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
nn.LeakyReLU(0.2, inplace=True),
# 状态尺寸: ndf x 32 x 32
nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 2),
nn.LeakyReLU(0.2, inplace=True),
# 状态尺寸: (ndf*2) x 16 x 16
nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 4),
nn.LeakyReLU(0.2, inplace=True),
# 状态尺寸: (ndf*4) x 8 x 8
nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 8),
nn.LeakyReLU(0.2, inplace=True),
# 状态尺寸: (ndf*8) x 4 x 4
nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
nn.Sigmoid()
)
def forward(self, input):
return self.main(input).view(-1, 1).squeeze(1)
# DCGAN生成器网络
class Generator(nn.Module):
def __init__(self, nc=3, ngf=64, nz=100):
super(Generator, self).__init__()
self.main = nn.Sequential(
# 输入是一个nz维度的噪声向量,进入转置卷积
nn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False),
nn.BatchNorm2d(ngf * 8),
nn.ReLU(True),
# 状态尺寸: (ngf*8) x 4 x 4
nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf * 4),
nn.ReLU(True),
# 状态尺寸: (ngf*4) x 8 x 8
nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf * 2),
nn.ReLU(True),
# 状态尺寸: (ngf*2) x 16 x 16
nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf),
nn.ReLU(True),
# 状态尺寸: (ngf) x 32 x 32
nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
nn.Tanh()
# 状态尺寸: nc x 64 x 64
)
def forward(self, input):
return self.main(input)
# 权重初始化
def weights_init(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
nn.init.normal_(m.weight.data, 0.0, 0.02)
elif classname.find('BatchNorm') != -1:
nn.init.normal_(m.weight.data, 1.0, 0.02)
nn.init.constant_(m.bias.data, 0)
# 初始化网络
netG = Generator(nc, 64, z_dim).to(device)
netD = Discriminator(nc, 64).to(device)
# 初始化权重
netG.apply(weights_init)
netD.apply(weights_init)
# 打印模型结构
print(netG)
print(netD)
# 设置损失函数和优化器
criterion = nn.BCELoss()
# 使用Adam优化器,按照DCGAN论文中的建议设置参数
optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))
# 创建固定的噪声向量,用于可视化生成器的进展
fixed_noise = torch.randn(64, z_dim, 1, 1, device=device)
# 创建真和假的标签
real_label = 1.
fake_label = 0.
# 用于保存训练过程中的损失值
G_losses = []
D_losses = []
img_list = []
print("开始训练...")
# 训练循环
for epoch in range(epochs):
for i, data in enumerate(dataloader, 0):
############################
# (1) 更新判别器: 最大化 log(D(x)) + log(1 - D(G(z)))
###########################
## 用真实图像训练判别器
netD.zero_grad()
real_cpu = data[0].to(device)
batch_size = real_cpu.size(0)
label = torch.full((batch_size,), real_label, dtype=torch.float, device=device)
output = netD(real_cpu)
errD_real = criterion(output, label)
errD_real.backward()
D_x = output.mean().item()
## 用生成的假图像训练判别器
noise = torch.randn(batch_size, z_dim, 1, 1, device=device)
fake = netG(noise)
label.fill_(fake_label)
output = netD(fake.detach())
errD_fake = criterion(output, label)
errD_fake.backward()
D_G_z1 = output.mean().item()
errD = errD_real + errD_fake
optimizerD.step()
############################
# (2) 更新生成器: 最大化 log(D(G(z)))
###########################
netG.zero_grad()
label.fill_(real_label) # 生成器希望判别器将假图像判为真
output = netD(fake)
errG = criterion(output, label)
errG.backward()
D_G_z2 = output.mean().item()
optimizerG.step()
# 输出训练状态
if i % 50 == 0:
print(f'[{epoch}/{epochs}][{i}/{len(dataloader)}] '
f'Loss_D: {errD.item():.4f} Loss_G: {errG.item():.4f} '
f'D(x): {D_x:.4f} D(G(z)): {D_G_z1:.4f}/{D_G_z2:.4f}')
# 保存损失,用于以后绘图
G_losses.append(errG.item())
D_losses.append(errD.item())
# 检查生成器如何处理固定噪声向量
if (i % 500 == 0) or ((epoch == epochs-1) and (i == len(dataloader)-1)):
with torch.no_grad():
fake = netG(fixed_noise).detach().cpu()
img_list.append(torchvision.utils.make_grid(fake, padding=2, normalize=True))
# 在每个epoch结束后保存模型
torch.save(netG.state_dict(), f'dcgan_generator_epoch_{epoch}.pth')
torch.save(netD.state_dict(), f'dcgan_discriminator_epoch_{epoch}.pth')
# 在每个epoch结束后显示生成的图像
with torch.no_grad():
fake = netG(fixed_noise).detach().cpu()
img_grid = torchvision.utils.make_grid(fake, padding=2, normalize=True)
plt.figure(figsize=(8, 8))
plt.imshow(np.transpose(img_grid, (1, 2, 0)))
plt.axis('off')
plt.title(f'DCGAN Generated Images - Epoch {epoch}')
plt.savefig(f'dcgan_images_epoch_{epoch}.png')
plt.close()
print("训练完成!")
# 绘制生成器和判别器的损失曲线
plt.figure(figsize=(10, 5))
plt.title("Generator and Discriminator Loss During Training")
plt.plot(G_losses, label="G")
plt.plot(D_losses, label="D")
plt.xlabel("iterations")
plt.ylabel("Loss")
plt.legend()
plt.savefig("dcgan_loss_plot.png")
plt.close()
# 显示训练进程中生成的图像
fig = plt.figure(figsize=(8, 8))
plt.axis("off")
ims = [[plt.imshow(np.transpose(i, (1, 2, 0)), animated=True)] for i in img_list]
ani = animation.ArtistAnimation(fig, ims, interval=1000, repeat_delay=1000, blit=True)
plt.savefig("dcgan_generation_animation.png")
plt.close()
# 展示最终生成的图像
plt.figure(figsize=(8, 8))
plt.axis("off")
plt.title("Final DCGAN Generated Images")
plt.imshow(np.transpose(img_list[-1], (1, 2, 0)))
plt.savefig("final_dcgan_generated_images.png")
plt.close()
print("所有结果已保存!")
11. GAN训练中的常见问题及解决方案
以下是GAN训练中常见的问题及其解决方案:
问题 | 原因 | 解决方案 |
---|---|---|
模式崩溃 | 生成器只学习生成有限种类的样本 | 1. 使用WGAN或WGAN-GP 2. 小批量判别(Minibatch Discrimination) 3. 特征匹配(Feature Matching) 4. 在判别器中添加噪声 |
判别器过强 | 判别器学习速度太快,导致生成器没有有效梯度 | 1. 降低判别器的学习率 2. 减少判别器的更新频率 3. 添加标签平滑(Label Smoothing) 4. 在生成器损失中添加辅助任务 |
训练不稳定 | 损失函数波动大或不收敛 | 1. 使用WGAN或WGAN-GP 2. 梯度裁剪或梯度惩罚 3. 调整学习率 4. 使用适当的架构设计 |
梯度消失 | 在训练初期,判别器可以轻易区分真假样本 | 1. 使用WGAN 2. 使用非饱和生成器损失 3. 标签翻转(Label Flipping) |
梯度爆炸 | 网络权重更新过大 | 1. 梯度裁剪 2. 权重归一化 3. 调整批量大小 4. 使用适当的初始化 |
12. GAN中的损失函数比较
下表比较了不同GAN变体中的损失函数:
13. 结论与实践建议
GAN是深度学习领域的一个重要创新,它为生成模型带来了革命性的变化。通过本节的学习,我们深入理解了GAN的数学基础、损失函数推导、改进方案以及实现技巧。
在实践中,我建议遵循以下原则:
- 从简单开始:先实现标准GAN,了解其基本原理和训练行为
- 选择合适的架构:根据任务选择适当的网络架构,DCGAN是一个良好的起点
- 使用改进的损失函数:考虑使用WGAN-GP等改进的损失函数提高训练稳定性
- 批量归一化应用:在生成器中恰当使用批量归一化,但注意输出层前不要使用
- 监控训练过程:定期生成样本并检查,及时调整超参数
清华大学全三版的《DeepSeek教程》完整的文档需要的朋友,关注我私信:deepseek 即可获得。
怎么样今天的内容还满意吗?再次感谢朋友们的观看,关注GZH:凡人的AI工具箱,回复666,送您价值199的AI大礼包。最后,祝您早日实现财务自由,还请给个赞,谢谢!