Pytorch实现之特征损失与残差结构稳定GAN训练,并训练自己的数据集
简介
简介:生成器和鉴别器分别采用了4个新颖设计的残差结构实现,同时在损失中结合了鉴别器层的特征损失来提高模型性能。
论文题目:Image Generation by Residual Block Based Generative Adversarial Networks(基于残留块的生成对抗网络产生图像)
会议:2022 IEEE International Conference on Consumer Electronics (ICCE)
摘要:生成对抗网络是一种用于解决人工智能任务的流行深度学习技术,并且已广泛研究并应用于处理图像,声音,文本等。 特别是,在图像处理领域(例如图像样式传输,图像恢复,图像超分辨率等)采用了生成对抗网络。 尽管生成的对抗网络在图像生成方面表现出色,但训练过程通常是不稳定和受过训练的模型崩溃的,许多生成的图像可能包含相同的颜色或纹理模式。 在本文中,修改了生成器和鉴别器的网络,并将残留块添加到生成对抗网络体系结构中,以学习更好的图像功能。 为了减少训练过程中图像功能的丢失并获得更多功能以稳定图像生成,我们使用功能匹配来最大程度地减少真实图像和生成的图像之间的特征损失,以进行稳定训练。 在实验中,可以通过采用我们提出的方法来提高性能,这也比某些最先进的方法更好。
模型结构
总体架构
生成器残差架构与鉴别器残差架构
class ResidualBlockG(nn.Module):
def __init__(self, in_channels, out_channels, scale_factor=2):
super(ResidualBlockG, self).__init__()
self.path1_conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
self.upsample = nn.Upsample(scale_factor=scale_factor, mode='bilinear', align_corners=True)
self.path1_conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
self.path2_conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
self.prelu = nn.PReLU()
def forward(self, x):
path1 = self.path1_conv1(x)
path1 = self.upsample(path1)
path1 = self.path1_conv2(path1)
path2 = self.path2_conv(x)
path2 = self.upsample(path2)
out = self.prelu(path1 + path2)
return out
# 定义鉴别器的残差块
class ResidualBlockD(nn.Module):
def __init__(self, in_channels, out_channels, scale_factor=2):
super(ResidualBlockD, self).__init__()
sel