使用生成对抗网络(GAN)进行人脸老化生成的Python示例
以下是一个使用生成对抗网络(GAN)进行人脸老化生成的Python示例,我们将使用PyTorch库来实现。GAN由生成器和判别器两部分组成,生成器尝试生成逼真的老化人脸图像,判别器则尝试区分生成的图像和真实的老化人脸图像。
步骤概述
- 数据准备:准备包含不同年龄段人脸的数据集。
- 定义生成器和判别器:构建生成器和判别器的神经网络模型。
- 训练GAN:交替训练生成器和判别器。
- 生成老化人脸图像:使用训练好的生成器生成老化人脸图像。
代码实现
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms, datasets
import numpy as np
import matplotlib.pyplot as plt
# 定义生成器
class Generator(nn.Module):
def __init__(self, z_dim=100, img_dim=784):
super(Generator, self).__init__()
self.gen = nn.Sequential(
nn.Linear(z_dim, 256),
nn.LeakyReLU(0.1),
nn.Linear(256, img_dim),
nn.Tanh()
)
def forward(self, x):
return self.gen(x)
# 定义判别器
class Discriminator(nn.Module):
def __init__(self, img_dim=784):
super(Discriminator, self).__init__()
self.disc = nn.Sequential(
nn.Linear(img_dim, 128),
nn.LeakyReLU(0.1),
nn.Linear(128, 1),
nn.Sigmoid()
)
def forward(self, x):
return self.disc(x)
# 超参数设置
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
lr = 3e-4
z_dim = 100
img_dim = 28 * 28
batch_size = 32
num_epochs = 50
# 数据加载
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
# 这里假设使用MNIST数据集作为示例,实际应用中需要使用人脸数据集
dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
# 初始化生成器和判别器
gen = Generator(z_dim, img_dim).to(device)
disc = Discriminator(img_dim).to(device)
# 定义优化器和损失函数
opt_gen = optim.Adam(gen.parameters(), lr=lr)
opt_disc = optim.Adam(disc.parameters(), lr=lr)
criterion = nn.BCELoss()
# 训练GAN
for epoch in range(num_epochs):
for batch_idx, (real, _) in enumerate(dataloader):
real = real.view(-1, 784).to(device)
batch_size = real.shape[0]
### 训练判别器
noise = torch.randn(batch_size, z_dim).to(device)
fake = gen(noise)
disc_real = disc(real).view(-1)
lossD_real = criterion(disc_real, torch.ones_like(disc_real))
disc_fake = disc(fake.detach()).view(-1)
lossD_fake = criterion(disc_fake, torch.zeros_like(disc_fake))
lossD = (lossD_real + lossD_fake) / 2
disc.zero_grad()
lossD.backward()
opt_disc.step()
### 训练生成器
output = disc(fake).view(-1)
lossG = criterion(output, torch.ones_like(output))
gen.zero_grad()
lossG.backward()
opt_gen.step()
print(f"Epoch [{epoch + 1}/{num_epochs}] Loss D: {lossD.item():.4f}, Loss G: {lossG.item():.4f}")
# 生成老化人脸图像(这里只是简单示例,实际需要更复杂的处理)
num_samples = 16
noise = torch.randn(num_samples, z_dim).to(device)
generated_images = gen(noise).cpu().detach().view(num_samples, 28, 28).numpy()
# 可视化生成的图像
fig, axes = plt.subplots(4, 4, figsize=(4, 4))
axes = axes.flatten()
for i in range(num_samples):
axes[i].imshow(generated_images[i], cmap='gray')
axes[i].axis('off')
plt.show()
代码解释
- 数据准备:使用
torchvision
库加载MNIST数据集作为示例,实际应用中需要使用包含不同年龄段人脸的数据集。 - 生成器和判别器:
Generator
:将随机噪声向量映射到图像空间。Discriminator
:判断输入图像是真实的还是生成的。
- 训练过程:
- 交替训练判别器和生成器。
- 判别器的目标是最大化区分真实图像和生成图像的能力。
- 生成器的目标是生成能够欺骗判别器的图像。
- 生成图像:使用训练好的生成器生成老化人脸图像,并进行可视化。
注意事项
- 此示例使用MNIST数据集作为演示,实际应用中需要使用包含不同年龄段人脸的数据集,如CACD、IMDB-WIKI等。
- 可以根据实际情况调整超参数和网络结构,以获得更好的生成效果。