当前位置: 首页 > article >正文

【模型学习之路】手写+分析GAT

从GNN,到GCN,再到GAT

目录

文章目录

前言

GNN

GCN

GAT

公式

注意力实现

公式对比

多头注意力实现

测试&可视化


前言

读本文前,可以先过一遍【GNN图神经网络】入门到实战完整40讲!同济大佬用大白话的方式从零到一讲解原理基础及代码复现,主打一个通俗易懂!_哔哩哔哩_bilibili的1-12集。

GNN

GNN(Graph Neural Network,图神经网络)是一种专门用于处理图结构数据的深度学习模型,它的核心思想是通过聚合邻居节点的信息来更新每个节点的表示。这种更新过程可以捕捉到节点的局部邻域结构,从而学习到节点、边甚至整个图的高级特征表示。

GCN(Graph Convolutional Network,图卷积网络)是GNN的一种,它通过图卷积操作来更新节点的特征表示。

GCN

稍微来总结一下视频中的内容。

 (m, f)是第层的节点特征矩阵,m是节点个数,f表示每个节点的特征数。

 (f, f’) 是第层的可学习权重矩阵,也就是我们要训练的参数。

(m, m)是无自环邻接矩阵(即邻接矩阵,但是对角线全为0),(m, m)是度矩阵。首先,为了消息的自我传播,我们给它们加上单位矩阵。

节点更新函数为:

其中我们可以令一个变量:

式中是激活函数。

左乘和右乘时为了分别对列和行做标准化。

代码很简单,就是公式的堆叠,不做赘述。

import torch
import torch.nn as nn
import torch.nn.functional as F


def normalized_adjacency(adj):
    """输入A, 返回A^ """
    d = torch.diag(torch.sum(adj, dim=1))
    a = adj + torch.eye(adj.shape[0])
    d = d + torch.eye(adj.shape[0])
    d_inv_sqrt = torch.pow(d, -0.5)
    a_norm = d_inv_sqrt @ a @ d_inv_sqrt
    return a_norm


class GraphConvolution(nn.Module):
    def __init__(self, in_features, out_features):
        super(GraphConvolution, self).__init__()
        self.weight = nn.Parameter(torch.FloatTensor(in_features, out_features))
        self.reset_parameters()

    def reset_parameters(self):
        stdv = 1. / (self.weight.size(1) ** 0.5)
        self.weight.data.uniform_(-stdv, stdv)

    def forward(self, x, adj):
        adj = normalized_adjacency(adj)
        return adj @ (x @ self.weight)


class GCN(nn.Module):
    def __init__(self, in_features, hidden_features, out_features,
                 n_layers, dropout=0.5):
        super(GCN, self).__init__()
        self.layers = nn.ModuleList()

        # 输入层到第一隐藏层
        self.layers.append(GraphConvolution(in_features, hidden_features))

        # 隐藏层到其他隐藏层
        for _ in range(n_layers - 2):
            self.layers.append(GraphConvolution(hidden_features, hidden_features))

        # 最后一个隐藏层到输出层
        self.layers.append(GraphConvolution(hidden_features, out_features))
        self.dropout = dropout

    def forward(self, x, adj):
        for layer in self.layers[:-1]:
            x = F.relu(layer(x, adj))
            x = F.dropout(x, self.dropout)
        x = self.layers[-1](x, adj)
        return F.log_softmax(x, dim=1)

GAT

公式

GAT(Graph Attention Networks,图注意力模型)。

本质上就是GCN应用了注意力机制,即A  成为了一个可变的、要更新的东西。

根据更新方式的不同,主要分为两种:

Global graph attention,就是任意两个点都要进行attention运算。

Mask graph attention,注意力机制的运算只在邻居顶点上进行。

根据实现中矩阵表示方法的不同,又分为密集矩阵(就是平时的矩阵)和稀释矩阵表示。

此外,又有很多种计算注意力系数的方法。

这里采用当年GAT论文中的做法,使用密集矩阵,计算注意力系数的方法与论文中保持一致。

在图注意力网络(GAT)中, 的计算公式以及节点更新公式如下:

计算注意力系

这里,  表示节点 hi  相对于节点  的注意力值,  是共享的可学习参数,||表示向量拼接操作。  是一种激活函数。

(1, f)(f, f’)的维度是(1, f’),两个拼接就会得到(1, 2f’) 是一个列向量,维度是(2f’,1),两者相乘得到一个标量。

对行求softmax。

    节点更新:

指的是与i节点相邻的所有节点形成的集合。

从单头到多头,  是注意力头的数量。

无非就是多个注意力头的结构然后全都concat起来,有时也会采用多个注意力头球均值的方法:

注意力实现

上代码

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F


class GraphAttentionLayer(nn.Module):
    def __init__(self, in_features, out_features, dropout, alpha, concat=True):
        super(GraphAttentionLayer, self).__init__()
        self.dropout = dropout
        self.in_features = in_features
        self.out_features = out_features
        self.alpha = alpha
        self.concat = concat

        self.W = nn.Parameter(torch.zeros(size=(in_features, out_features)))  # [f, f']
        nn.init.xavier_uniform_(self.W.data, gain=1.414)
        self.a = nn.Parameter(torch.zeros(size=(2*out_features, 1)))  # [2f', 1]
        nn.init.xavier_uniform_(self.a.data, gain=1.414)
        self.leakyrelu = nn.LeakyReLU(self.alpha)

    def forward(self, input_h, adj):  # input_h: [B, m, f]
        h = input_h @ self.W  # [B, m, f']
        m = h.size()[1]

        a_input = torch.cat(
            [h.repeat(1, 1, m).view(-1, m * m, self.out_features),
             h.repeat(1, m, 1)], dim=-1).view(-1, m, m, 2 * self.out_features)  # [B, m, m, 2f']
        e = self.leakyrelu((a_input @ self.a).squeeze(3))  # [B, m, m, 1] -> [B, m, m]

        # 如果邻接(adj_ij=1),就用e_ij
        # 如果不邻接(adj_ij=0),就用-9e15, 之后会被softmax“屏蔽”掉
        # 这样就只留用了邻接的Global graph attention -> Mask graph attention
        zero_vec = -9e15*torch.ones_like(e)
        attention = torch.where(adj > 0, e, zero_vec)  # [B, m, m]

        attention = F.softmax(attention, dim=-1)
        attention = F.dropout(attention, self.dropout)
        h_prime = attention @ h  # [B, m, m] @ [B, m, f'] -> [B, m, f']

        if self.concat:
            return F.elu(h_prime)
        else:
            return h_prime

先解释一下这一行:

a_input = torch.cat(
    [h.repeat(1, 1, m).view(-1, m * m, self.out_features),
     h.repeat(1, m, 1)], dim=-1).view(-1, m, m, 2 * self.out_features)  # [B, m, m, 2f']

这一行十分抽象。在这个代码中中h维度是(B, m, f’)。这里先不看batch_size,我们先假设h的维度是(m, f’),h由m个节点组成,每个节点有f个特征。我们可以先把h写成:

 (m, f’) 显然,这里m=3

第一个量:

h.repeat(1,m),横着重复:

   (m, mf’)

h.view(mm, f’):

​​​​​​​

第二个量:

h.repeat(m,1),竖着重复:

 (mm, f’)

将两者拼起来:

 (mm, 2f’)

最后展开:

View(m,m,2f’):

(m, m, 2f’)

加了batch_size之后一个道理。

芜湖,这样处理之后,就有:

则:

解释完毕

公式对比

 

对比一下,GAT的公式描述是这样的:

在代码里面我们是这样实现GAT的:

再得到a_input,之后有:

写成矩阵形式:

然后用A来掩盖掉e中不相邻的值。这一步不写公式了。(因为我也不知道数学上怎么写)。

更新节点

回顾一下GCN的节点更新公式:

发现其实GAT本质就是用了一个新的权重计算方式。

多头注意力实现

之后多头注意力

class MultiHeadGATLayer(nn.Module):
    def __init__(self, nfeat, nhid, nout, dropout, alpha, nheads):
        """Dense version of GAT."""
        super(MultiHeadGATLayer, self).__init__()
        self.dropout = dropout
        self.attentions = nn.ModuleList(
            [GraphAttentionLayer(nfeat, nhid, dropout=dropout, alpha=alpha, concat=True)
             for _ in range(nheads)]
        )
        self.out = GraphAttentionLayer(nhid * nheads, nout, dropout=dropout, alpha=alpha, concat=False)

    def forward(self, x, adj):
        x = F.dropout(x, self.dropout)
        x = torch.cat([att(x, adj) for att in self.attentions], dim=2)  # [B, m, n_hid*n_heads]
        x = F.dropout(x, self.dropout)
        x = F.elu(self.out(x, adj))
        return x  # [B, m, n_out]

简单组装一下:

class GAT(nn.Module):
    def __init__(self, nfeat, nhid, nclass, dropout, alpha, nheads, nlayers):
        """Dense version of GAT."""
        super(GAT, self).__init__()
        self.dropout = dropout
        # 由输入到隐藏
        self.gats = nn.ModuleList(
            [MultiHeadGATLayer(nfeat, nhid, nfeat, dropout=dropout, alpha=alpha, nheads=nheads)
             for _ in range(nlayers)]
        )  # [B, m, f] -> [B, m, h] -> [B, m, f]
        self.out = GraphAttentionLayer(nfeat, nclass, dropout=dropout, alpha=alpha, concat=False)  # [B, m, n_class]

    def forward(self, x, adj):
        for gat in self.gats:
            x = gat(x, adj)
        x = F.dropout(x, self.dropout)

        x = self.out(x, adj)
        return F.log_softmax(x, dim=-1)  # [B, m, n_class]

测试&可视化

# 生成测试样例
x = torch.randn(1, 3, 2)  # [batch_size, m, f]
adj = torch.ones(1, 3, 3)  # [batch_size, m, m]
model = GAT(2, 4, 2, 0.6, 0.6, 4, 3)
print(model(x, adj).shape)  # [B, m, n_class] [1, 3, 2]

modelData = "./demo.pth"
torch.onnx.export(model, (x, adj), modelData)
netron.start(modelData)

感觉还行 


http://www.kler.cn/a/388066.html

相关文章:

  • 【数据结构与算法】第11课—数据结构之选择排序和交换排序
  • 车载空气净化器语音芯片方案
  • 【C++】new操作符的使用说明
  • 并发基础:(淘宝笔试题)三个线程分别打印 A,B,C,要求这三个线程一起运行,打印 n 次,输出形如“ABCABCABC....”的字符串【举一反三】
  • 鸿蒙next版开发:相机开发-元数据(ArkTS)
  • 基于yolov8、yolov5的番茄成熟度检测识别系统(含UI界面、训练好的模型、Python代码、数据集)
  • 前端 Flex 布局语法详解
  • Python接口自动化测试自学指南(项目实战)
  • 海外云手机在出海业务中的优势有哪些?
  • Elasticsearch实战使用
  • u盘怎么重装电脑系统_u盘重装电脑系统步骤和详细教程【新手宝典】
  • Hive中查看字段中是否包含某些字符串的函数
  • Git 入门篇(三)
  • 发布 VectorTraits v3.0(支持 X86架构的Avx512系列指令集,支持 Wasm架构及PackedSimd指令集等)
  • 从0开始深度学习(24)——填充和步幅
  • 通过 SSH 连接远程 Ubuntu 服务器
  • 24下半年教资面试资源(幼儿+小学+初中+高中+各科)逐字稿
  • Redis集群——针对实习面试
  • JDK8主要特性
  • React 中 `key` 属性的警告及其解决方案
  • C++设计模式精选面试题及参考答案
  • 如何找到系统中bert-base-uncased默认安装位置
  • 数据结构和算法-贪心算法01- 认识贪心
  • 如何利用 Python 的爬虫技术获取淘宝天猫商品的价格信息?
  • 手写线程池c
  • 前端基础面试题·第四篇——Vue(其三)