当前位置: 首页 > article >正文

GIN:逼近WL-test的GNN架构

Introduction

在 图卷积网络GCN 中我们已经知道图神经网络在结点分类等任务上的作用,但GIN(图同构神经网络)给出了一个对于图嵌入(graph embedding)更强的公式。

GIN,图同构神经网络,致力于解决使用GNN对图进行分类的问题,我们知道采用GCN等GNN可以得到每个结点的嵌入向量表示(node embedding), 如果我们将一张图的所有节点的向量表示拼接起来,便可以得到整张图的向量表示(graph embedding)

本篇博文将关注利用GNN进行图级别表示的学习。图表征学习要求根据节点属性、边和边的属性(如果有的话)生成一个向量作为图的表征,基于图表征可以做图的预测。基于**图同构网络(Graph Isomorphism Network, GIN)**的图表征网络是当前最经典的图表征学习网络。

本文将以GIN为例,首先将介绍图同构的相关概念,然后介绍图同构测试的经典算法——Weisfeiler-Lehamn算法,接着解释为什么说GNN是WL-test的变体,并分析基于消息传递网络架构设计的GNN模型如何学习图表征,而现有的GNN模型如GCN等为什么达不到WL-test,最后介绍How Powerful are Graph Neural Networks?一文中作者提出的GIN架构,并提供相应模块的代码。

Weisfeiler-Lehamn

2.1 Graph Isomorphism

简单来说,如果图 G1 和 G2 的顶点和边数量相同,且边的连接性相同,则可以称这两个图为同构的。也可以认为, G1 的点是由 G2 中的点映射得到。
在这里插入图片描述
上面两张图是同构的,映射关系为:A-3. B-1. C-2. D-5. E-4.

图同构测试的意义:比如在蛋白质结构、基因网络中,具有相似结构(同构测试或相似度计算)的蛋白质或基因结构可能具有相似的功能特性。又比如两位作者相似的期刊引文网络结构可能表示两位作者的研究内容相似等等。因此度量图的相似度是图学习的一个核心问题。

图同构问题通常被认为是 NP 问题,目前最有效的算法是 Weisfeiler-Lehman 算法,可以在准多项式时间内进行求解。但是Weisfeiler-Lehman 测试(WL-test)是图同构的一个必要但不充分的条件。也就是说,两个图的WL-test结果显示有差异,可认为这两个图是非同构的;但如WL-test结果显示没有差异,只能表述为这两个图可能同构

2.2 1-dimensional WL-test

WL 算法可以是 K-维的,K-维 WL 算法在计算图同构问题时会考虑顶点的 k 元组,如果只考虑顶点的自身特征(如标签、颜色等),那么就是 1-维 WL 算法。

介绍下 1-维 WL 算法,首先给出一个定义:

「多重集」(Multiset):一组可能重复的元素集合。例如:{1,1,2,3}就是一个多重集合。在图神经网络中表示节点邻居的特征向量集
假设节点 v 及其邻居集合 N(v),假设节点 v 的标签是 1 ,其 N(v) 对应的标签是 1、1、2、3、4,可以把邻居集合看成一个 Multiset 。
在这里插入图片描述
WL-test包括四个步骤:聚合邻居节点标签;多重集排序;标签压缩;更新标签。

下图a首先给出了两个标签标记(如果没有节点标签,可以用节点的度作为标签)的图 。(这个例子中两个图显然是不同构的,因为图1中有两个标签为1的节点而图2中只有一个标签为1的节点)
①聚合邻居节点标签(图b)
在第一步中,考虑各个节点的1阶邻居节点,构建集合(如图
中的节点5,其周围节点为4、3、2,则经过步骤1生成的多重集为(5,4-3-2))

②多重集排序(图b)
在各个节点形成的集合内除自身节点按标签升序排序(如经步骤1的图
中节点5,经过步骤2生成的多重集为(5,2-3-4)),排序的原因在于要保证结果不因邻居节点顺序改变而改变,即保证单射性。

③标签压缩/标签散列(图c)
对标签进行压缩映射,将较长的字符串映射到一个简短的标签(如利用hash编码或者继续编号等方式);

④更新标签(图d)
将压缩映射后的新标签更新到各个节点上。

每重复一次以上的过程,就完成一次节点自身标签与邻接节点标签的聚合。不断重复直到每个节点的label稳定不变。稳定后统计各个标签的分布,如果两个图相同标签的出现次数不一致,即可判断两个图不同构

⑤迭代 1 轮后,利用计数函数分别得到两张图的计数特征,如:分别统计图中各类标签出现的次数,存于一个向量,这个向量可以作为图的表征。两个图的这样的向量的内积,即可作为这两个图的相似性的估计。(图e)

直观上来看,WL-test 第 k 次迭代时节点的标号表示的是结点高度为 k 的子树结构:
在这里插入图片描述
以节点 1 为例,右图是节点 1 迭代两次的子树。因此 WL-test 所考虑的图特征本质上是图中以不同节点为根的子树的计数。

上述WL-test的步骤公式表示为在这里插入图片描述
可以看到这个过程和GCN、GraphSAGE等GNN网络的公式是很相似的。都是分为两步:聚合和结合

基于邻域聚合的GNN可以拆分为以下三个模块:

  • Aggregate:聚合一阶邻域特征。
  • Combine:将邻居聚合的特征与当前节点特征合并,以更新当前节点特征。
  • Readout(可选,针对图分类):如果是对graph分类,需要将graph中所有节点特征转变成graph特征。READOUT表示一个置换不变性函数(permutation invariant function),也可以是一个图级pooling函数

目前的GNNs都遵循一个邻居聚合的策略,也就是通过聚合邻居的表示然后迭代地更新自己的表示。在k次迭代聚合后就可以捕获到在k-hop邻居内的结构信息。一个k层的GNNs可以表示为:

在这里插入图片描述
因此,可以看出AGGREGATE聚合函数和COMBINE连接函数是非常重要的。

对于 GraphSAGE 来说:

  • AGGREGATE 被定义为:
    在这里插入图片描述
  • COMBINE 为 :
    在这里插入图片描述

而在 GCN 中,AGGREGATE 和 COMBINE 集成为:
在这里插入图片描述
对于节点分类任务,节点表示 h v ( K ) h^{(K)}_v hv(K) 将作为预测的输入;对于图分类任务,READOUT 函数聚合了最后一次迭代输出的节点表示 h v ( K ) h^{(K)}_v hv(K),并生成图表示 h G h_G hG:在这里插入图片描述
其中:READOUT 函数是具有排列不变性的函数,如:summation。

GNN 是通过递归更新每个节点的特征向量,从而捕获其周围网络结构和节点的特征,而 WL-test 也类似,那么 GNN 和 WL-test 之间又有什么关系呢?

2.3 Upper limit of GNN performance

现有的GNN模型达不到WL-test的原因(eg.GCN、GraphSAGE)
为什么诸多 GNN 变体没有 GIN 能力强呢?主要是因为GNN 变体的聚合和池化过程不满足单射。

单射(injective)是指每一个不同的输入只对应一个不同的输出。如果两个经过单射函数的输出相等,则它们的输入必定相等

WL测试是GNN判断图同构的上界,直观来讲,WL测试每次将一个多重集(一个结点的邻居集合)用一个哈希函数做了一个单射,而GNN聚合邻居信息的时候,虽然使用了各种各样的函数,但显然并不是单射,比如常用的Max(x), Sum(x), Mean(x)这些函数都不是单射,那么当然不可能有WL测试强了。

在图神经网络工作时,会对图中的节点特征按照边的结构进行聚合。如果将节点邻居的特征向量集可以看作是一个MultiSet,则整个图神经网络可以理解为是MultiSet的聚合函数。好的图神经网络应当具有单射函数的特性。即图神经网络必须能够将不同的MultiSet聚合到不同的表示中

聚合和更新操作的选择至关重要:只有多集内射函数才能使其等同于 Weisfeiler-Lehman 算法。一些文献中常用的聚合器选择,例如,最大值或均值,实际上严格来说没有 Weisfeiler-Lehman 强大,并且无法区分非常简单的图结构

基于mean和max的池化函数存在一些问题,以下图为例:
在这里插入图片描述
设节点以颜色为标签,即 r g b。

对于图 a 来说,max 和 mean 池化操作无法区分其结构,聚合邻居节点后最终标签都为 b,而对于 sum 操作来说 可以得到 2b 和 3b;
对于图 b 来说,max 操作无法区分,而 mean 和 sum 可以区分;
对于图 c 来说,max 和 mean 都无法区分,而 sum 依然可以区分。

三种不同的 Aggregate:

  • sum:学习全部的标签以及数量,可以学习精确的结构信息(不仅保存了分布信息,还保存了类别信息);[ 蓝色:4个;红色:2 个 ]
  • mean:学习标签的比例(比如两个图标签比例相同,但是节点有倍数关系),偏向学习分布信息;[ 蓝色:4/6=2/3 的比例;红色:2/6 = 1/3 的比例 ]
  • max:学习最大标签,忽略多样,偏向学习有代表性的元素信息;[ 两类(类内相同),所以各一个 ]

由此可见,max 和 min 并不满足单射性,故其性能会比 sum 差一点。本文主要 基于对 graph分类,证明了 sum 比 mean 、max 效果好,但是不能说明在node 分类上也是这样的效果,另外可能优先场景会更关注邻域特征分布, 或者代表性, 故需要都加入进来实验。

Graph Isomorphism Network

WL-test 是 GNN 的性能上界,那么是否存在一个与 WL-test 性能相当的 GNN 呢?答案是肯定的。

一个强大的GNN不会将两个不同的邻域映射到相同的表示,这意味着它的聚合模式必须是单射的(injective)。

GNN 和 WL-test 的主要区别在于「单射函数」中。作者提出 Theorem 3:**如果 GNN 中 Aggregate、Combine 和 Readout 函数是单射,GNN 可以和 WL test 一样强大。**顺利成章的,作者设计一个满足单射函数的图同构网络(Graph Isomorphism Network,以下简称 GIN)。

文中将一个GNN的聚合方案抽象为一类神经网络可以表示的multisets上函数。
在这里插入图片描述

作者证明了当X为可数时,将aggregate设置为sum, combine 设置为1 + ϵ时,会存在f(x),使h(c,X)为单射:
在这里插入图片描述
h(c, X) is unique for each pair (c, X)

其中,c为节点自身特征, X为邻域特征集。ϵ是一个可学习的无理数参数

进一步推出任意g(c,X)都可以分解成以下f ∘ φ形式,满足单射性
在这里插入图片描述
通过引入多层感知机MLP,去学习φ和f,保证单射性。最终得到基于MLP+SUM的GIN框架
在这里插入图片描述

  • MLP(大于等于2层)可以近似拟合任意函数,故可以学习到单射函数,而graphsage和gcn中使用的单层感知机不能满足。
  • 约束输入特征是one-hot,故第一次迭代sum后还是满足单射性,不需先做MLP的预处理
  • 根据定理4, 迭代一轮得到新特征 h v ( k ) h_v^{(k)} hv(k)是可数的、经过了 转换f(x)(隐),下一轮迭代还是满足单射性条件。

GIN的聚合部分是基于单射injective的,聚合函数使用的是求和函数,它特殊的一点是在中心节点加了一个自连边(自环),之后对自连边进行加权
这样做的好处是即使我们调换了中心节点和邻居节点,得到的聚合结果依旧是不同的。所以带权重的自连边能够保证中心节点和邻居节点可区分。
在这里插入图片描述
在这里插入图片描述

Graph-level Readout of GIN

通过GIN学习的节点embeddings可以用于类似于节点分类、连接预测这样的任务。对于图分类任务,文中提出了一个“readout”函数:READOUT 函数的作用是将图中节点的 Embedding 映射成整张图的 Embedding。

于是作者提出了基于 SUM+CONCAT 的 READOUT 函数,对每次迭代得到的所有节点的特征求和得到该轮迭代的图特征,然后再拼接起每一轮迭代的图特征来得到最终的图特征:
在这里插入图片描述
无论是图卷积模型还是图注意力模型,都是使用递归迭代的方式,对图中的节点特征按照边的结构进行聚合来进行计算的。图同构网络(GIN)模型在此基础之上,对图神经网络提出了一个更高的合理性要求——同构性。即对同构图处理后的图特征应该相同,对非同构图处理后的图特征应该不同

Experiments

Datasets

论文使用 9 个图形分类基准(benchmarks):

4个生物信息学数据集

  • MUTAG
  • PTC
  • NCI1
  • PROTEINS

5个社交网络数据集

  • OLLAB
  • IMDB-BINARY
  • IMDB-MULTI
  • REDDIT-BINARY
  • REDDIT-MULTI5K

重要的是,我目标不是让模型依赖输入节点特征,而是主要从网络结构中学习。
因此,在生物信息图中,节点具有分类输入特征,但在社交网络中,它们没有特征。 对于社交网络,按如下方式创建节点特征:对于 REDDIT 数据集,将所有节点特征向量设置为相同(因此,这里的特征是无信息的); 对于其他社交图,我们使用节点度数的 one-hot 编码。

Baselines

1.拟合能力——train效果

在这里插入图片描述
在训练中,GIN 和WL test 一样,可以拟合所有数据集,这表说了GIN表达能力达到了上限

2.关注泛化能力——test集效果

在这里插入图片描述

GIN-0 比GIN-eps 泛化能力强:可能是因为更简单的缘故
GIN 比 WL test 效果好:因为GIN进一步考虑了结构相似性,即WL test 最终是one-hot输出,而GIN是将WL test映射到低维的embedding
max在 无节点特征的图(用度来表示特征)基本无效

DGL库中GIN的实现

在DGL库中, GIN模型是通过GINConv类来实现的。该类将GIN模型中的全连接网络换成了参数调用形式,在使用时可以向该参数传入任意神经网络,具有更加灵活的扩展性。

具体代码在DGL安装库路径下的\nn\pytorch\conv\ginconv.py中。

class GINConv(nn.Module):
    def __init__(self,
                 apply_func,  # 自定义模型参数
                 aggregator_type,  # 聚合类型
                 init_eps=0,  # 可学习变量的初始值
                 learn_eps=False):  # 是否使用可学习变量
        super(GINConv, self).__init__()
        self.apply_func = apply_func
        self._aggregator_type = aggregator_type
        if aggregator_type == 'sum':
            self._reducer = fn.sum
        elif aggregator_type == 'max':
            self._reducer = fn.max
        elif aggregator_type == 'mean':
            self._reducer = fn.mean
        else:
            raise KeyError('Aggregator type {} not recognized.'.format(aggregator_type))
        # to specify whether eps is trainable or not.
        if learn_eps:  # 是否使用可学习变量
            self.eps = th.nn.Parameter(th.FloatTensor([init_eps]))
        else:
            self.register_buffer('eps', th.FloatTensor([init_eps]))

    def forward(self, graph, feat, edge_weight=None):  # 正向传播
        with graph.local_scope():
            aggregate_fn = fn.copy_src('h', 'm')
            if edge_weight is not None:
                assert edge_weight.shape[0] == graph.number_of_edges()
                graph.edata['_edge_weight'] = edge_weight
                aggregate_fn = fn.u_mul_e('h', '_edge_weight', 'm')

            feat_src, feat_dst = expand_as_pair(feat, graph)
            graph.srcdata['h'] = feat_src
            graph.update_all(aggregate_fn, self._reducer('m', 'neigh'))  # 聚合邻居节点特征
            rst = (1 + self.eps) * feat_dst + graph.dstdata['neigh']  # 将自身特征混合
            if self.apply_func is not None:  # 使用神经网络进行单射
                rst = self.apply_func(rst)
            return rst


http://www.kler.cn/a/389795.html

相关文章:

  • 港湾周评|万科的多重压力
  • HBASE学习(一)
  • 用公网服务器实现内网穿透
  • 构建一个简单的深度学习模型
  • PyTorch 神经协同过滤 (NCF) 推荐系统教程
  • Golang Gin系列-3:Gin Framework的项目结构
  • 分布式数据库:深入探讨架构、挑战与未来趋势
  • 鸿蒙Flutter实战:13-鸿蒙应用打包上架流程
  • 随堂测微信小程序ssm+论文源码调试讲解
  • MongoDB 详解:深入理解与探索
  • IOS开发之MapKit定位国内不准的问题
  • LLaMA-Factory全流程训练模型
  • Flink输出算子
  • Tcp中的流量控制,拥塞控制,超时重传时间的选择,都附带相应例子说明
  • OBOO鸥柏:公司品牌部分户外广告一体机已布局文化传媒市场
  • Spring Boot集成Access DB实现数据导入和解析
  • Rust生成随机值实战应用
  • http的发展史
  • Spring Boot与工程认证:计算机课程管理的现代化
  • 【ET8框架进阶】HybridCLR打包丢失元方法问题MissingMethodException:生成LinkXml增加元方法
  • 车间管理|基于SprinBoot+vue工厂车间管理系统设计与实现(源码+数据库+文档)
  • Chromium 中chrome.contextMenus扩展接口定义c++
  • QT模态对话框和非模态对话框区别以及常用标准对话框
  • 【QT】Qt网络
  • 【CSS】什么是BFC?
  • 2.5_XXE(XML外部实体注入)