【模型学习之路】手写+分析GAT
从GNN,到GCN,再到GAT
目录
文章目录
前言
GNN
GCN
GAT
公式
注意力实现
公式对比
多头注意力实现
测试&可视化
前言
读本文前,可以先过一遍【GNN图神经网络】入门到实战完整40讲!同济大佬用大白话的方式从零到一讲解原理基础及代码复现,主打一个通俗易懂!_哔哩哔哩_bilibili的1-12集。
GNN
GNN(Graph Neural Network,图神经网络)是一种专门用于处理图结构数据的深度学习模型,它的核心思想是通过聚合邻居节点的信息来更新每个节点的表示。这种更新过程可以捕捉到节点的局部邻域结构,从而学习到节点、边甚至整个图的高级特征表示。
GCN(Graph Convolutional Network,图卷积网络)是GNN的一种,它通过图卷积操作来更新节点的特征表示。
GCN
稍微来总结一下视频中的内容。
(m, f)是第 层的节点特征矩阵,m是节点个数,f表示每个节点的特征数。
(f, f’) 是第 层的可学习权重矩阵,也就是我们要训练的参数。
(m, m)是无自环邻接矩阵(即邻接矩阵,但是对角线全为0),(m, m)是度矩阵。首先,为了消息的自我传播,我们给它们加上单位矩阵。
节点更新函数为:
其中我们可以令一个变量:
式中是激活函数。
左乘和右乘时为了分别对列和行做标准化。
代码很简单,就是公式的堆叠,不做赘述。
import torch
import torch.nn as nn
import torch.nn.functional as F
def normalized_adjacency(adj):
"""输入A, 返回A^ """
d = torch.diag(torch.sum(adj, dim=1))
a = adj + torch.eye(adj.shape[0])
d = d + torch.eye(adj.shape[0])
d_inv_sqrt = torch.pow(d, -0.5)
a_norm = d_inv_sqrt @ a @ d_inv_sqrt
return a_norm
class GraphConvolution(nn.Module):
def __init__(self, in_features, out_features):
super(GraphConvolution, self).__init__()
self.weight = nn.Parameter(torch.FloatTensor(in_features, out_features))
self.reset_parameters()
def reset_parameters(self):
stdv = 1. / (self.weight.size(1) ** 0.5)
self.weight.data.uniform_(-stdv, stdv)
def forward(self, x, adj):
adj = normalized_adjacency(adj)
return adj @ (x @ self.weight)
class GCN(nn.Module):
def __init__(self, in_features, hidden_features, out_features,
n_layers, dropout=0.5):
super(GCN, self).__init__()
self.layers = nn.ModuleList()
# 输入层到第一隐藏层
self.layers.append(GraphConvolution(in_features, hidden_features))
# 隐藏层到其他隐藏层
for _ in range(n_layers - 2):
self.layers.append(GraphConvolution(hidden_features, hidden_features))
# 最后一个隐藏层到输出层
self.layers.append(GraphConvolution(hidden_features, out_features))
self.dropout = dropout
def forward(self, x, adj):
for layer in self.layers[:-1]:
x = F.relu(layer(x, adj))
x = F.dropout(x, self.dropout)
x = self.layers[-1](x, adj)
return F.log_softmax(x, dim=1)
GAT
公式
GAT(Graph Attention Networks,图注意力模型)。
本质上就是GCN应用了注意力机制,即A 成为了一个可变的、要更新的东西。
根据更新方式的不同,主要分为两种:
Global graph attention,就是任意两个点都要进行attention运算。
Mask graph attention,注意力机制的运算只在邻居顶点上进行。
根据实现中矩阵表示方法的不同,又分为密集矩阵(就是平时的矩阵)和稀释矩阵表示。
此外,又有很多种计算注意力系数的方法。
这里采用当年GAT论文中的做法,使用密集矩阵,计算注意力系数的方法与论文中保持一致。
在图注意力网络(GAT)中, 和 的计算公式以及节点更新公式如下:
计算注意力系数 :
这里, 表示节点 hi 相对于节点 的注意力值, 和 是共享的可学习参数,||表示向量拼接操作。 是一种激活函数。
(1, f)(f, f’)的维度是(1, f’),两个拼接就会得到(1, 2f’), 是一个列向量,维度是(2f’,1),两者相乘得到一个标量。
对行求softmax。
节点更新:
, 指的是与i节点相邻的所有节点形成的集合。
从单头到多头, 是注意力头的数量。
无非就是多个注意力头的结构然后全都concat起来,有时也会采用多个注意力头球均值的方法:
注意力实现
上代码
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
class GraphAttentionLayer(nn.Module):
def __init__(self, in_features, out_features, dropout, alpha, concat=True):
super(GraphAttentionLayer, self).__init__()
self.dropout = dropout
self.in_features = in_features
self.out_features = out_features
self.alpha = alpha
self.concat = concat
self.W = nn.Parameter(torch.zeros(size=(in_features, out_features))) # [f, f']
nn.init.xavier_uniform_(self.W.data, gain=1.414)
self.a = nn.Parameter(torch.zeros(size=(2*out_features, 1))) # [2f', 1]
nn.init.xavier_uniform_(self.a.data, gain=1.414)
self.leakyrelu = nn.LeakyReLU(self.alpha)
def forward(self, input_h, adj): # input_h: [B, m, f]
h = input_h @ self.W # [B, m, f']
m = h.size()[1]
a_input = torch.cat(
[h.repeat(1, 1, m).view(-1, m * m, self.out_features),
h.repeat(1, m, 1)], dim=-1).view(-1, m, m, 2 * self.out_features) # [B, m, m, 2f']
e = self.leakyrelu((a_input @ self.a).squeeze(3)) # [B, m, m, 1] -> [B, m, m]
# 如果邻接(adj_ij=1),就用e_ij
# 如果不邻接(adj_ij=0),就用-9e15, 之后会被softmax“屏蔽”掉
# 这样就只留用了邻接的Global graph attention -> Mask graph attention
zero_vec = -9e15*torch.ones_like(e)
attention = torch.where(adj > 0, e, zero_vec) # [B, m, m]
attention = F.softmax(attention, dim=-1)
attention = F.dropout(attention, self.dropout)
h_prime = attention @ h # [B, m, m] @ [B, m, f'] -> [B, m, f']
if self.concat:
return F.elu(h_prime)
else:
return h_prime
先解释一下这一行:
a_input = torch.cat(
[h.repeat(1, 1, m).view(-1, m * m, self.out_features),
h.repeat(1, m, 1)], dim=-1).view(-1, m, m, 2 * self.out_features) # [B, m, m, 2f']
这一行十分抽象。在这个代码中中h维度是(B, m, f’)。这里先不看batch_size,我们先假设h的维度是(m, f’),h由m个节点组成,每个节点有f个特征。我们可以先把h写成:
(m, f’) 显然,这里m=3
第一个量:
h.repeat(1,m),横着重复:
(m, mf’)
h.view(mm, f’):
第二个量:
h.repeat(m,1),竖着重复:
(mm, f’)
将两者拼起来:
(mm, 2f’)
最后展开:
View(m,m,2f’):
(m, m, 2f’)
加了batch_size之后一个道理。
芜湖,这样处理之后,就有:
则:
解释完毕
公式对比
对比一下,GAT的公式描述是这样的:
在代码里面我们是这样实现GAT的:
先
再得到a_input,之后有:
写成矩阵形式:
然后用A来掩盖掉e中不相邻的值。这一步不写公式了。(因为我也不知道数学上怎么写)。
更新节点
回顾一下GCN的节点更新公式:
发现其实GAT本质就是用了一个新的权重计算方式。
多头注意力实现
之后多头注意力
class MultiHeadGATLayer(nn.Module):
def __init__(self, nfeat, nhid, nout, dropout, alpha, nheads):
"""Dense version of GAT."""
super(MultiHeadGATLayer, self).__init__()
self.dropout = dropout
self.attentions = nn.ModuleList(
[GraphAttentionLayer(nfeat, nhid, dropout=dropout, alpha=alpha, concat=True)
for _ in range(nheads)]
)
self.out = GraphAttentionLayer(nhid * nheads, nout, dropout=dropout, alpha=alpha, concat=False)
def forward(self, x, adj):
x = F.dropout(x, self.dropout)
x = torch.cat([att(x, adj) for att in self.attentions], dim=2) # [B, m, n_hid*n_heads]
x = F.dropout(x, self.dropout)
x = F.elu(self.out(x, adj))
return x # [B, m, n_out]
简单组装一下:
class GAT(nn.Module):
def __init__(self, nfeat, nhid, nclass, dropout, alpha, nheads, nlayers):
"""Dense version of GAT."""
super(GAT, self).__init__()
self.dropout = dropout
# 由输入到隐藏
self.gats = nn.ModuleList(
[MultiHeadGATLayer(nfeat, nhid, nfeat, dropout=dropout, alpha=alpha, nheads=nheads)
for _ in range(nlayers)]
) # [B, m, f] -> [B, m, h] -> [B, m, f]
self.out = GraphAttentionLayer(nfeat, nclass, dropout=dropout, alpha=alpha, concat=False) # [B, m, n_class]
def forward(self, x, adj):
for gat in self.gats:
x = gat(x, adj)
x = F.dropout(x, self.dropout)
x = self.out(x, adj)
return F.log_softmax(x, dim=-1) # [B, m, n_class]
测试&可视化
# 生成测试样例
x = torch.randn(1, 3, 2) # [batch_size, m, f]
adj = torch.ones(1, 3, 3) # [batch_size, m, m]
model = GAT(2, 4, 2, 0.6, 0.6, 4, 3)
print(model(x, adj).shape) # [B, m, n_class] [1, 3, 2]
modelData = "./demo.pth"
torch.onnx.export(model, (x, adj), modelData)
netron.start(modelData)
感觉还行