当前位置: 首页 > article >正文

生成对抗网络(GAN)入门与编程实现

生成对抗网络(Generative Adversarial Networks, 简称 GAN)自 2014 年由 Ian Goodfellow 等人提出以来,迅速成为机器学习和深度学习领域的重要工具之一。GAN 以其在图像生成、风格转换、数据增强等领域的出色表现,吸引了广泛的研究兴趣和应用探索。本文将介绍 GAN 的基本概念、工作原理以及如何通过代码实现一个简单的 GAN 模型。

什么是生成对抗网络(GAN)?

GAN 是一种生成模型,旨在通过学习数据的潜在分布,生成与真实数据相似的样本。它由两个核心部分组成:

  • 生成器(Generator):输入一个随机噪声向量,通过一系列的变换生成假数据,目标是让生成的假数据尽可能接近真实数据。
  • 判别器(Discriminator):输入真实数据和生成器生成的假数据,输出判断其真实性的概率,目标是尽可能准确地区分真实数据和生成数据。
    二者在训练过程中相互对抗,形成一个博弈过程。

在这里插入图片描述

GAN 的工作原理

GAN 的训练过程可以看作是生成器和判别器之间的"零和博弈":

  1. 生成器:
  • 输入随机噪声向量 z z z(通常服从正态分布)。
  • 输出生成的样本 G ( z ) G(z) G(z)
  • 目标是让判别器无法区分 G ( z ) G(z) G(z) 和真实数据。
  1. 判别器:
  • 输入真实样本 x x x 和生成器生成的假样本 G ( z ) G(z) G(z)
  • 输出区分真假样本的概率。
  • 目标是最大化对真实样本和生成样本的区分能力。

通过对模型进行训练,生成器逐渐生成更接近真实分布的样本,而判别器也不断提高其判别能力,直到达到平衡。
在这里插入图片描述
完整的训练过程如下:
在这里插入图片描述

GAN 的代码实现

接下来,我们通过 PyTorch 实现一个简单的 GAN 模型,生成 MNIST 手写数字图片。

  1. 数据加载与预处理
    MNIST 是一个常用的手写数字数据集,每张图片的大小为 28x28,灰度范围为 0-1。
# data_loader
transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.5), std=(0.5))
])
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('data', train=True, download=True, transform=transform),
    batch_size=batch_size, shuffle=True)

使用 torchvision 的 datasets.MNIST 下载MNIST数据集。之后,将图片转换为Tensor格式,并对像素值进行归一化(均值0.5,标准差0.5)。

  1. 构建生成器与判别器
    生成器和判别器都是多层全连接神经网络。
# G(z)
class generator(nn.Module):
    # initializers
    def __init__(self, input_size=32, n_class = 10):
        super(generator, self).__init__()
        self.fc1 = nn.Linear(input_size, 256)
        self.fc2 = nn.Linear(self.fc1.out_features, 512)
        self.fc3 = nn.Linear(self.fc2.out_features, 1024)
        self.fc4 = nn.Linear(self.fc3.out_features, n_class)

    # forward method
    def forward(self, input):
        x = F.leaky_relu(self.fc1(input), 0.2)
        x = F.leaky_relu(self.fc2(x), 0.2)
        x = F.leaky_relu(self.fc3(x), 0.2)
        x = F.tanh(self.fc4(x))
        x = x.squeeze(-1)

        return x

class discriminator(nn.Module):
    # initializers
    def __init__(self, input_size=32, n_class=10):
        super(discriminator, self).__init__()
        self.fc1 = nn.Linear(input_size, 1024)
        self.fc2 = nn.Linear(self.fc1.out_features, 512)
        self.fc3 = nn.Linear(self.fc2.out_features, 256)
        self.fc4 = nn.Linear(self.fc3.out_features, n_class)

    # forward method
    def forward(self, input):
        x = F.leaky_relu(self.fc1(input), 0.2)
        x = F.dropout(x, 0.3)
        x = F.leaky_relu(self.fc2(x), 0.2)
        x = F.dropout(x, 0.3)
        x = F.leaky_relu(self.fc3(x), 0.2)
        x = F.dropout(x, 0.3)
        x = F.sigmoid(self.fc4(x))
        x = x.squeeze(-1)
        return x
        
# network
G = generator(input_size=100, n_class=28*28)
D = discriminator(input_size=28*28, n_class=1)
  • 生成器 (generator):

    • 输入:一个大小为100的噪声向量。
    • 结构:包含4个全连接层(fc1到fc4),每层后面跟随一个激活函数:
      • 前三层使用 LeakyReLU 激活函数,最后一层使用 tanh。
      • 输出大小为 28×28(MNIST图片的尺寸)。
    • 功能:将随机噪声映射为类似于手写数字的图片。
  • 判别器 (discriminator):

    • 输入:展平的MNIST图片(大小为 28×28)。
    • 结构:包含4个全连接层(fc1到fc4),每层后面跟随:
      • LeakyReLU 激活函数和 Dropout(用于防止过拟合)。
      • 最后一层使用 sigmoid 激活函数。
    • 输出:一个介于0和1之间的值,表示输入是“真实图片”的概率。
  1. 定义训练参数以及损失函数和优化器
# training parameters
batch_size = 256
lr = 0.0002
train_epoch = 200
device = torch.cuda.is_available()
if device:
    print("running on GPU!")
    
# Binary Cross Entropy loss
BCE_loss = nn.BCELoss()

#move to cuda
if device:
    G.cuda()
    D.cuda()
    BCE_loss = BCE_loss.cuda()

# Adam optimizer
G_optimizer = optim.Adam(G.parameters(), lr=lr)
D_optimizer = optim.Adam(D.parameters(), lr=lr)
4. 训练过程
在训练过程中,我们交替训练判别器和生成器。
train_hist = {}
train_hist['D_losses'] = []
train_hist['G_losses'] = []

for epoch in range(train_epoch):
    D_losses = []
    G_losses = []
    #生成任务,不需要标签
    for x_, _ in train_loader:
        #训练图像展平
        x_ = x_.view(-1, 28 * 28)

        mini_batch = x_.size()[0]

        y_real_ = torch.ones(mini_batch)
        y_fake_ = torch.zeros(mini_batch)

        # train discriminator D
        D.zero_grad()
        z_ = torch.randn((mini_batch, 100))
        
        if device:
            x_, y_real_, y_fake_ = x_.cuda(), y_real_.cuda(), y_fake_.cuda()
            z_ = z_.cuda()

        #真数据loss
        D_result = D(x_)
        D_real_loss = BCE_loss(D_result, y_real_)
        D_real_score = D_result

        #假数据loss
        G_result = G(z_)
        D_result = D(G_result)
        D_fake_loss = BCE_loss(D_result, y_fake_)
        D_fake_score = D_result

        D_train_loss = D_real_loss + D_fake_loss

        D_train_loss.backward()
        D_optimizer.step()

        D_losses.append(D_train_loss.item())

        # train generator G
        G.zero_grad()
        # z_ = torch.randn((mini_batch, 100))
        # if device:
        #     z_ = z_.cuda()
        G_result = G(z_)
        D_result = D(G_result)

        G_train_loss = BCE_loss(D_result, y_real_)
        G_train_loss.backward()
        G_optimizer.step()

        G_losses.append(G_train_loss.item())

    print('[%d/%d]: loss_d: %.3f, loss_g: %.3f' % (
        (epoch + 1), train_epoch, torch.mean(torch.FloatTensor(D_losses)), torch.mean(torch.FloatTensor(G_losses))))

    if epoch %10 == 0:
        p = 'MNIST_GAN_results/Random_results/MNIST_GAN_' + str(epoch + 1) + '.png'
        fixed_p = 'MNIST_GAN_results/Fixed_results/MNIST_GAN_' + str(epoch + 1) + '.png'
        show_result((epoch+1), save=True, path=p, isFix=False)
        show_result((epoch+1), save=True, path=fixed_p, isFix=True)
        train_hist['D_losses'].append(torch.mean(torch.FloatTensor(D_losses)))
        train_hist['G_losses'].append(torch.mean(torch.FloatTensor(G_losses)))

采用交叉熵损失函数(BCE)计算Loss,即
在这里插入图片描述
其中判别器的loss计算如下:
在这里插入图片描述
生成器的loss计算如下:
在这里插入图片描述

  1. 保存模型及数据
    将生成器和判别器的模型参数进行保存,保存训练过程的loss数据。
print("Training finish!... save training results")
torch.save(G.state_dict(), "MNIST_GAN_results/generator_param.pkl")
torch.save(D.state_dict(), "MNIST_GAN_results/discriminator_param.pkl")
with open('MNIST_GAN_results/train_hist.pkl', 'wb') as f:
    pickle.dump(train_hist, f)
  1. 数据可视化
show_train_hist(train_hist, save=True, path='MNIST_GAN_results/MNIST_GAN_train_hist.png')

images = []
for e in range(train_epoch):
    img_name = 'MNIST_GAN_results/Fixed_results/MNIST_GAN_' + str(e + 1) + '.png'
    images.append(imageio.imread(img_name))
imageio.mimsave('MNIST_GAN_results/generation_animation.gif', images, fps=5)
  1. 完整代码
import os
import matplotlib.pyplot as plt
import itertools
import pickle
import imageio
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
# from torch.autograd import Variable

# G(z)
class generator(nn.Module):
    # initializers
    def __init__(self, input_size=32, n_class = 10):
        super(generator, self).__init__()
        self.fc1 = nn.Linear(input_size, 256)
        self.fc2 = nn.Linear(self.fc1.out_features, 512)
        self.fc3 = nn.Linear(self.fc2.out_features, 1024)
        self.fc4 = nn.Linear(self.fc3.out_features, n_class)

    # forward method
    def forward(self, input):
        x = F.leaky_relu(self.fc1(input), 0.2)
        x = F.leaky_relu(self.fc2(x), 0.2)
        x = F.leaky_relu(self.fc3(x), 0.2)
        x = F.tanh(self.fc4(x))
        x = x.squeeze(-1)

        return x

class discriminator(nn.Module):
    # initializers
    def __init__(self, input_size=32, n_class=10):
        super(discriminator, self).__init__()
        self.fc1 = nn.Linear(input_size, 1024)
        self.fc2 = nn.Linear(self.fc1.out_features, 512)
        self.fc3 = nn.Linear(self.fc2.out_features, 256)
        self.fc4 = nn.Linear(self.fc3.out_features, n_class)

    # forward method
    def forward(self, input):
        x = F.leaky_relu(self.fc1(input), 0.2)
        x = F.dropout(x, 0.3)
        x = F.leaky_relu(self.fc2(x), 0.2)
        x = F.dropout(x, 0.3)
        x = F.leaky_relu(self.fc3(x), 0.2)
        x = F.dropout(x, 0.3)
        x = F.sigmoid(self.fc4(x))
        x = x.squeeze(-1)
        return x

fixed_z_ = torch.randn((5 * 5, 100))    # fixed noise
with torch.no_grad():
    fixed_z_ = fixed_z_.cuda()

def show_result(num_epoch, show = False, save = False, path = 'result.png', isFix=False):
    z_ = torch.randn((5*5, 100))
    with torch.no_grad():
        z_ = z_.cuda()
    # z_ = Variable(z_.cuda(), volatile=True)

    G.eval()
    if isFix:
        test_images = G(fixed_z_)
    else:
        test_images = G(z_)
    G.train()

    size_figure_grid = 5
    fig, ax = plt.subplots(size_figure_grid, size_figure_grid, figsize=(5, 5))
    for i, j in itertools.product(range(size_figure_grid), range(size_figure_grid)):
        ax[i, j].get_xaxis().set_visible(False)
        ax[i, j].get_yaxis().set_visible(False)

    for k in range(5*5):
        i = k // 5
        j = k % 5
        ax[i, j].cla()
        ax[i, j].imshow(test_images[k, :].cpu().data.view(28, 28).numpy(), cmap='gray')

    label = 'Epoch {0}'.format(num_epoch)
    fig.text(0.5, 0.04, label, ha='center')
    plt.savefig(path)

    if show:
        plt.show()
    else:
        plt.close()

def show_train_hist(hist, show = False, save = False, path = 'Train_hist.png'):
    x = range(len(hist['D_losses']))

    y1 = hist['D_losses']
    y2 = hist['G_losses']

    plt.plot(x, y1, label='D_loss')
    plt.plot(x, y2, label='G_loss')

    plt.xlabel('Epoch')
    plt.ylabel('Loss')

    plt.legend(loc=4)
    plt.grid(True)
    plt.tight_layout()

    if save:
        plt.savefig(path)

    if show:
        plt.show()
    else:
        plt.close()

# training parameters
batch_size = 256
lr = 0.0002
train_epoch = 200
device = torch.cuda.is_available()
if device:
    print("running on GPU!")

# data_loader
transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.5), std=(0.5))
])
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('data', train=True, download=True, transform=transform),
    batch_size=batch_size, shuffle=True)

# network
G = generator(input_size=100, n_class=28*28)
D = discriminator(input_size=28*28, n_class=1)

# Binary Cross Entropy loss
BCE_loss = nn.BCELoss()

#move to cuda
if device:
    G.cuda()
    D.cuda()
    BCE_loss = BCE_loss.cuda()

# Adam optimizer
G_optimizer = optim.Adam(G.parameters(), lr=lr)
D_optimizer = optim.Adam(D.parameters(), lr=lr)

# results save folder
if not os.path.isdir('MNIST_GAN_results'):
    os.mkdir('MNIST_GAN_results')
if not os.path.isdir('MNIST_GAN_results/Random_results'):
    os.mkdir('MNIST_GAN_results/Random_results')
if not os.path.isdir('MNIST_GAN_results/Fixed_results'):
    os.mkdir('MNIST_GAN_results/Fixed_results')

train_hist = {}
train_hist['D_losses'] = []
train_hist['G_losses'] = []

for epoch in range(train_epoch):
    D_losses = []
    G_losses = []
    #生成任务,不需要标签
    for x_, _ in train_loader:
        #训练图像展平
        x_ = x_.view(-1, 28 * 28)

        mini_batch = x_.size()[0]

        y_real_ = torch.ones(mini_batch)
        y_fake_ = torch.zeros(mini_batch)

        # train discriminator D
        D.zero_grad()
        z_ = torch.randn((mini_batch, 100))
        
        if device:
            x_, y_real_, y_fake_ = x_.cuda(), y_real_.cuda(), y_fake_.cuda()
            z_ = z_.cuda()

        #真数据loss
        D_result = D(x_)
        D_real_loss = BCE_loss(D_result, y_real_)
        D_real_score = D_result

        #假数据loss
        G_result = G(z_)
        D_result = D(G_result)
        D_fake_loss = BCE_loss(D_result, y_fake_)
        D_fake_score = D_result

        D_train_loss = D_real_loss + D_fake_loss

        D_train_loss.backward()
        D_optimizer.step()

        D_losses.append(D_train_loss.item())

        # train generator G
        G.zero_grad()
        # z_ = torch.randn((mini_batch, 100))
        # if device:
        #     z_ = z_.cuda()
        G_result = G(z_)
        D_result = D(G_result)

        G_train_loss = BCE_loss(D_result, y_real_)
        G_train_loss.backward()
        G_optimizer.step()

        G_losses.append(G_train_loss.item())

    print('[%d/%d]: loss_d: %.3f, loss_g: %.3f' % (
        (epoch + 1), train_epoch, torch.mean(torch.FloatTensor(D_losses)), torch.mean(torch.FloatTensor(G_losses))))

    if epoch %10 == 0:
        p = 'MNIST_GAN_results/Random_results/MNIST_GAN_' + str(epoch + 1) + '.png'
        fixed_p = 'MNIST_GAN_results/Fixed_results/MNIST_GAN_' + str(epoch + 1) + '.png'
        show_result((epoch+1), save=True, path=p, isFix=False)
        show_result((epoch+1), save=True, path=fixed_p, isFix=True)
    
    train_hist['D_losses'].append(torch.mean(torch.FloatTensor(D_losses)))
    train_hist['G_losses'].append(torch.mean(torch.FloatTensor(G_losses)))


print("Training finish!... save training results")
torch.save(G.state_dict(), "MNIST_GAN_results/generator_param.pkl")
torch.save(D.state_dict(), "MNIST_GAN_results/discriminator_param.pkl")
with open('MNIST_GAN_results/train_hist.pkl', 'wb') as f:
    pickle.dump(train_hist, f)

show_train_hist(train_hist, save=True, path='MNIST_GAN_results/MNIST_GAN_train_hist.png')

images = []
for e in range(train_epoch):
    img_name = 'MNIST_GAN_results/Fixed_results/MNIST_GAN_' + str(e + 1) + '.png'
    images.append(imageio.imread(img_name))
imageio.mimsave('MNIST_GAN_results/generation_animation.gif', images, fps=5)

训练结果

在这里插入图片描述

以上是训练190个epoch后得到的结果,可以看到其中某些图片已经有了数字的模样。这里仅仅是使用了全连接层来搭建模型,如果使用卷积神经网络,效果会有更好的提升,大家可以尝试一下。

遇到的问题

可以适当地提高batch size来提高训练速度,也可以切换更简单的loss函数来提高训练速度。
建议batch size从底到高慢慢调节,若batch size过高,可能导致模型训练出现问题。


http://www.kler.cn/a/515078.html

相关文章:

  • Linux(LAMP)
  • GPT 结束语设计 以nanogpt为例
  • 初阶5 排序
  • UE5 开启“Python Remote Execution“
  • 2025-1-21 Newstar CTF web week1 wp
  • Mysql触发器(学习自用)
  • LeetCode:53. 最大子序和
  • 初始Transformer
  • C++ STL(8)map
  • 正则表达式的艺术:轻松驾驭 Python 的 re 库
  • 智能鞋利用机器学习和深度学习技术进行患者监测和步态分析的演变与挑战
  • Roland 键盘合成器接声卡(福克斯特/雅马哈)声音小/音质异常的问题
  • insight在线需求分析系统概要介绍
  • redis离线安装部署详解(包括一键启动)
  • 为什么要申请专利
  • LiveBench:AI 模型基准测试与评估工具解析与实战指南
  • 复位信号的同步与释放(同步复位、异步复位、异步复位同步释放)
  • 【网络协议】【http】【https】TLS解决了HTTP存在的问题-加密通信+摘要,数字签名+CA证书
  • HTTP post请求工具类
  • 博客之星2024年度总评选——我的年度创作回顾与总结
  • Django项目的创建及运行——Django学习日志(一)
  • Ubuntu环境 nginx 源码 编译安装
  • 吴恩达深度学习——神经网络介绍
  • 最新版pycharm如何配置conda环境
  • 考研408笔记之数据结构(七)——排序
  • 使用easyimages部署个人图床服务