当前位置: 首页 > article >正文

Pytorch实现之统计全局信息的轻量级EGAN

简介

简介:模型在EGAN的基础上改进了一个降维的自注意力机制,并且设计了一个新颖的选择算子,使用轮盘赌来选择个体,如果他们的适配度满足fchild<VALUE,则被选中的个体将被丢弃。需要在进化的初始阶段尽快找到最佳个体,并在后续阶段保持种群的多样性。

论文题目:LGEGAN: A Lightweight Evolutionary Generative Adversarial Network with Statistic Global Information(具有统计全局信息的轻量级进化生成对抗网络)

会议:2023 Chinese Control Conference

摘要:生成对抗网络(GAN)已经在很多领域得到了应用。然而,现有的GAN及其变体遇到了很多问题,包括模式崩溃、训练不稳定和陷入局部最优。因此,我们构建了一个具有统计全局信息的轻量级进化生成对抗网络(LGEGAN)。为了解决浅层卷积神经网络难以捕获远程特征依赖和训练过程容易出现模式崩溃的问题,LGEGAN与EGAN的不同之处在于我们在生成器网络中添加了改进的自注意力机制。为了解决训练过程中不稳定的问题,我们在LGEGAN中添加了谱归一化,这增加了每一代训练过程的稳定性。 最后,为了有效地在短时间内进化出适应环境的个体,解决陷入局部最优的问题,我们构造了一种新颖的选择算子,并将其应用于LGEGAN对生成器进化的选择阶段。在实验中,我们从生成图像的质量和多样性、模式的崩溃、训练的稳定性和架构的鲁棒性四个方面对LGEGAN进行了评估。实验结果表明,LGEGAN的性能优于EGAN、MOEGAN、SMOEGAN、LRGAN、ProbGAN和其他生成对抗网络模型。

模型结构

作者在论文中给出了具体的LGEGAN的模型结构,如下图。这部分在后续的代码中有一些注意事项。

LGEGAN模型框架

生成器架构代码注意点

作者在表中列出的四个反卷积层最后的输出是128*128*3的图像尺寸,然而按照上述的代码实现起来的一个图像尺寸是45*45*3,无法达到预期的要求,这其中作者可能设置了padding或者存在一些尺寸的扩大等操作,因此我在生成器的代码层上多加了两个反卷积层来达到128*128*3的一个尺寸大小,符合后续的操作规范。

class Generator(nn.Module):
    def __init__(self, z_dim):
        super(Generator, self).__init__()
        self.layer1 = nn.Sequential(
            nn.ConvTranspose2d(z_dim, 512, kernel_size=4, stride=1, padding=0),
            nn.BatchNorm2d(512),
            nn.ReLU()
        )
        self.layer2 = nn.Sequential(
            nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU()
        )
        self.layer3 = nn.Sequential(
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU()
        )
        self.attention = SelfAttention(128)
        self.layer4 = nn.Sequential(
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU()
        )
        self.layer5 = nn.Sequential(
            nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU()
        )
        self.layer6 = nn.Sequential(
            nn.ConvTranspose2d(32, 3, kernel_size=4, stride=2, padding=1),
            nn.Tanh()
        )

    def forward(self, z):
        z = z.view(z.size(0), z.size(1), 1, 1)
        out = self.layer1(z)
        #print(f"Generator layer1 output shape: {out.shape}")
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.attention(out)
        out = self.layer4(out)
        out = self.layer5(out)
        #print(f"Generator layer5 output shape: {out.shape}")
        out = self.layer6(out)
        #print(f"Generator layer6 output shape: {out.shape}")
        return out

改进的自注意力机制

在K和V的输出之后做了一个reduced dimension projection操作来进行降为。这一点内容博主较有疑问,维度降低之后如何与最先前的输入x进行匹配?作者在图中没有很明确的给出,于是我在代码的实现部分在后续进行了升维与输入x相加。

 改进的选择算子

为了解决陷入局部最优的问题,作者在文中构造了一个基于轮盘算法的选择算子并添加了一个丢弃策略。设置了一个阈值VALUE,当被选中的个体的适配度小于阈值时丢弃。这样做来避免陷入局部最优!

具体操作是:使用轮盘赌来选择个体,如果他们的适配度满足fchild<VALUE,则被选中的个体将被丢弃。需要在进化的初始阶段尽快找到最佳个体,并在后续阶段保持种群的多样性。因此,使丢弃概率与训练迭代次数负相关。VALUE定义为:

其中a是常数,I是训练迭代的次数,并且是前一个总体的平均适应度。当训练迭代次数小于20000时,可以得到a/I=0.9。 

选择算子的代码部分

class RouletteWheelSelector:
    def __init__(self, threshold_value, decay_factor=0.9):
        self.threshold_value = threshold_value
        self.decay_factor = decay_factor

    def select(self, fitness_scores, iteration):
        # 计算丢弃概率
        discard_prob = self.decay_factor / (iteration + 1)
        selected_indices = []

        for i, score in enumerate(fitness_scores):
            if score < self.threshold_value and np.random.rand() < discard_prob:
                continue  # 丢弃
            selected_indices.append(i)

        return selected_indices

训练自己的数据集 

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.utils import save_image
import os
import numpy as np

class SelfAttention(nn.Module):
    def __init__(self, in_channels):
        super(SelfAttention, self).__init__()
        self.query_conv = nn.Conv2d(in_channels, in_channels // 4, kernel_size=1)
        self.key_conv = nn.Conv2d(in_channels, in_channels, kernel_size=1)
        self.value_conv = nn.Conv2d(in_channels, in_channels, kernel_size=1)
        self.key_reduction = nn.Linear(in_channels, in_channels // 4)
        self.value_reduction = nn.Linear(in_channels, in_channels // 4)
        self.out_reduction = nn.Linear(in_channels // 4, in_channels)
        self.gamma = nn.Parameter(torch.zeros(1))

    def forward(self, x):
        batch_size, C, width, height = x.size()
        query = self.query_conv(x).view(batch_size, -1, width * height).permute(0, 2, 1)
        key = self.key_conv(x).view(batch_size, -1, width * height).permute(0, 2, 1)
        key = self.key_reduction(key)
        key = key.permute(0, 2, 1)

        # 检查 query 和 key 的形状是否匹配
        assert query.size(2) == key.size(1), f"query dim 2 ({query.size(2)}) must match key dim 1 ({key.size(1)})"

        energy = torch.bmm(query, key)
        attention = F.softmax(energy, dim=-1)

        value = self.value_conv(x).view(batch_size, -1, width * height).permute(0, 2, 1)
        value = self.value_reduction(value)
        value = value.permute(0, 2, 1)

        out = torch.bmm(value, attention.permute(0, 2, 1))
        #print(f"out shape before out_reduction: {out.shape}, num elements: {out.numel()}")
        # 调整 out 的维度
        out = out.permute(0, 2, 1)
        out = self.out_reduction(out)
        out = out.permute(0, 2, 1)
        #print(f"out shape after out_reduction: {out.shape}, num elements: {out.numel()}")
        #print(f"Target shape: [batch_size={batch_size}, C={C}, width={width}, height={height}], num elements: {batch_size * C * width * height}")
        out = out.view(batch_size, C, width, height)
        out = self.gamma * out + x
        return out

class Generator(nn.Module):
    def __init__(self, z_dim):
        super(Generator, self).__init__()
        self.layer1 = nn.Sequential(
            nn.ConvTranspose2d(z_dim, 512, kernel_size=4, stride=1, padding=0),
            nn.BatchNorm2d(512),
            nn.ReLU()
        )
        self.layer2 = nn.Sequential(
            nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU()
        )
        self.layer3 = nn.Sequential(
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU()
        )
        self.attention = SelfAttention(128)
        self.layer4 = nn.Sequential(
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU()
        )
        self.layer5 = nn.Sequential(
            nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU()
        )
        self.layer6 = nn.Sequential(
            nn.ConvTranspose2d(32, 3, kernel_size=4, stride=2, padding=1),
            nn.Tanh()
        )

    def forward(self, z):
        z = z.view(z.size(0), z.size(1), 1, 1)
        out = self.layer1(z)
        #print(f"Generator layer1 output shape: {out.shape}")
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.attention(out)
        out = self.layer4(out)
        out = self.layer5(out)
        #print(f"Generator layer5 output shape: {out.shape}")
        out = self.layer6(out)
        #print(f"Generator layer6 output shape: {out.shape}")
        return out

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.layer1 = nn.Sequential(
            nn.utils.spectral_norm(nn.Conv2d(3, 128, kernel_size=3, stride=2, padding=1)),
            nn.LeakyReLU(0.2)
        )
        self.layer2 = nn.Sequential(
            nn.utils.spectral_norm(nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1)),
            nn.LeakyReLU(0.2)
        )
        self.layer3 = nn.Sequential(
            nn.utils.spectral_norm(nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1)),
            nn.LeakyReLU(0.2)
        )
        self.layer4 = nn.Sequential(
            nn.Conv2d(512, 1, kernel_size=3, stride=1, padding=1)
        )
        self.global_avg_pool = nn.AdaptiveAvgPool2d((1, 1))  # 添加全局平均池化层

    def forward(self, x):
        #print(f"Discriminator input shape: {x.shape}")  # 打印输入形状
        out = self.layer1(x)
        #print(f"Discriminator layer1 output shape: {out.shape}")
        out = self.layer2(out)
        #print(f"Discriminator layer2 output shape: {out.shape}")
        out = self.layer3(out)
        #print(f"Discriminator layer3 output shape: {out.shape}")
        out = self.layer4(out)
        #print(f"Discriminator layer4 output shape: {out.shape}")
        out = self.global_avg_pool(out)  # 应用全局平均池化
        #print(f"Discriminator global_avg_pool output shape: {out.shape}")
        out = out.view(out.size(0), -1)  # 调整输出形状为二维张量
        return out

# 改进的选择算子
class RouletteWheelSelector:
    def __init__(self, threshold_value, decay_factor=0.9):
        self.threshold_value = threshold_value
        self.decay_factor = decay_factor

    def select(self, fitness_scores, iteration):
        # 计算丢弃概率
        discard_prob = self.decay_factor / (iteration + 1)
        selected_indices = []

        for i, score in enumerate(fitness_scores):
            if score < self.threshold_value and np.random.rand() < discard_prob:
                continue  # 丢弃
            selected_indices.append(i)

        return selected_indices

# 数据预处理和加载
transform = transforms.Compose([
    transforms.Resize(128),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

dataset = datasets.ImageFolder(root=r"C:\Users\Wu Meishun\Desktop\02", transform=transform)
dataloader = DataLoader(dataset, batch_size=64, shuffle=True)

# 初始化模型
z_dim = 100
generator = Generator(z_dim)
discriminator = Discriminator()

# 定义损失函数和优化器
criterion = nn.BCEWithLogitsLoss()
optimizer_G = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))

# 训练循环
num_epochs = 100
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
generator.to(device)
discriminator.to(device)

# 初始化选择算子
selector = RouletteWheelSelector(threshold_value=0.5)

for epoch in range(num_epochs):
    for i, (real_images, _) in enumerate(dataloader):
        real_images = real_images.to(device)
        batch_size = real_images.size(0)

        # 训练鉴别器
        optimizer_D.zero_grad()
        real_labels = torch.ones(batch_size, 1).to(device)
        fake_labels = torch.zeros(batch_size, 1).to(device)

        # 真实图像的损失
        real_outputs = discriminator(real_images)
        d_loss_real = criterion(real_outputs, real_labels)

        # 生成假图像
        z = torch.randn(batch_size, z_dim).to(device)
        fake_images = generator(z)

        # 假图像的损失
        with torch.no_grad():  # 使用 torch.no_grad() 来避免计算梯度
            fake_outputs = discriminator(fake_images)
        d_loss_fake = criterion(fake_outputs, fake_labels)

        # 总损失
        d_loss = d_loss_real + d_loss_fake
        d_loss.backward()
        optimizer_D.step()

        # 训练生成器
        optimizer_G.zero_grad()
        fake_outputs = discriminator(fake_images)
        g_loss = criterion(fake_outputs, real_labels)
        g_loss.backward()
        optimizer_G.step()

        # 使用选择算子
        fitness_scores = fake_outputs.squeeze().detach().cpu().numpy()
        selected_indices = selector.select(fitness_scores, epoch * len(dataloader) + i)

        if i % 100 == 0:
            print(f"Epoch [{epoch}/{num_epochs}] Batch {i}/{len(dataloader)} "
                  f"Loss D: {d_loss.item():.4f}, Loss G: {g_loss.item():.4f}")

    # 保存生成的图像
    if epoch % 10 == 0:
        with torch.no_grad():
            fake_images = generator(torch.randn(64, z_dim).to(device))

    if epoch % 100 == 0:
        torch.save(generator.state_dict(), f"generator_epoch_{epoch}.pth")

 


http://www.kler.cn/a/553322.html

相关文章:

  • 计算机视觉算法实战——图像合成(主页有源码)
  • PHP培训机构教务管理系统小程序源码
  • CF1801D
  • ffmpeg configure 研究2:分析屏幕输出及文件输出的具体过程
  • 洛谷B2139
  • 解析Uniprot数据库数据|Python
  • PrimeFaces实战:IdleMonitor与Ajax的完美结合
  • Linux之kernel(1)系统基础理论(4)
  • 鸿蒙第三方库MMKV源码学习笔记
  • Redis字符串常见命令(String)
  • 深入浅出C语言内存模型——高阶篇
  • springboot-ffmpeg-m3u8-convertor nplayer视频播放弹幕效果
  • WIN系统服务器如何修改远程端口?
  • 人工智能学习环境配置
  • qt for android release apk 手动签名方式
  • 如何使用Spark SQL进行复杂的数据查询和分析
  • TPU(Tensor Processing Unit)详解
  • 使用DeepSeek+本地知识库,尝试从0到1搭建高度定制化工作流(自动化篇)
  • 新品发布:即插即用,8寸Type-C接口电脑副屏显示器发布!
  • 6.4 k8s的informer机制