当前位置: 首页 > article >正文

论文《Graph Neural Networks with convolutional ARMA filters》笔记

【ARMA 2021 PAMI】本文介绍了一种新型的基于**自回归移动平均(Auto-Regression Moving Average,ARMA)**滤波器的图卷积层。与多项式滤波器相比,ARMA滤波器提供了更灵活的频率响应,对噪声更鲁棒,能更好地捕获全局图结构。

本文发表在2021年PAMI期刊上,第一作者学校:UiT the Arctic University of Norway,引用量:494。

PAMI期刊简介:全称IEEE Transactions on Pattern Analysis and Machine Intelligence(IEEE模式分析与机器学习智能汇刊),影响因子很高,计算机视觉顶刊,CCF A。

查询会议:

  • 会伴:https://www.myhuiban.com/

  • CCF deadline:https://ccfddl.github.io/

原文和开源代码链接:

  • paper原文:https://arxiv.org/abs/1901.01343
  • 开源代码:https://github.com/xnuohz/ARMA-dgl

0、核心内容

本文介绍了一种新型的基于**自回归移动平均(Auto-Regression Moving Average,ARMA)**滤波器的图卷积层。

研究背景与动机

  • 传统的图神经网络在图上实施卷积操作,通常基于多项式谱滤波器。
  • 作者指出多项式滤波器的局限性,如对噪声的敏感性、不能很好地捕捉全局图结构等。

ARMA滤波器的卷积层

  • 提出了一种基于ARMA滤波器的新型图卷积层,与多项式滤波器相比,ARMA滤波器提供了更灵活的频率响应,对噪声更鲁棒,能更好地捕获全局图结构。
  • 介绍了ARMA滤波器的图神经网络实现,包括递归和分布式公式,使得卷积层训练高效、节点空间局部化,并能在测试时迁移到新图上。

理论分析

  • 通过谱分析研究了所提出的ARMA层的滤波效果。
  • 论文还讨论了ARMA滤波器与多项式滤波器的数学性质和性能对比。

实验验证

  • 在四个下游任务上进行了实验:半监督节点分类、图信号分类、图分类和图回归。
  • 实验结果显示,基于ARMA层的图神经网络在各项任务重相较于基于多项式滤波器的网络有显著的性能提升。

方法细节

  • 论文详细介绍了ARMA滤波器的理论基础,包括其在图信号处理中的应用。
  • 讨论了ARMA滤波器的非线性和可训练性,以及如何通过优化任务相关的损失函数来学习参数。

实验设置与结果

  • 论文描述了实验的设置,包括数据集、模型架构、超参数选择等。
  • 提供了详细的实验结果,证明了ARMA层在不同任务上的有效性。

结论

  • 论文总结了ARMA层的主要贡献,并强调了其在图机器学习任务中的优越性能。

1、先验知识
① 什么是加权移动平均(weighted moving average)?

在图信号处理和GNN的上下文中,加权移动平均是一种对图信号进行平滑处理的方法。

移动平均(Moving Average,MA)

  • 移动平均是一种常用的信号处理技术,用于减少数据的波动性,突出其趋势。在图信号的背景下,它涉及对图中每个节点的特征值进行局部平均,以获得更平滑的信号表示。
  • 在最简单的形式中,每个节点的新特征值是其自身和其邻居节点特征值的算术平均。

加权移动平均(Weighted Moving Average,WMA)

  • 加权移动平均是移动平均的一种扩展,其中每个邻居节点的贡献被一个权重系数所调整。这些权重反映了节点间连接的重要性或节点特征的相关性。
  • 在图信号处理中,这意味着每个节点的新特征值使其自身和邻居节点特征值的加权和。权重通常基于图的拓扑结构(例如,边的强度或节点的相似性)或其他学习到的参数。

2、从谱滤波器到多项式滤波器发展脉络

① 谱滤波器:在谱域内与非线性可训练滤波器实现卷积的GNN。

参考资料:

  • 《Spectral networks and locally connected networks on graphs》Brunna 2013
  • 《Deep convolutional networks on graph-structured data》Henaff 2015

这种滤波器选择性地缩小或放大图信号的傅里叶系数(节点特征的一个实例),然后将节点特征映射到一个新的空间。

在这里插入图片描述

谱滤波器存在的问题:在实现时需要进行特征分解,计算代价很大。

② 多项式滤波器:为了避免频域昂贵的频谱分解和投影,最先进的GNN将图滤波器作为低阶多项式,直接在节点域学习。

参考资料:

  • 《Convolutional neural networks on graphs with fast localized spectral filtering》Deferrard 2016
  • 《GCN》Kipf 2016
  • 《Variational graph auto-encoders》Kipf 2016

多项式滤波器有一个有限的脉冲响应和执行加权移动平均过滤图信号在局部节点社区,允许快速分布式实现等基于切比雪夫多项式和Lanczos迭代。

多项式滤波器的频率响应(frequency response):

在这里插入图片描述

多项式滤波器实现:

在这里插入图片描述

切比雪夫多项式滤波器实现:

在这里插入图片描述

GCN滤波器实现:

在这里插入图片描述

多项式滤波器存在的问题:建模能力有限,由于其平滑性,不能模拟频率响应的急剧变化。至关重要的是,高阶多项式对于达到高阶邻域是必要的,但它们的计算成本往往更高,过拟合训练数据,使模型对图信号或底层图结构的变化敏感。(总结:低阶多项式不能模拟频率响应的急剧变化,导致过平滑问题;高阶多项式导致过拟合问题。)

③ ARMA(自回归移动平均滤波器):与具有相同参数数量的多项式滤波器相比,ARMA滤波器提供了更多的频率响应,可以解释高阶邻域。

参考资料:

  • 《Design of graph filters and filterbanks》Tremblay 2018
  • 《Signal processing techniques for interpolation in graph structured data》Narang 2013

ARMA滤波器的频率响应:

在这里插入图片描述

ARMA滤波器的实现:

在这里插入图片描述

为了避免求逆,ARMA滤波器的近似实现(迭代方法):

在这里插入图片描述

1阶ARMA滤波器的近似实现:

在这里插入图片描述

1阶ARMA滤波器的频率响应:

k阶ARMA滤波器的近似实现:

3、ARMA层及滤波器实现

1阶ARMA滤波器的实现(本文最核心的公式,而且比较好理解):

在这里插入图片描述

定理1证明了公式14可收敛的条件(可以先忽略proof部分):

在这里插入图片描述

在实际实现ARMA算法过程中存在一些问题和挑战:

  1. 每个GCS堆栈k可能需要不同的、可能是大量的迭代次数 T k T_k Tk来收敛。这使得神经网络的实现变得麻烦,因为计算图是动态的,在训练过程中每次权重矩阵通过梯度下降更新都会变化。
  2. 为了训练参数的反向传播,如果 T k T_k Tk较大,神经网络必须多次展开,引入较高的计算代价和消失梯度问题。

解决方法:

  1. 参数 W k W_k Wk V k V_k Vk随机权重初始化。
  2. 固定收敛次数为常数 T T T

k阶ARMA滤波器的实现:

在这里插入图片描述

ARMA卷积层算法实现:

在这里插入图片描述

4、实验部分

节点分类任务的数据集和超参数:

在这里插入图片描述

节点分类任务实验结果:

5、核心卷积层源码

ARMA4NC类:

class ARMA4NC(nn.Module):
    def __init__(self,
                 in_dim,
                 hid_dim,
                 out_dim,
                 num_stacks,
                 num_layers,
                 activation=None,
                 dropout=0.0):
        super(ARMA4NC, self).__init__()

        self.conv1 = ARMAConv(in_dim=in_dim,
                              out_dim=hid_dim,
                              num_stacks=num_stacks,
                              num_layers=num_layers,
                              activation=activation,
                              dropout=dropout)

        self.conv2 = ARMAConv(in_dim=hid_dim,
                              out_dim=out_dim,
                              num_stacks=num_stacks,
                              num_layers=num_layers,
                              activation=activation,
                              dropout=dropout)
        
        self.dropout = nn.Dropout(p=dropout)

    def forward(self, g, feats):
        feats = F.relu(self.conv1(g, feats))
        feats = self.dropout(feats)
        feats = self.conv2(g, feats)
        return feats

ARMA卷积层:

class ARMAConv(nn.Module):
    def __init__(self,
                 in_dim,
                 out_dim,
                 num_stacks,
                 num_layers,
                 activation=None,
                 dropout=0.0,
                 bias=True):
        super(ARMAConv, self).__init__()
        
        self.in_dim = in_dim
        self.out_dim = out_dim
        self.K = num_stacks
        self.T = num_layers
        self.activation = activation
        self.dropout = nn.Dropout(p=dropout)

        # init weight
        self.w_0 = nn.ModuleDict({
            str(k): nn.Linear(in_dim, out_dim, bias=False) for k in range(self.K)
        })
        # deeper weight
        self.w = nn.ModuleDict({
            str(k): nn.Linear(out_dim, out_dim, bias=False) for k in range(self.K)
        })
        # v
        self.v = nn.ModuleDict({
            str(k): nn.Linear(in_dim, out_dim, bias=False) for k in range(self.K)
        })
        # bias
        if bias:
            self.bias = nn.Parameter(torch.Tensor(self.K, self.T, 1, self.out_dim))
        else:
            self.register_parameter('bias', None)
        
        self.reset_parameters()

    def reset_parameters(self):
        for k in range(self.K):
            glorot(self.w_0[str(k)].weight)
            glorot(self.w[str(k)].weight)
            glorot(self.v[str(k)].weight)
        zeros(self.bias)

    def forward(self, g, feats):
        with g.local_scope():
            init_feats = feats
            # assume that the graphs are undirected and graph.in_degrees() is the same as graph.out_degrees()
            degs = g.in_degrees().float().clamp(min=1)
            norm = torch.pow(degs, -0.5).to(feats.device).unsqueeze(1)
            output = None

            for k in range(self.K):
                feats = init_feats
                for t in range(self.T):
                    feats = feats * norm
                    g.ndata['h'] = feats
                    g.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h'))
                    feats = g.ndata.pop('h')
                    feats = feats * norm

                    if t == 0:
                        feats = self.w_0[str(k)](feats)
                    else:
                        feats = self.w[str(k)](feats)
                    
                    feats += self.dropout(self.v[str(k)](init_feats))
                    feats += self.v[str(k)](self.dropout(init_feats))

                    if self.bias is not None:
                        feats += self.bias[k][t]
                    
                    if self.activation is not None:
                        feats = self.activation(feats)
                    
                if output is None:
                    output = feats
                else:
                    output += feats
                
            return output / self.K 

6、心得&代码复现

从本文中可以学到很多东西,收获很大;本文的方法也比较有意思,值得认真研读,虽然理论部分读起来非常硬核,尤其是算法implementation部分。

实验部分我自己复现的结果:

CoraCiteseerPubmedFilmSquirrelChameleonTexasCornellWisconsin
GCN81.90%71.80%79.10%23.36%34.49%51.97%54.05%54.05%64.71%
H2GCN74.20%59.80%76.60%25.72%38.04%52.63%72.97%48.65%74.51%
ARMA(dgl版)80.80%71.20%78.50%33.30%31.50%53.40%70.27%59.46%79.80%

原论文中没有对异配图进行实验,在这里补充了6个异配图的实验结果,ARMA算法的效果提升主要在异配图上。

7、参考资料
  • kimi:https://kimi.moonshot.cn/

http://www.kler.cn/a/302056.html

相关文章:

  • Python习题 251:修改文件名称
  • react 中 memo 模块作用
  • 网络原理-网络层和数据链路层
  • 11. 观光景点组合得分问题 |豆包MarsCode AI刷题
  • 【快捷入门笔记】mysql基本操作大全-SQL表
  • 算法训练(leetcode)二刷第二十六天 | *452. 用最少数量的箭引爆气球、435. 无重叠区间、*763. 划分字母区间
  • 开关电源的占空比与输入输出电压的关系
  • 更改PaddlePaddle的模型默认缓存目录
  • Anaconda下载及安装保姆级教程(详细图文)
  • 基于Java+SpringBoot+Vue+MySQL的西安旅游管理系统网站
  • 办海洋测绘乙级该如何准备才能万无一失办下来
  • 2.3.1 协程设计原理与汇编实现coroutine 2
  • 路径规划 | 基于A*算法的往返式全覆盖路径规划的改进算法(Matlab)
  • 【西电电装实习】6. 手装无人机的蓝牙断连debug
  • 【linux006】目录操作命令篇 - pwd 命令
  • vue3 自定义指令 directive
  • CSS-3
  • 如何让大模型更好地进行场景落地?
  • OpenSSL工具验证RSA证书
  • C# 批量更改文件后缀名称
  • 软件测试中常用模型分析
  • 【编程底层思考】性能监控和优化:JVM参数调优,诊断工具的使用等。JVM 调优和线上问题排查实战经验总结
  • 【C语言从不挂科到高绩点】17-C语言中的宏定义
  • 云服务器 卸载mysql5并安装mysql8(图文)
  • docker-compose 部署 flink
  • 笔试强训day10