当前位置: 首页 > article >正文

Pytorch实现之结合SE注意力和多种损失的特征金字塔架构GAN的图像去模糊方法

简介

简介:提出了一种利用特征金字塔作为框架代替多尺度输入的一种方法来构建生成器模型,减少了模型规模并加快了训练速度。在模型架构中还融合了通道注意力方法来提高训练能力。作者在生成器中采用了三种常见的损失计算,在鉴别器中结合了最小二乘和相对论损失来改善模型训练。

论文题目:Image Deblurring Based on Generative Adversarial Networks(基于生成对抗网络的图像去模糊)

会议:International Conference on Intelligent Computing and Signal Processing (ICSP)

摘要:图像去模糊技术利用深度学习方法解决单幅图像的模糊问题,这是计算机视觉领域的一个具有挑战性的问题。 近年来,深度学习和计算机视觉的快速发展,提高了模糊处理算法的性能。 本文从深度学习的角度研究图像去模糊问题,利用卷积神经网络实现图像去模糊的目的。 针对多尺度网络单次去模糊处理规模庞大,重要特征信息未得到充分利用的问题,提出了一种基于生成对抗网络的去模糊算法。 该模型采用特征金字塔网络作为框架代替多尺度输入,有效地减小了网络规模,加快了训练速度。 为了更好地利用特征信息,在网络中引入了注意机制和双尺度判别器。 为了使训练过程更加稳定,该算法采用最小二乘和相对论相结合的方法改善了鉴别器的损失。 实验结果表明,基于生成对抗网络的图像去模糊算法比其他算法具有更好的恢复效果。

模型结构

生成器架构

生成器设计介绍

作者提到,目前,在现有的图像去模糊任务中,骨干网通常使用类似ResNet的网络。 大多数处理不同程度模糊图像的先进方法都使用多尺度输入方法来消除模糊。

然而,多尺度模式下的输入法往往会消耗更多的时间和占用大量的内存,因此该模型中的生成器使用特征金字塔网络而不是多尺度网络。设计的特征金字塔网络结构是一种编解码的形式,它包括两条路径。从浅到深的路径可以看作是编码部分,主要用于提取输入图像的特征。分辨率降低了,但它可以提取高级特征并压缩更多的上下文语义信息。从深到浅的路径可以看作是解码部分,通过上采样恢复空间分辨率,并结合高级特征和丰富的语义信息生成清晰的图像。此外,两条路径之间的水平连接补充了高分辨率的细节,有助于恢复更清晰的图像。

在论文的提议中,网络模型通过激活权值的方式,在生成器中加入关注模块,加强对重要特征的关注。 当卷积层数较浅时,加入注意力模块会使计算量过大,而卷积层数较深时,特征图之间的差异较小。  

 

class ConvBlock(nn.Module):
    def __init__(self, in_channel, out_channel):
        super(ConvBlock, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(in_channel, out_channel, 4, 2, 1),
            nn.BatchNorm2d(out_channel),
            nn.ReLU()
        )

    def forward(self, x):
        x = self.model(x)
        return x

class ConvBlock_1(nn.Module):
    def __init__(self, in_channel, out_channel):
        super(ConvBlock_1, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(in_channel, out_channel, 3, 1, 1),
            nn.BatchNorm2d(out_channel),
            nn.ReLU()
        )

    def forward(self, x):
        x = self.model(x)
        return x

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.ConvBlock1 = ConvBlock(3, 32)
        self.ConvBlock2 = ConvBlock(32, 32)
        self.ConvBlock3 = ConvBlock(32, 32)
        self.ConvBlock4 = ConvBlock(32, 32)
        self.ConvBlock5 = ConvBlock(32, 32)
        self.ConvBlock6 = ConvBlock_1(32, 3)
        self.ConvBlock1_1 = ConvBlock_1(32, 32)
        self.con1_1 = nn.Conv2d(32, 32, 1)
        self.SE = SE(32, 8)
        self.Up = nn.Upsample(scale_factor=2)
        self.Up4 = nn.Upsample(scale_factor=8)
        self.Up3 = nn

http://www.kler.cn/a/557855.html

相关文章:

  • js如何直接下载文件流
  • #渗透测试#批量漏洞挖掘#Progress Software Flowmon命令执行漏洞(CVE-2024-2389)
  • STM32MP157A单片机驱动--控制拓展版的灯实现流水效果
  • 从函数到神经网络
  • Elasticsearch常用的查询条件
  • [Android]使用WorkManager循环执行任务
  • 【开放词汇分割】Image Segmentation Using Text and Image Prompts
  • 设计心得——解耦的实现技术
  • 打开Firefox自动打开hao360.hjttif.com标签解决方案
  • java Web
  • 【论文解析】Fast prediction mode selection and CU partition for HEVC intra coding
  • 【漫话机器学习系列】100.L2 范数(L2 Norm,欧几里得范数)
  • .NET MVC实现电影票管理
  • 电商API安全防护:JWT令牌与XSS防御实战
  • android 快速定位当前页面
  • 设计模式之组合设计模式实战 文件展示 树叶子节点
  • chrome扩展程序如何实现国际化
  • springboot3.x整合fastdfs
  • Wireshark详解
  • cs106x-lecture14(Autumn 2017)-SPL实现