论文《Graph Neural Networks with convolutional ARMA filters》笔记
【ARMA 2021 PAMI】本文介绍了一种新型的基于**自回归移动平均(Auto-Regression Moving Average,ARMA)**滤波器的图卷积层。与多项式滤波器相比,ARMA滤波器提供了更灵活的频率响应,对噪声更鲁棒,能更好地捕获全局图结构。
本文发表在2021年PAMI期刊上,第一作者学校:UiT the Arctic University of Norway,引用量:494。
PAMI期刊简介:全称IEEE Transactions on Pattern Analysis and Machine Intelligence(IEEE模式分析与机器学习智能汇刊),影响因子很高,计算机视觉顶刊,CCF A。
查询会议:
-
会伴:https://www.myhuiban.com/
-
CCF deadline:https://ccfddl.github.io/
原文和开源代码链接:
- paper原文:https://arxiv.org/abs/1901.01343
- 开源代码:https://github.com/xnuohz/ARMA-dgl
0、核心内容
本文介绍了一种新型的基于**自回归移动平均(Auto-Regression Moving Average,ARMA)**滤波器的图卷积层。
研究背景与动机
- 传统的图神经网络在图上实施卷积操作,通常基于多项式谱滤波器。
- 作者指出多项式滤波器的局限性,如对噪声的敏感性、不能很好地捕捉全局图结构等。
ARMA滤波器的卷积层
- 提出了一种基于ARMA滤波器的新型图卷积层,与多项式滤波器相比,ARMA滤波器提供了更灵活的频率响应,对噪声更鲁棒,能更好地捕获全局图结构。
- 介绍了ARMA滤波器的图神经网络实现,包括递归和分布式公式,使得卷积层训练高效、节点空间局部化,并能在测试时迁移到新图上。
理论分析
- 通过谱分析研究了所提出的ARMA层的滤波效果。
- 论文还讨论了ARMA滤波器与多项式滤波器的数学性质和性能对比。
实验验证
- 在四个下游任务上进行了实验:半监督节点分类、图信号分类、图分类和图回归。
- 实验结果显示,基于ARMA层的图神经网络在各项任务重相较于基于多项式滤波器的网络有显著的性能提升。
方法细节
- 论文详细介绍了ARMA滤波器的理论基础,包括其在图信号处理中的应用。
- 讨论了ARMA滤波器的非线性和可训练性,以及如何通过优化任务相关的损失函数来学习参数。
实验设置与结果
- 论文描述了实验的设置,包括数据集、模型架构、超参数选择等。
- 提供了详细的实验结果,证明了ARMA层在不同任务上的有效性。
结论
- 论文总结了ARMA层的主要贡献,并强调了其在图机器学习任务中的优越性能。
1、先验知识
① 什么是加权移动平均(weighted moving average)?
在图信号处理和GNN的上下文中,加权移动平均是一种对图信号进行平滑处理的方法。
移动平均(Moving Average,MA)
- 移动平均是一种常用的信号处理技术,用于减少数据的波动性,突出其趋势。在图信号的背景下,它涉及对图中每个节点的特征值进行局部平均,以获得更平滑的信号表示。
- 在最简单的形式中,每个节点的新特征值是其自身和其邻居节点特征值的算术平均。
加权移动平均(Weighted Moving Average,WMA)
- 加权移动平均是移动平均的一种扩展,其中每个邻居节点的贡献被一个权重系数所调整。这些权重反映了节点间连接的重要性或节点特征的相关性。
- 在图信号处理中,这意味着每个节点的新特征值使其自身和邻居节点特征值的加权和。权重通常基于图的拓扑结构(例如,边的强度或节点的相似性)或其他学习到的参数。
2、从谱滤波器到多项式滤波器发展脉络
① 谱滤波器:在谱域内与非线性可训练滤波器实现卷积的GNN。
参考资料:
- 《Spectral networks and locally connected networks on graphs》Brunna 2013
- 《Deep convolutional networks on graph-structured data》Henaff 2015
这种滤波器选择性地缩小或放大图信号的傅里叶系数(节点特征的一个实例),然后将节点特征映射到一个新的空间。
谱滤波器存在的问题:在实现时需要进行特征分解,计算代价很大。
② 多项式滤波器:为了避免频域昂贵的频谱分解和投影,最先进的GNN将图滤波器作为低阶多项式,直接在节点域学习。
参考资料:
- 《Convolutional neural networks on graphs with fast localized spectral filtering》Deferrard 2016
- 《GCN》Kipf 2016
- 《Variational graph auto-encoders》Kipf 2016
多项式滤波器有一个有限的脉冲响应和执行加权移动平均过滤图信号在局部节点社区,允许快速分布式实现等基于切比雪夫多项式和Lanczos迭代。
多项式滤波器的频率响应(frequency response):
多项式滤波器实现:
切比雪夫多项式滤波器实现:
GCN滤波器实现:
多项式滤波器存在的问题:建模能力有限,由于其平滑性,不能模拟频率响应的急剧变化。至关重要的是,高阶多项式对于达到高阶邻域是必要的,但它们的计算成本往往更高,过拟合训练数据,使模型对图信号或底层图结构的变化敏感。(总结:低阶多项式不能模拟频率响应的急剧变化,导致过平滑问题;高阶多项式导致过拟合问题。)
③ ARMA(自回归移动平均滤波器):与具有相同参数数量的多项式滤波器相比,ARMA滤波器提供了更多的频率响应,可以解释高阶邻域。
参考资料:
- 《Design of graph filters and filterbanks》Tremblay 2018
- 《Signal processing techniques for interpolation in graph structured data》Narang 2013
ARMA滤波器的频率响应:
ARMA滤波器的实现:
为了避免求逆,ARMA滤波器的近似实现(迭代方法):
1阶ARMA滤波器的近似实现:
1阶ARMA滤波器的频率响应:
k阶ARMA滤波器的近似实现:
3、ARMA层及滤波器实现
1阶ARMA滤波器的实现(本文最核心的公式,而且比较好理解):
定理1证明了公式14可收敛的条件(可以先忽略proof部分):
在实际实现ARMA算法过程中存在一些问题和挑战:
- 每个GCS堆栈k可能需要不同的、可能是大量的迭代次数 T k T_k Tk来收敛。这使得神经网络的实现变得麻烦,因为计算图是动态的,在训练过程中每次权重矩阵通过梯度下降更新都会变化。
- 为了训练参数的反向传播,如果 T k T_k Tk较大,神经网络必须多次展开,引入较高的计算代价和消失梯度问题。
解决方法:
- 参数 W k W_k Wk和 V k V_k Vk随机权重初始化。
- 固定收敛次数为常数 T T T。
k阶ARMA滤波器的实现:
ARMA卷积层算法实现:
4、实验部分
节点分类任务的数据集和超参数:
节点分类任务实验结果:
5、核心卷积层源码
ARMA4NC类:
class ARMA4NC(nn.Module):
def __init__(self,
in_dim,
hid_dim,
out_dim,
num_stacks,
num_layers,
activation=None,
dropout=0.0):
super(ARMA4NC, self).__init__()
self.conv1 = ARMAConv(in_dim=in_dim,
out_dim=hid_dim,
num_stacks=num_stacks,
num_layers=num_layers,
activation=activation,
dropout=dropout)
self.conv2 = ARMAConv(in_dim=hid_dim,
out_dim=out_dim,
num_stacks=num_stacks,
num_layers=num_layers,
activation=activation,
dropout=dropout)
self.dropout = nn.Dropout(p=dropout)
def forward(self, g, feats):
feats = F.relu(self.conv1(g, feats))
feats = self.dropout(feats)
feats = self.conv2(g, feats)
return feats
ARMA卷积层:
class ARMAConv(nn.Module):
def __init__(self,
in_dim,
out_dim,
num_stacks,
num_layers,
activation=None,
dropout=0.0,
bias=True):
super(ARMAConv, self).__init__()
self.in_dim = in_dim
self.out_dim = out_dim
self.K = num_stacks
self.T = num_layers
self.activation = activation
self.dropout = nn.Dropout(p=dropout)
# init weight
self.w_0 = nn.ModuleDict({
str(k): nn.Linear(in_dim, out_dim, bias=False) for k in range(self.K)
})
# deeper weight
self.w = nn.ModuleDict({
str(k): nn.Linear(out_dim, out_dim, bias=False) for k in range(self.K)
})
# v
self.v = nn.ModuleDict({
str(k): nn.Linear(in_dim, out_dim, bias=False) for k in range(self.K)
})
# bias
if bias:
self.bias = nn.Parameter(torch.Tensor(self.K, self.T, 1, self.out_dim))
else:
self.register_parameter('bias', None)
self.reset_parameters()
def reset_parameters(self):
for k in range(self.K):
glorot(self.w_0[str(k)].weight)
glorot(self.w[str(k)].weight)
glorot(self.v[str(k)].weight)
zeros(self.bias)
def forward(self, g, feats):
with g.local_scope():
init_feats = feats
# assume that the graphs are undirected and graph.in_degrees() is the same as graph.out_degrees()
degs = g.in_degrees().float().clamp(min=1)
norm = torch.pow(degs, -0.5).to(feats.device).unsqueeze(1)
output = None
for k in range(self.K):
feats = init_feats
for t in range(self.T):
feats = feats * norm
g.ndata['h'] = feats
g.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h'))
feats = g.ndata.pop('h')
feats = feats * norm
if t == 0:
feats = self.w_0[str(k)](feats)
else:
feats = self.w[str(k)](feats)
feats += self.dropout(self.v[str(k)](init_feats))
feats += self.v[str(k)](self.dropout(init_feats))
if self.bias is not None:
feats += self.bias[k][t]
if self.activation is not None:
feats = self.activation(feats)
if output is None:
output = feats
else:
output += feats
return output / self.K
6、心得&代码复现
从本文中可以学到很多东西,收获很大;本文的方法也比较有意思,值得认真研读,虽然理论部分读起来非常硬核,尤其是算法implementation部分。
实验部分我自己复现的结果:
Cora | Citeseer | Pubmed | Film | Squirrel | Chameleon | Texas | Cornell | Wisconsin | |
---|---|---|---|---|---|---|---|---|---|
GCN | 81.90% | 71.80% | 79.10% | 23.36% | 34.49% | 51.97% | 54.05% | 54.05% | 64.71% |
H2GCN | 74.20% | 59.80% | 76.60% | 25.72% | 38.04% | 52.63% | 72.97% | 48.65% | 74.51% |
ARMA(dgl版) | 80.80% | 71.20% | 78.50% | 33.30% | 31.50% | 53.40% | 70.27% | 59.46% | 79.80% |
原论文中没有对异配图进行实验,在这里补充了6个异配图的实验结果,ARMA算法的效果提升主要在异配图上。
7、参考资料
- kimi:https://kimi.moonshot.cn/