图卷积网络GCN与图注意力网络GAT原理详解
文章目录
- 1. Why 图卷积网络GCN?
- 2. GCN的原理
- 2.1 GCN的输入
- 2.2 GCN的核心公式
- 2.3 GCN 的核心公式推导的直观理解
- 3. Why 图注意力网络GAT?
- 4. GAT的原理
- 4.1 GAT的输入
- 4.2 GAT的流程及核心公式
- References
1. Why 图卷积网络GCN?
GCN(Graph Convolutional Networks) 首次提出于ICLR2017【https://openreview.net/forum?id=SJU4ayYgl】
简单来说,在图神经网络出现之前,一般的神经网络只能对常规的欧式数据进行处理,其特点就是节点有固定的排列规则和顺序,如二维网格和一维序列。
在GCN之前,CNN、RNN等神经网络模型在CV、NLP等领域都取得了优异的效果,但是其都无法很好地解决图结构的数据问题。
【e.g. RNN 一维序列数据】 当对象是自然语言这样的序列信息时,是一个一维的结构,此时RNN系列被提出,通过各种门的操作,使得序列前后的信息相互影响,从而很好地捕捉序列的特征。关于RNN系列的介绍,可参考从RNN讲起(RNN、LSTM、GRU、BiGRU)——序列数据处理网络。
【e.g. CNN 二维图像数据】 对于图像数据,是一个二维的结构,于是有了CNN模型来提取图片的特征。CNN的核心在于它的kernel(卷积核),简单来说,kernel是一个个小窗口,在图片上平移,通过卷积的方式来提取特征。这里的关键在于图片结构的平移不变形:是指系统在输入图像平移后,输出结果保持不变。这种特性使得CNN在处理图像时,无论目标对象在图像中的位置如何变化,都能保持一致的识别效果。因此CNN可以实现参数共享,这就是CNN的精髓所在。
上面所提到的图片或者自然语言,都是属于欧式空间的数据,因此才有维度的概念,欧式空间数据的特点就是数据结构很规则。然而,在现实生活中,其实有很多不规则的数据结构,典型的就是图结构(也称为拓扑结构)。
图结构一般是不规则的,可以认为是无限维的一种数据,所以它没有平移不变形。每个节点周围的结构可能都是独一无二的,因此,像CNN、RNN这种模型就不能很好地发挥作用了。因此,涌现出了很多处理图结构的方法,例如GNN、GAT等。
GCN,图卷积神经网络,实际上跟CNN的作用一样,就是一个特征提取器,只不过它的对象是图数据。 GCN精妙地设计了一种从图数据中提取特征的方法,从而让我们可以使用这些特征去对图数据进行节点分类(node classification)、图分类(graph classification)、边预测(link prediction),还可以顺便得到图的嵌入表示(graph embedding) 等等。
2. GCN的原理
2.1 GCN的输入
对于一个图结构数据
G
G
G,其中有
N
N
N 个节点,每个节点都有自己的一个特征表示。
假设这些节点的特征组成了一个特征矩阵
X
∈
R
N
×
D
X \in \mathbb R^{N \times D}
X∈RN×D,其中
D
D
D 为特征维数。
此外每个节点之间的关系也可以用一个关系矩阵
A
∈
R
N
×
N
A \in \mathbb R^{N \times N}
A∈RN×N 表示,这个矩阵也被称为邻接矩阵(adjaceney matrix)。
GCN模型输入就是:节点特征 X X X 和 邻接矩阵 A A A (表示图结构,即节点之间的关系).
2.2 GCN的核心公式
首先,先给出GCN的核心公式:
H
(
l
+
1
)
=
σ
(
D
~
−
1
2
A
~
D
~
−
1
2
H
(
l
)
W
(
l
)
)
H^{(l+1)}=\sigma(\tilde D^{-\frac{1}{2}} \tilde A \tilde D^{-\frac{1}{2}} H^{(l)} W^{(l)})
H(l+1)=σ(D~−21A~D~−21H(l)W(l))
其中,
-
A
~
=
A
+
I
N
\tilde A = A +I_N
A~=A+IN,
I
N
I_N
IN是一个
N
N
N维单位阵。
- 之所以要加上一个单位阵,是因为邻接矩阵 A A A 的对角线上都是0,在更新节点特征表示时会忽略节点自身的特征。
- D ~ \tilde D D~ 是 A ~ \tilde A A~ 的度矩阵, D ~ = ∑ j A ~ i j \tilde D = \sum_j \tilde A_{ij} D~=∑jA~ij。
- W ( l ) W^{(l)} W(l) 是第 l l l 层的权重矩阵。
- H ( l ) ∈ R N × D H^{(l)} \in \mathbb R^{N \times D} H(l)∈RN×D 为第 l l l 层的特征矩阵, H ( 0 ) = X H^{(0)}=X H(0)=X, H ( L ) = Z H^{(L)}=Z H(L)=Z。
- σ \sigma σ 表示非线性激活函数,例如ReLU等。
上图中的GCN输入一个图,通过若干层GCN每个node的特征从
X
X
X变成了
Z
Z
Z,但是,无论中间有多少层,node之间的连接关系,即
A
A
A,都是共享的。
2.3 GCN 的核心公式推导的直观理解
首先,我们给出一个简单的图以及其对应的邻接矩阵 A A A、节点特征 X X X,这里我们用一个值表示一个节点特征。
那么GCN所做的事情就是:融合邻居节点来更新某一个节点。
以2号节点为例,那么最直接的融合方式就是求和求平均。
【融合邻居节点信息更新2号节点的值:step1 求和】
将2号节点的邻居求和: A g g r e g a t e ( X 2 ) = X 0 + X 1 + X 3 + X 4 Aggregate(X_2) = X_0 + X_1 + X_3 + X_4 Aggregate(X2)=X0+X1+X3+X4
写成所有节点特征与系数相乘的形式: A g g r e g a t e ( X 2 ) = 1 ⋅ X 0 + 1 ⋅ X 1 + 0 ⋅ X 2 + 1 ⋅ X 3 + 1 ⋅ X 4 + 0 ⋅ X 5 Aggregate(X_2) = 1\cdot X_0 + 1\cdot X_1 +0\cdot X_2 + 1\cdot X_3 +1\cdot X_4 + 0\cdot X_5 Aggregate(X2)=1⋅X0+1⋅X1+0⋅X2+1⋅X3+1⋅X4+0⋅X5
可以发现,这个特征前面的系数,就是邻接矩阵的2号节点对应的那一行数据 A 2 = [ 1 , 1 , 0 , 1 , 1 , 0 ] A_2 = [1,1,0,1,1,0] A2=[1,1,0,1,1,0]
因此,融合表示为邻接矩阵对应元素与特征矩阵对应元素相乘:
A
g
g
r
e
g
a
t
e
(
X
2
)
=
1
⋅
X
0
+
1
⋅
X
1
+
0
⋅
X
2
+
1
⋅
X
3
+
1
⋅
X
4
+
0
⋅
X
5
=
A
2
,
0
X
0
+
A
2
,
1
X
1
+
A
2
,
2
X
2
+
A
2
,
3
X
3
+
A
2
,
4
X
4
+
A
2
,
5
X
5
=
∑
j
=
0
N
A
2
j
X
j
=
A
2
X
\begin{aligned} Aggregate(X_2) &= 1\cdot X_0 + 1\cdot X_1 +0\cdot X_2 + 1\cdot X_3 +1\cdot X_4 + 0\cdot X_5 \\ &= A_{2,0} X_0 + A_{2,1} X_1 + A_{2,2} X_2 + A_{2,3} X_3 + A_{2,4} X_4 + A_{2,5} X_5 \\ &= \sum_{j=0}^N A_{2j}X_j \\ &= A_2 X \end{aligned}
Aggregate(X2)=1⋅X0+1⋅X1+0⋅X2+1⋅X3+1⋅X4+0⋅X5=A2,0X0+A2,1X1+A2,2X2+A2,3X3+A2,4X4+A2,5X5=j=0∑NA2jXj=A2X
对于所有的节点,都做聚合操作的话:
存在问题:没有考虑自身节点特征
在上面这个步骤,我们聚合2号节点的邻居节点信息的时候 ( A g g r e g a t e ( X 2 ) = X 0 + X 1 + X 3 + X 4 Aggregate(X_2) = X_0 + X_1 + X_3 + X_4 Aggregate(X2)=X0+X1+X3+X4),并没有考虑自己本身的节点信息,这显然是不合理的,应该为 A g g r e g a t e ( X 2 ) = X 0 + X 1 + X 2 + X 3 + X 4 Aggregate(X_2) = X_0 + X_1 + X_2 + X_3 + X_4 Aggregate(X2)=X0+X1+X2+X3+X4。
因此,在原有邻接矩阵的基础上,加上了一个单位阵 I I I,记为 A ~ = A + I \tilde A = A+I A~=A+I
那么融合结果表示为 A g g r e g a t e ( X 2 ) = X 0 + X 1 + X 2 + X 3 + X 4 = A ~ 2 X Aggregate(X_2) = X_0 + X_1 + X_2 + X_3 + X_4 = \tilde A_2 X Aggregate(X2)=X0+X1+X2+X3+X4=A~2X
同理,对所有节点表示进行聚合表示为
A
g
g
r
e
g
a
t
e
(
X
)
=
A
~
X
Aggregate(X) = \tilde A X
Aggregate(X)=A~X
【融合邻居节点信息更新2号节点的值:step2 求均值】
还是以2号节点为例,其求和表示为 A g g r e g a t e ( X 2 ) = 1 ⋅ X 0 + 1 ⋅ X 1 + 1 ⋅ X 2 + 1 ⋅ X 3 + 1 ⋅ X 4 + 0 ⋅ X 5 = A ~ 2 , 0 X 0 + A ~ 2 , 1 X 1 + A ~ 2 , 2 X 2 + A ~ 2 , 3 X 3 + A ~ 2 , 4 X 4 + A ~ 2 , 5 X 5 = ∑ j = 0 N A ~ 2 j X j \begin{aligned} Aggregate(X_2) &= 1\cdot X_0 + 1\cdot X_1 +1\cdot X_2 + 1\cdot X_3 +1\cdot X_4 + 0\cdot X_5 \\ &= \tilde A_{2,0} X_0 + \tilde A_{2,1} X_1 + \tilde A_{2,2} X_2 + \tilde A_{2,3} X_3 + \tilde A_{2,4} X_4 + \tilde A_{2,5} X_5 \\ &= \sum_{j=0}^N \tilde A_{2j}X_j \end{aligned} Aggregate(X2)=1⋅X0+1⋅X1+1⋅X2+1⋅X3+1⋅X4+0⋅X5=A~2,0X0+A~2,1X1+A~2,2X2+A~2,3X3+A~2,4X4+A~2,5X5=j=0∑NA~2jXj
那么自然的求均值首先想到的就是用和除以节点数量。对于2号节点,与其相临的节点有四个(
X
0
X_0
X0,
X
1
X_1
X1,
X
3
X_3
X3,
X
4
X_4
X4),再加上自身,共有5个节点.。这个数量其实就是每个节点的邻居个数(包括自身),可以通过计算邻接矩阵对应的行中1的个数获得。
因此,这里引入了度矩阵
D
D
D,度矩阵是一个对角阵,对角线上的元素为各个顶点的度,定义为
D
~
=
∑
j
A
~
i
j
\tilde D = \sum_j \tilde A_{ij}
D~=∑jA~ij。
所以,融合邻居节点更新2号节点的值: A g g r e g a t e ( X 2 ) = ∑ j = 0 N A ~ 2 j X j 5 = ∑ j = 0 N A ~ 2 j X j D ~ 2 , 2 \begin{aligned} Aggregate(X_2) &= \sum_{j=0}^N \frac{\tilde A_{2j} X_j}{5} \\ &=\sum_{j=0}^N \frac{\tilde A_{2j} X_j}{\tilde D_{2,2}} \end{aligned} Aggregate(X2)=j=0∑N5A~2jXj=j=0∑ND~2,2A~2jXj
存在问题:平均时未考虑被聚合节点的信息
以更新3号节点为例,融合邻居节点更新3号节点的值: A g g r e g a t e ( X 3 ) = ∑ j = 0 N A ~ 3 j X j D ~ 3 , 3 = A ~ 3 , 2 X 2 + A ~ 3 , 3 X 3 + A ~ 3 , 5 X 5 D ~ 3 , 3 = A ~ 3 , 2 X 2 D ~ 3 , 3 + A ~ 3 , 3 X 3 D ~ 3 , 3 + A ~ 3 , 5 X 5 D ~ 3 , 3 = A ~ 3 , 2 X 2 D ~ 3 , 3 D ~ 3 , 3 + A ~ 3 , 3 X 3 D ~ 3 , 3 D ~ 3 , 3 + A ~ 3 , 5 X 5 D ~ 3 , 3 D ~ 3 , 3 \begin{aligned} Aggregate(X_3) &=\sum_{j=0}^N \frac{\tilde A_{3j} X_j}{\tilde D_{3,3}} \\ &=\frac{\tilde A_{3,2} X_2 + \tilde A_{3,3} X_3 + \tilde A_{3,5} X_5}{\tilde D_{3,3}} \\ &= \frac{\tilde A_{3,2} X_2}{\tilde D_{3,3}} + \frac{\tilde A_{3,3} X_3}{\tilde D_{3,3}}+ \frac{\tilde A_{3,5} X_5}{\tilde D_{3,3}} \\ &= \frac{\tilde A_{3,2} X_2}{\sqrt{\tilde D_{3,3}\tilde D_{3,3}}} + \frac{\tilde A_{3,3} X_3}{\sqrt{\tilde D_{3,3}\tilde D_{3,3}}}+ \frac{\tilde A_{3,5} X_5}{\sqrt{\tilde D_{3,3}\tilde D_{3,3}}} \end{aligned} Aggregate(X3)=j=0∑ND~3,3A~3jXj=D~3,3A~3,2X2+A~3,3X3+A~3,5X5=D~3,3A~3,2X2+D~3,3A~3,3X3+D~3,3A~3,5X5=D~3,3D~3,3A~3,2X2+D~3,3D~3,3A~3,3X3+D~3,3D~3,3A~3,5X5
通过公式我们可以看出,在分子上考虑了自身节点和被融合节点,而分母只考虑了3号节点本身。
因此,分母加入对被融合节点的考虑,优化更新操作:
A
g
g
r
e
g
a
t
e
(
X
3
)
=
A
~
3
,
2
X
2
D
~
3
,
3
D
~
2
,
2
+
A
~
3
,
3
X
3
D
~
3
,
3
D
~
3
,
3
+
A
~
3
,
5
X
5
D
~
3
,
3
D
~
5
,
5
=
∑
j
=
0
N
A
~
3
j
X
j
D
~
3
,
3
D
~
j
,
j
\begin{aligned} Aggregate(X_3) &=\frac{\tilde A_{3,2} X_2}{\sqrt{\tilde D_{3,3}\tilde D_{2,2}}} + \frac{\tilde A_{3,3} X_3}{\sqrt{\tilde D_{3,3}\tilde D_{3,3}}}+ \frac{\tilde A_{3,5} X_5}{\sqrt{\tilde D_{3,3}\tilde D_{5,5}}} \\ &=\sum_{j=0}^N \frac{\tilde A_{3j}X_j}{\sqrt{\tilde D_{3,3}\tilde D_{j,j}}} \end{aligned}
Aggregate(X3)=D~3,3D~2,2A~3,2X2+D~3,3D~3,3A~3,3X3+D~3,3D~5,5A~3,5X5=j=0∑ND~3,3D~j,jA~3jXj
推广到融合邻居节点的值来更新
i
i
i 号节点的值:
A
g
g
r
e
g
a
t
e
(
X
i
)
=
∑
j
=
0
N
A
~
i
j
X
j
D
~
i
i
D
~
j
j
=
∑
j
=
0
N
D
~
i
i
−
1
2
A
~
i
j
D
~
j
j
−
1
2
X
j
\begin{aligned} Aggregate(X_i) &= \sum_{j=0}^N \frac{\tilde A_{ij}X_j}{\sqrt{\tilde D_{ii}\tilde D_{jj}}}\\ &=\sum_{j=0}^N \tilde D_{ii}^{-\frac{1}{2}} \tilde A_{ij} \tilde D_{jj}^{-\frac{1}{2}} X_j \end{aligned}
Aggregate(Xi)=j=0∑ND~iiD~jjA~ijXj=j=0∑ND~ii−21A~ijD~jj−21Xj
矩阵计算:
A
g
g
r
e
g
a
t
e
(
X
)
=
∑
i
=
0
N
A
g
g
r
e
g
a
t
e
(
X
i
)
=
∑
i
=
0
N
∑
j
=
0
N
D
~
i
i
−
1
2
A
~
i
j
D
~
j
j
−
1
2
X
j
=
∑
i
=
0
N
D
~
i
i
−
1
2
A
~
i
D
~
−
1
2
X
=
D
~
−
1
2
A
~
D
~
−
1
2
X
\begin{aligned} Aggregate(X) &= \sum_{i=0}^N Aggregate(X_i)\\ &=\sum_{i=0}^N \sum_{j=0}^N \tilde D_{ii}^{-\frac{1}{2}} \tilde A_{ij} \tilde D_{jj}^{-\frac{1}{2}} X_j \\ &=\sum_{i=0}^N \tilde D_{ii}^{-\frac{1}{2}} \tilde A_{i} \tilde D^{-\frac{1}{2}} X \\ &=\tilde D^{-\frac{1}{2}} \tilde A \tilde D^{-\frac{1}{2}} X \end{aligned}
Aggregate(X)=i=0∑NAggregate(Xi)=i=0∑Nj=0∑ND~ii−21A~ijD~jj−21Xj=i=0∑ND~ii−21A~iD~−21X=D~−21A~D~−21X
至此,我们就可以得出GCN的核心公式啦。
3. Why 图注意力网络GAT?
由于之前的GCN网络平等的对待每一个节点,于是就有人想到把attention引入,来给每个邻居节点分配不同的权重,从而能够识别出更加重要的邻居,于是就有了GAT。
GAT的核心思想是在每个节点上计算注意力系数,以确定节点与其邻居节点之间的重要性。这种注意力机制使得模型能够对不同节点之间的关系赋予不同的权重,从而更好地捕捉图数据中的局部结构和全局信息。
4. GAT的原理
GAT网络由堆叠简单的图注意力层(Graph Attention Layer)来实现。
4.1 GAT的输入
GAT的输入是每个节点的向量表示 h = { h 1 , h 2 , … , h N } , h i ∈ R F \mathbf{h} = \{\mathbf{h}_1,\mathbf{h}_2,\ldots,\mathbf{h}_N\},\mathbf{h}_i \in \mathbb{R}^F h={h1,h2,…,hN},hi∈RF,其中 N N N 是节点的数量。
4.2 GAT的流程及核心公式
【Step1 特征维度变换】
为了获得足够的表达能力,利用(至少一个可学习的)线性变换将原始特征映射为一个 F ′ F' F′ 维的特征向量 h ′ = { h 1 ′ , h 2 ′ , … , h N ′ } , h i ′ ∈ R F ′ \mathbf{h'} = \{\mathbf{h}'_1,\mathbf{h}'_2,\ldots,\mathbf{h}'_N\}, \mathbf{h}'_i \in \mathbb{R}^{F'} h′={h1′,h2′,…,hN′},hi′∈RF′.
h i ′ = W h i \mathbf{h}'_i=\mathbf{W} \mathbf{h}_i hi′=Whi
其中,权重矩阵
W
∈
R
F
′
×
F
\mathbf{W} \in \mathbb{R}^{F' \times F}
W∈RF′×F 是所有节点共享的。
【Step2 计算注意力系数并归一化处理】
在节点上执行自注意力(共享注意机制)计算注意力系数 e i j e_{ij} eij。
所谓注意力机制就是一种映射 a : R F ′ × R F ′ → R a : \mathbb R^{F'} \times \mathbb R^{F'} \to \mathbb R a:RF′×RF′→R,那么注意力系数:
e i j = a ( W h i , W h j ) = LeakyReLU ( a T [ W h i ∥ W h j ] ) e_{ij} = a(\mathbf{W} \mathbf{h}_i, \mathbf{W} \mathbf{h}_j)=\text{LeakyReLU}(\mathbf{a}^T[\mathbf{W} \mathbf{h}_i \| \mathbf{W} \mathbf{h}_j]) eij=a(Whi,Whj)=LeakyReLU(aT[Whi∥Whj])
表示节点 j j j 的特征对节点 i i i 的重要程度。
为了不同系数之间便于比较,对系数进行 归一化(normalization) 处理:
α i j = softmax j ( e i j ) = exp ( e i j ) ∑ k ∈ N i exp ( e i k ) \alpha_{ij} = \text{softmax}_{j}(e_{ij}) = \frac{\exp (e_{ij})}{\sum_{k \in \mathcal{N}_i} \exp (e_{ik})} αij=softmaxj(eij)=∑k∈Niexp(eik)exp(eij)
【Step 3 计算节点的向量表示输出】
得到归一化的权重系数后,按照注意力机制加权求和的思路,最终节点 i i i 新的特征向量为: h i ′ = σ ( ∑ j ∈ N i α i j W h j ) \mathbf{h}'_i = \sigma(\sum_{j \in \mathcal{N}_i} \alpha_{ij} \mathbf{W} \mathbf{h}_j) hi′=σ(j∈Ni∑αijWhj)
为了稳定自注意力的学习过程,引入了多头自注意力。用
K
K
K 个独立注意机制执行上述变换,然后将它们的特征拼接起来(或计算平均值),从而得到如下两种特征表示:
h
i
′
=
∥
k
=
1
K
σ
(
∑
j
∈
N
i
α
i
j
k
W
k
h
j
)
h
i
′
=
σ
(
1
K
∑
k
=
1
K
∑
j
∈
N
i
α
i
j
k
W
k
h
j
)
\mathbf{h}'_i = \overset{K}{\underset{k=1}{\parallel}} \sigma(\sum_{j \in \mathcal{N}_i} \alpha_{ij}^k \mathbf{W}^k \mathbf{h}_j) \\ \mathbf{h}'_i = \sigma(\frac{1}{K} \sum_{k=1}^K\sum_{j \in \mathcal{N}_i} \alpha_{ij}^k \mathbf{W}^k \mathbf{h}_j)
hi′=k=1∥Kσ(j∈Ni∑αijkWkhj)hi′=σ(K1k=1∑Kj∈Ni∑αijkWkhj)
其中 ∥ \parallel ∥ 表示向量拼接操作。
References
- 图神经网络系列讲解及代码实现
- 何时能懂你的心——图卷积神经网络(GCN)
- 图卷积网络(Graph Convolution Network,GCN)
- 图神经网络GAT(Graph Attention Network)理解