当前位置: 首页 > article >正文

【论文阅读】Persistent Homology Based Generative Adversarial Network

Persistent Homology Based Generative Adversarial Network

摘要

现有的生成模型不能充分捕捉图像的全局结构信息,使得图像生成过程中难以协调全局结构特征和局部细节特征。该文提出了一种基于持续同调的生成对抗网络(PHGAN)。本文基于持久同调方法设计了拓扑特征变换算法,并通过全连通层模块和自注意模块将拓扑特征集成到遗传神经网络的鉴别器中,使PHGAN具有良好的全局结构信息捕获能力,提高了模型的生成性能。在CIFAR10数据集和STL10数据集上对PHGAN进行了实验评估,并与几种经典的生成式对抗网络模型进行了比较。实验结果表明我们的PHGAN模型具有较好的图像生成能力。

基于卷积神经网络的GAN依赖卷积运算进行特征提取,而卷积运算的卷积核大小有限。它的感受域是有限的,一些长距离的依赖关系不能被捕获,因此模型在整体结构中表现不佳。

相关工作

Khrulkov et al.(Khrulkov和Oseledets,2018)使用生成图像和真实的图像的拓扑特征的近似值作为衡量GAN生成性能的指标。巧合的是,Horak等人(Horak等人,2021)也提出了一种基于持久同源方法的不同GANs生成性能评价指标。但他们提出的只是用生成图像和真实的图像的拓扑特征来衡量GAN的生成性能,并没有用真实的图像的拓扑特征来指导生成器生成图像。

Br¨uel-Gabrielsson等(Gabrielsson等人,2020)提出利用持久同调得到的拓扑特征来指导GAN生成图像,但作者只对输入到生成器的噪声做了显式拓扑特征优化,生成器并没有学习到真实的图像的拓扑特征分布。

此外,也有一些关于持久同源性在其他生成模型中的应用的研究。例如,Schiff et al.(Schiff等人,2022)提出了一种基于持久同源性的变分自动编码器模型,利用拓扑特征作为新的重构损失项来优化变分自动编码器模型的生成性能。

我们提出的PHGAN使用真实的图像的拓扑特征来指导GAN的生成器生成图像,使得生成器能够学习真实的图像的拓扑特征分布。

method

从高斯分布中采样随机噪声,然后将此噪声馈送到生成器中以生成图像。生成的图像和真实的图像被输入到鉴别器中以鉴别真实的和伪造的。在卷积神经网络中,输入图像不仅经过卷积模块处理,得到卷积神经网络的特征,而且经过持久同源模块和拓扑特征变换模块处理,得到拓扑特征。将这两个特征串联起来,用于区分真实的图像和伪图像。

在这里插入图片描述

从图像构造立方体复形,将得到的拓扑特征表示为持久图。然后进行矢量化,即简单计算 lifespan(是不是过于简单了?)
在这里插入图片描述

鉴别器
在这里插入图片描述

https://blog.csdn.net/weixin_44790306/article/details/119256806
与使用全连接层网络结构不同,使用自注意网络会在训练过程中学习卷积神经网络特征与拓扑特征之间的相关性,进而判别图像的真伪。我们将向量ν输入自注意网络,然后卷积神经网络特征与拓扑特征相互作用后得到向量 v s a v_{sa} vsa。我们使用残差网络将向量 v s a v_{sa} vsa添加到原始向量ν
在这里插入图片描述
(γ表示可学习的参数)
以获得矢量v '。最后,将向量v '输入到全连通层,以判断图像的真实性。

损失函数
在这里插入图片描述

原始GAN损失函数

https://blog.csdn.net/weixin_48320163/article/details/128603886
在这里插入图片描述
即先从判别器的角度令损失最大化,又从生成器的角度令损失最小化,即可让判别器和生成器在共享损失的情况下实现对抗。

第一个期望 E x ∼ P d a t a ( x ) [ log ⁡ D ( x ) ] \mathbb{E}_{x \sim{P_{data}} (x)} [\log D(x)] ExPdata(x)[logD(x)]是所有 x x x都是真实数据时 log ⁡ D ( x ) \log D(x) logD(x)的期望;第二个期望 E z ∼ P z ( z ) [ log ⁡ ( 1 − D ( G ( z ) ) ) ] \mathbb{E}_{z \sim{P _{z}}(z)}[\log (1 - D(G(z)))] EzPz(z)[log(1D(G(z)))]是所有数据都是生成数据时 log ⁡ ( 1 − D ( G ( z ) ) ) \log (1 - D(G(z))) log(1D(G(z)))的期望。当真实数据、生成数据的样本点固定时,期望就等于均值。

实验

使用CIFAR10数据集和STL10数据集进行实验评估,并与DCGAN(Radford等,2015)、WGAN-GP(Gulrajani等,2017)以及WGAN(Arjovsky等,2017)进行了比较分析。

CIFAR10数据集由10类32x32彩色图像组成。每个类别包含6000张图像,其中5000张作为训练集,1000张作为测试集。STL10数据集由10类96x96彩色图像组成,每类包含1300张图像,其中500张用于训练,800张用于测试。

在本文的实验中,首先通过中心裁剪将原始图像裁剪成32x32大小的图像,然后使用训练集训练生成器。

使用FID(Fréchet Inception Distance)(Heusel等, 2017)和IS(Inception Score)(Salimans等, 2016)来评估生成模型的性能外,我们还使用了GS(Geometry Score)(Khrulkov和Osеledets, 2018)评价指标:基于拓扑特征相似性生成对抗网络模型生成性能评估。

FID(Fréchet Inception Distance)是一种评估生成模型(特别是生成对抗网络)性能的指标。它通过对比生成图像和真实图像在特征空间中的分布差异来评估生成图像的质量和多样性。具体来说,FID计算了生成图像和真实图像通过预训练的Inception网络获得的嵌入特征的均值和协方差矩阵之间的Fréchet距离。FID值越低,说明生成图像与真实图像的分布越接近,生成模型的性能越好。
IS(Inception Score)也是一种评估生成模型性能的指标。它通过量化生成图像的清晰度和多样性来评估模型的表现。具体来说,IS基于预训练的Inception网络的输出,将生成图像分类为不同类别,并计算每个图像的预测类别分布的熵值。同时,它还考虑了所有生成图像整体的类别分布的熵值。IS值越高,说明生成图像既具有高质量(明确的类别)又具有多样性(跨越不同类别),生成模型的性能越好。
GS(Geometry Score)是一个基于拓扑特征相似性的生成对抗网络模型生成性能评估指标。它通过比较生成图像和真实图像的数据流形的拓扑特征来评估生成图像的质量。具体来说,GS使用持久同调(persistent homology)方法计算数据集的Betti数,用以描述数据的拓扑结构,比如连通分支的数量、环的数量等。GS值越低,说明生成图像的数据流形与真实图像的数据流形的拓扑特征越接近,生成模型的性能越好。

这些指标各有侧重,通过综合使用,可以更全面地评价生成模型的表现。

实验在Linux服务器上进行,操作系统为Ubuntu 18.04,使用Nvidia Tesla P40 24GB单显卡。采用Adam优化器(Kingma和Ba, 2014),设置 β 1 = 0.5 \beta_1=0.5 β1=0.5 β 2 = 0.999 \beta_2=0.999 β2=0.999,批量大小设置为64,训练时的学习率为0.0002。

在这里插入图片描述
在这里插入图片描述


http://www.kler.cn/a/371985.html

相关文章:

  • JS中的鼠标事件和键盘事件基础
  • 大模型WebUI:Gradio全解系列9——Additional Features:补充特性(下)
  • 前端如何排查内存泄漏
  • 【MySQL】踩坑笔记——保存带有换行符等特殊字符的数据,需要进行转义保存
  • 【Linux命令】su、sudo、sudo su、sudo -i、sudo -l的用法和区别
  • Go-知识 模板
  • CSS flex布局- 最后一个元素占满剩余可用高度转载
  • Rust 力扣 - 59. 螺旋矩阵 II
  • 正则表达式笔记
  • Windows目录共享到Linux
  • vue2和vue3在html中引用组件component方式不一样
  • 聊聊AI时代的新岗位
  • 软件测试-覆盖率测试-四关全
  • JavaScript的本地存储知识点详解Cookie、SessionStorage、LocalStorage、IndexedDB
  • SQL 数据汇总与透视的实用案例
  • mixin的基本用法
  • 达梦数据库创建oracle dblink
  • SUSE发布云安全行业趋势报告,中国市场释放积极信号
  • Google Recaptcha V2 简单使用
  • 【网络原理】——图解HTTPS如何加密(通俗简单易懂)
  • 【Pytorch】Pytorch的安装
  • 实现PC端和安卓手机的局域网内文件共享
  • OpenCV视觉分析之运动分析(5)背景减除类BackgroundSubtractorMOG2的使用
  • Oracle视频基础1.1.3练习
  • 《Linux系统编程篇》exec族函数——基础篇
  • _csv.Error: field larger than field limit (131072)