当前位置: 首页 > article >正文

动手学图神经网络(4):利用图神经网络进行图分类

利用图神经网络进行图分类:从理论到实践

引言

在之前的学习中,大家了解了如何使用图神经网络(GNNs)进行节点分类。本次教程将深入探讨如何运用 GNNs 解决图分类问题。图分类是指在给定一个图数据集的情况下,根据图的一些结构属性对整个图进行分类,而不是对图中的节点进行分类。

环境准备

import os
import torch
os.environ['TORCH'] = torch.__version__
print(torch.__version__)

!pip install -q torch-scatter -f https://data.pyg.org/whl/torch-${TORCH}.html
!pip install -q torch-sparse -f https://data.pyg.org/whl/torch-${TORCH}.html
!pip install -q git+https://github.com/pyg-team/pytorch_geometric.git

数据集介绍

图分类中最常见的任务是分子属性预测,其中分子被表示为图,任务可能是推断分子是否抑制 病毒复制。
在这里插入图片描述

大家使用 TU Dortmund University 收集的 TUDatasets 中的 MUTAG 数据集,可以通过 PyTorch Geometric 的 torch_geometric.datasets.TUDataset 访问。

import torch
from torch_geometric.datasets import TUDataset

dataset = TUDataset(root='data/TUDataset', name='MUTAG')

print()
print(f'Dataset: {dataset}:')
print('====================')
print(f'Number of graphs: {len(dataset)}')
print(f'Number of features: {dataset.num_features}')
print(f'Number of classes: {dataset.num_classes}')

data = dataset[0]  # 获取第一个图对象

print()
print(data)
print('=============================================================')

# 收集第一个图的统计信息
print(f'Number of nodes: {data.num_nodes}')
print(f'Number of edges: {data.num_edges}')
print(f'Average node degree: {data.num_edges / data.num_nodes:.2f}')
print(f'Has isolated nodes: {data.has_isolated_nodes()}')
print(f'Has self-loops: {data.has_self_loops()}')
print(f'Is undirected: {data.is_undirected()}')

在这里插入图片描述

该数据集提供了 188 个不同的图,任务是将每个图分类到两个类别中的一个。第一个图对象有 17 个节点(每个节点有 7 维特征向量)和 38 条边,平均节点度为 2.24,还有一个图标签和 4 维边特征,但为了简单起见,大家不使用边特征。

大家将数据集打乱,并将前 150 个图用作训练集,其余用作测试集。

torch.manual_seed(12345)
dataset = dataset.shuffle()

train_dataset = dataset[:150]
test_dataset = dataset[150:]

print(f'Number of training graphs: {len(train_dataset)}')
print(f'Number of test graphs: {len(test_dataset)}')

在这里插入图片描述

图的小批量处理

由于图分类数据集中的图通常较小,在将图输入到图神经网络之前对其进行批处理是个好主意,以保证 GPU 的充分利用。对于 GNNs,PyTorch Geometric 采用了一种特殊的批处理方法,即将邻接矩阵以对角方式堆叠,节点和目标特征在节点维度上简单拼接。

在这里插入图片描述

from torch_geometric.loader import DataLoader

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

for step, data in enumerate(train_loader):
    print(f'Step {step + 1}:')
    print('=======')
    print(f'Number of graphs in the current batch: {data.num_graphs}')
    print(data)
    print()

在这里插入图片描述

这里大家选择了 64 的批处理大小,形成了 3 个(随机打乱的)小批量,包含了所有 150 个图。每个 Batch 对象都有一个 batch 向量,用于将每个节点映射到其在批处理中所属的图。

图数据的多种属性:

  • edge_attr:边的属性,其格式为 [边的数量, 每个边属性的维度]。
  • edge_index:边的索引,格式为 [2, 边的数量],用于表示图中边的连接关系。
  • x:节点特征矩阵,格式为 [节点数量, 每个节点特征的维度]。
  • y: 图的标签或目标值,其长度对应批次中图的数量。
  • batch:一个长度与节点数量相同的向量,用于标识每个节点属于哪个图。
  • ptr:指针向量,长度为批次中图的数量加 1,常用于在批次中索引各个图的起始位置。

训练图神经网络(GNN)

训练用于图分类的 GNN 通常遵循以下简单步骤:

  1. 通过多次消息传递嵌入每个节点。
  2. 将节点嵌入聚合为统一的图嵌入(读出层)。
  3. 在图嵌入上训练最终的分类器。

最常见的读出层是简单地取节点嵌入的平均值:
x G = 1 ∣ V ∣ ∑ v ∈ V x v ( L ) \mathbf{x}_{\mathcal{G}} = \frac{1}{|\mathcal{V}|} \sum_{v \in \mathcal{V}} \mathcal{x}^{(L)}_v xG=V1vVxv(L)

PyTorch Geometric 通过 torch_geometric.nn.global_mean_pool 提供了此功能。

定义 GCN 模型

from torch.nn import Linear
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.nn import global_mean_pool

class GCN(torch.nn.Module):
    def __init__(self, hidden_channels):
        super(GCN, self).__init__()
        torch.manual_seed(12345)
        self.conv1 = GCNConv(dataset.num_node_features, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, hidden_channels)
        self.conv3 = GCNConv(hidden_channels, hidden_channels)
        self.lin = Linear(hidden_channels, dataset.num_classes)

    def forward(self, x, edge_index, batch):
        # 1. 获得节点嵌入
        x = self.conv1(x, edge_index)
        x = x.relu()
        x = self.conv2(x, edge_index)
        x = x.relu()
        x = self.conv3(x, edge_index)

        # 2. 读出层
        x = global_mean_pool(x, batch)  # [batch_size, hidden_channels]

        # 3. 应用最终的分类器
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.lin(x)

        return x

model = GCN(hidden_channels=64)
print(model)

在这里插入图片描述

训练和测试模型

from IPython.display import Javascript
display(Javascript('''google.colab.output.setIframeHeight(0, true, {maxHeight: 300})'''))

model = GCN(hidden_channels=64)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = torch.nn.CrossEntropyLoss()

def train():
    model.train()

    for data in train_loader:  # 在训练数据集上批量迭代
        out = model(data.x, data.edge_index, data.batch)  # 执行一次前向传播
        loss = criterion(out, data.y)  # 计算损失
        loss.backward()  # 计算梯度
        optimizer.step()  # 根据梯度更新参数
        optimizer.zero_grad()  # 清除梯度

def test(loader):
    model.eval()

    correct = 0
    for data in loader:  # 在训练/测试数据集上批量迭代
        out = model(data.x, data.edge_index, data.batch)
        pred = out.argmax(dim=1)  # 使用概率最高的类别
        correct += int((pred == data.y).sum())  # 与真实标签进行比较
    return correct / len(loader.dataset)  # 计算正确预测的比例

for epoch in range(1, 171):
    train()
    train_acc = test(train_loader)
    test_acc = test(test_loader)
    print(f'Epoch: {epoch:03d}, Train Acc: {train_acc:.4f}, Test Acc: {test_acc:.4f}')
Epoch: 001, Train Acc: 0.6467, Test Acc: 0.7368
Epoch: 002, Train Acc: 0.6467, Test Acc: 0.7368
Epoch: 003, Train Acc: 0.6467, Test Acc: 0.7368
Epoch: 004, Train Acc: 0.6467, Test Acc: 0.7368
Epoch: 005, Train Acc: 0.6467, Test Acc: 0.7368
Epoch: 006, Train Acc: 0.6533, Test Acc: 0.7368
Epoch: 007, Train Acc: 0.7467, Test Acc: 0.7632
Epoch: 008, Train Acc: 0.7267, Test Acc: 0.7632
Epoch: 009, Train Acc: 0.7200, Test Acc: 0.7632
Epoch: 010, Train Acc: 0.7133, Test Acc: 0.7895
Epoch: 011, Train Acc: 0.7200, Test Acc: 0.7632
Epoch: 012, Train Acc: 0.7200, Test Acc: 0.7895
Epoch: 013, Train Acc: 0.7200, Test Acc: 0.7895
Epoch: 014, Train Acc: 0.7133, Test Acc: 0.8421
Epoch: 015, Train Acc: 0.7133, Test Acc: 0.8421
Epoch: 016, Train Acc: 0.7533, Test Acc: 0.7368
Epoch: 017, Train Acc: 0.7400, Test Acc: 0.7632
Epoch: 018, Train Acc: 0.7133, Test Acc: 0.8421
Epoch: 019, Train Acc: 0.7400, Test Acc: 0.7895
Epoch: 020, Train Acc: 0.7533, Test Acc: 0.7368
Epoch: 021, Train Acc: 0.7467, Test Acc: 0.7895
Epoch: 022, Train Acc: 0.7467, Test Acc: 0.7895
Epoch: 023, Train Acc: 0.7533, Test Acc: 0.7895
Epoch: 024, Train Acc: 0.7267, Test Acc: 0.8421
Epoch: 025, Train Acc: 0.7533, Test Acc: 0.7632
Epoch: 026, Train Acc: 0.7533, Test Acc: 0.7632
Epoch: 027, Train Acc: 0.7600, Test Acc: 0.8158
Epoch: 028, Train Acc: 0.7533, Test Acc: 0.8421
Epoch: 029, Train Acc: 0.7600, Test Acc: 0.7632
Epoch: 030, Train Acc: 0.7600, Test Acc: 0.8158
Epoch: 031, Train Acc: 0.7600, Test Acc: 0.8158
Epoch: 032, Train Acc: 0.7600, Test Acc: 0.7632
Epoch: 033, Train Acc: 0.7667, Test Acc: 0.7632
Epoch: 034, Train Acc: 0.7667, Test Acc: 0.7895
Epoch: 035, Train Acc: 0.7667, Test Acc: 0.7895
Epoch: 036, Train Acc: 0.7667, Test Acc: 0.7632
Epoch: 037, Train Acc: 0.7400, Test Acc: 0.7632
Epoch: 038, Train Acc: 0.7667, Test Acc: 0.8158
Epoch: 039, Train Acc: 0.7667, Test Acc: 0.7895
Epoch: 040, Train Acc: 0.7533, Test Acc: 0.7368
Epoch: 041, Train Acc: 0.7467, Test Acc: 0.7368
Epoch: 042, Train Acc: 0.7667, Test Acc: 0.7895
Epoch: 043, Train Acc: 0.7667, Test Acc: 0.8158
Epoch: 044, Train Acc: 0.7533, Test Acc: 0.7632
Epoch: 045, Train Acc: 0.7600, Test Acc: 0.7632
Epoch: 046, Train Acc: 0.7600, Test Acc: 0.7632
Epoch: 047, Train Acc: 0.7667, Test Acc: 0.8158
Epoch: 048, Train Acc: 0.7600, Test Acc: 0.7632
Epoch: 049, Train Acc: 0.7667, Test Acc: 0.7632
Epoch: 050, Train Acc: 0.7667, Test Acc: 0.8158
Epoch: 051, Train Acc: 0.7733, Test Acc: 0.7895
Epoch: 052, Train Acc: 0.7733, Test Acc: 0.8158
Epoch: 053, Train Acc: 0.7667, Test Acc: 0.7632
Epoch: 054, Train Acc: 0.7667, Test Acc: 0.7632
Epoch: 055, Train Acc: 0.7800, Test Acc: 0.7895
Epoch: 056, Train Acc: 0.7667, Test Acc: 0.7632
Epoch: 057, Train Acc: 0.7533, Test Acc: 0.7632
Epoch: 058, Train Acc: 0.7733, Test Acc: 0.7895
Epoch: 059, Train Acc: 0.7800, Test Acc: 0.7632
Epoch: 060, Train Acc: 0.7733, Test Acc: 0.8158
Epoch: 061, Train Acc: 0.7667, Test Acc: 0.7632
Epoch: 062, Train Acc: 0.7733, Test Acc: 0.8158
Epoch: 063, Train Acc: 0.7733, Test Acc: 0.8158
Epoch: 064, Train Acc: 0.7733, Test Acc: 0.8158
Epoch: 065, Train Acc: 0.7733, Test Acc: 0.8158
Epoch: 066, Train Acc: 0.7733, Test Acc: 0.7895
Epoch: 067, Train Acc: 0.7733, Test Acc: 0.7895
Epoch: 068, Train Acc: 0.7667, Test Acc: 0.7895
Epoch: 069, Train Acc: 0.7733, Test Acc: 0.8158
Epoch: 070, Train Acc: 0.7667, Test Acc: 0.7895
Epoch: 071, Train Acc: 0.7733, Test Acc: 0.7895
Epoch: 072, Train Acc: 0.7800, Test Acc: 0.7895
Epoch: 073, Train Acc: 0.7733, Test Acc: 0.8158
Epoch: 074, Train Acc: 0.7733, Test Acc: 0.8158
Epoch: 075, Train Acc: 0.7667, Test Acc: 0.7632
Epoch: 076, Train Acc: 0.7800, Test Acc: 0.7895
Epoch: 077, Train Acc: 0.7800, Test Acc: 0.7895
Epoch: 078, Train Acc: 0.7733, Test Acc: 0.8421
Epoch: 079, Train Acc: 0.7667, Test Acc: 0.8158
Epoch: 080, Train Acc: 0.7800, Test Acc: 0.7895
Epoch: 081, Train Acc: 0.7667, Test Acc: 0.7895
Epoch: 082, Train Acc: 0.7600, Test Acc: 0.7632
Epoch: 083, Train Acc: 0.7800, Test Acc: 0.7895
Epoch: 084, Train Acc: 0.7733, Test Acc: 0.7895
Epoch: 085, Train Acc: 0.7667, Test Acc: 0.7895
Epoch: 086, Train Acc: 0.7800, Test Acc: 0.8158
Epoch: 087, Train Acc: 0.7667, Test Acc: 0.7895
Epoch: 088, Train Acc: 0.7800, Test Acc: 0.7895
Epoch: 089, Train Acc: 0.7667, Test Acc: 0.7895
Epoch: 090, Train Acc: 0.7800, Test Acc: 0.7895
Epoch: 091, Train Acc: 0.7800, Test Acc: 0.7895
Epoch: 092, Train Acc: 0.7800, Test Acc: 0.8158
Epoch: 093, Train Acc: 0.7800, Test Acc: 0.7895
Epoch: 094, Train Acc: 0.7733, Test Acc: 0.7895
Epoch: 095, Train Acc: 0.7800, Test Acc: 0.7895
Epoch: 096, Train Acc: 0.7600, Test Acc: 0.7895
Epoch: 097, Train Acc: 0.7733, Test Acc: 0.7895
Epoch: 098, Train Acc: 0.7733, Test Acc: 0.8158
Epoch: 099, Train Acc: 0.7733, Test Acc: 0.7895
Epoch: 100, Train Acc: 0.7733, Test Acc: 0.7895
Epoch: 101, Train Acc: 0.7667, Test Acc: 0.7895
Epoch: 102, Train Acc: 0.7667, Test Acc: 0.7895
Epoch: 103, Train Acc: 0.7733, Test Acc: 0.7895
Epoch: 104, Train Acc: 0.7600, Test Acc: 0.7632
Epoch: 105, Train Acc: 0.7733, Test Acc: 0.7368
Epoch: 106, Train Acc: 0.7600, Test Acc: 0.7632
Epoch: 107, Train Acc: 0.7733, Test Acc: 0.7105
Epoch: 108, Train Acc: 0.8000, Test Acc: 0.7632
Epoch: 109, Train Acc: 0.7800, Test Acc: 0.7895
Epoch: 110, Train Acc: 0.7733, Test Acc: 0.7632
Epoch: 111, Train Acc: 0.7733, Test Acc: 0.7895
Epoch: 112, Train Acc: 0.7733, Test Acc: 0.7895
Epoch: 113, Train Acc: 0.7667, Test Acc: 0.7895
Epoch: 114, Train Acc: 0.7733, Test Acc: 0.7895
Epoch: 115, Train Acc: 0.7667, Test Acc: 0.7895
Epoch: 116, Train Acc: 0.7733, Test Acc: 0.7632
Epoch: 117, Train Acc: 0.7733, Test Acc: 0.7895
Epoch: 118, Train Acc: 0.7733, Test Acc: 0.7632
Epoch: 119, Train Acc: 0.7667, Test Acc: 0.7632
Epoch: 120, Train Acc: 0.8000, Test Acc: 0.7105
Epoch: 121, Train Acc: 0.7600, Test Acc: 0.7632
Epoch: 122, Train Acc: 0.7667, Test Acc: 0.7105
Epoch: 123, Train Acc: 0.7667, Test Acc: 0.7632
Epoch: 124, Train Acc: 0.7667, Test Acc: 0.7632
Epoch: 125, Train Acc: 0.7667, Test Acc: 0.7632
Epoch: 126, Train Acc: 0.7733, Test Acc: 0.7368
Epoch: 127, Train Acc: 0.7733, Test Acc: 0.7632
Epoch: 128, Train Acc: 0.7733, Test Acc: 0.7632
Epoch: 129, Train Acc: 0.7733, Test Acc: 0.7632
Epoch: 130, Train Acc: 0.7733, Test Acc: 0.7632
Epoch: 131, Train Acc: 0.7667, Test Acc: 0.7632
Epoch: 132, Train Acc: 0.7800, Test Acc: 0.7895
Epoch: 133, Train Acc: 0.7733, Test Acc: 0.7632
Epoch: 134, Train Acc: 0.7667, Test Acc: 0.7632
Epoch: 135, Train Acc: 0.8067, Test Acc: 0.7368
Epoch: 136, Train Acc: 0.7800, Test Acc: 0.7632
Epoch: 137, Train Acc: 0.7733, Test Acc: 0.7632
Epoch: 138, Train Acc: 0.8133, Test Acc: 0.7105
Epoch: 139, Train Acc: 0.7867, Test Acc: 0.7632
Epoch: 140, Train Acc: 0.7800, Test Acc: 0.7895
Epoch: 141, Train Acc: 0.8000, Test Acc: 0.6579
Epoch: 142, Train Acc: 0.7733, Test Acc: 0.7632
Epoch: 143, Train Acc: 0.7933, Test Acc: 0.7632
Epoch: 144, Train Acc: 0.7867, Test Acc: 0.7368
Epoch: 145, Train Acc: 0.8267, Test Acc: 0.7368
Epoch: 146, Train Acc: 0.7667, Test Acc: 0.7895
Epoch: 147, Train Acc: 0.7800, Test Acc: 0.7105
Epoch: 148, Train Acc: 0.7933, Test Acc: 0.7895
Epoch: 149, Train Acc: 0.8200, Test Acc: 0.7105
Epoch: 150, Train Acc: 0.7800, Test Acc: 0.7895
Epoch: 151, Train Acc: 0.7800, Test Acc: 0.7632
Epoch: 152, Train Acc: 0.7867, Test Acc: 0.7632
Epoch: 153, Train Acc: 0.8067, Test Acc: 0.7368
Epoch: 154, Train Acc: 0.8067, Test Acc: 0.7368
Epoch: 155, Train Acc: 0.7867, Test Acc: 0.7632
Epoch: 156, Train Acc: 0.7800, Test Acc: 0.7105
Epoch: 157, Train Acc: 0.8000, Test Acc: 0.7368
Epoch: 158, Train Acc: 0.7800, Test Acc: 0.7368
Epoch: 159, Train Acc: 0.7867, Test Acc: 0.7632
Epoch: 160, Train Acc: 0.7867, Test Acc: 0.7632
Epoch: 161, Train Acc: 0.7800, Test Acc: 0.7632
Epoch: 162, Train Acc: 0.7867, Test Acc: 0.7632
Epoch: 163, Train Acc: 0.7867, Test Acc: 0.7632
Epoch: 164, Train Acc: 0.7800, Test Acc: 0.8158
Epoch: 165, Train Acc: 0.7800, Test Acc: 0.8158
Epoch: 166, Train Acc: 0.7733, Test Acc: 0.7632
Epoch: 167, Train Acc: 0.7867, Test Acc: 0.7895
Epoch: 168, Train Acc: 0.7867, Test Acc: 0.7895
Epoch: 169, Train Acc: 0.8000, Test Acc: 0.7632
Epoch: 170, Train Acc: 0.8000, Test Acc: 0.7632

模型达到了约 76% 的测试准确率。由于数据集较小(只有 38 个测试图),准确率会有波动,在应用于更大的数据集时会比较稳定。

(可选)练习

研究指出,应用邻域归一化会降低 GNN 在区分某些图结构时的表达能力。一种替代方案是完全省略邻域归一化,并在 GNN 层中添加一个简单的跳过连接,以保留中心节点信息:
x v ( ℓ + 1 ) = W 1 ( ℓ + 1 ) x v ( ℓ ) + W 2 ( ℓ + 1 ) ∑ w ∈ N ( v ) x w ( ℓ ) \mathbf{x}_v^{(\ell+1)} = \mathbf{W}^{(\ell + 1)}_1 \mathbf{x}_v^{(\ell)} + \mathbf{W}^{(\ell + 1)}_2 \sum_{w \in \mathcal{N}(v)} \mathbf{x}_w^{(\ell)} xv(+1)=W1(+1)xv()+W2(+1)wN(v)xw()

这个层在 PyTorch Geometric 中以 GraphConv 的名称实现。作为练习,你可以完成以下代码,使用 GraphConv 而不是 GCNConv,这应该可以使测试准确率接近 82%。

from torch_geometric.nn import GraphConv

class GNN(torch.nn.Module):
    def __init__(self, hidden_channels):
        super(GNN, self).__init__()
        torch.manual_seed(12345)
        self.conv1 = ...  # TODO
        self.conv2 = ...  # TODO
        self.conv3 = ...  # TODO
        self.lin = Linear(hidden_channels, dataset.num_classes)

    def forward(self, x, edge_index, batch):
        x = self.conv1(x, edge_index)
        x = x.relu()
        x = self.conv2(x, edge_index)
        x = x.relu()
        x = self.conv3(x, edge_index)

        x = global_mean_pool(x, batch)

        x = F.dropout(x, p=0.5, training=self.training)
        x = self.lin(x)

        return x

model = GNN(hidden_channels=64)
print(model)

from IPython.display import Javascript
display(Javascript('''google.colab.output.setIframeHeight(0, true, {maxHeight: 300})'''))

model = GNN(hidden_channels=64)
print(model)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

for epoch in range(1, 201):
    train()
    train_acc = test(train_loader)
    test_acc = test(test_loader)
    print(f'Epoch: {epoch:03d}, Train Acc: {train_acc:.4f}, Test Acc: {test_acc:.4f}')

在本章中,大家学习了如何将 GNNs 应用于图分类任务。了解了如何对图进行批处理以提高 GPU 利用率,以及如何使用读出层获得图嵌入而不是节点嵌入。


http://www.kler.cn/a/520762.html

相关文章:

  • node.js 07.npm下包慢的问题与nrm的使用
  • java入门笔记基础语法篇(4)
  • 解决CentOS9系统下Zabbix 7.2图形中文字符乱码问题
  • Linux相关概念和易错知识点(26)(命名管道、共享内存)
  • 83,【7】BUUCTF WEB [MRCTF2020]你传你[特殊字符]呢
  • π0:仅有3B数据模型打通Franka等7种机器人形态适配,实现0样本的完全由模型自主控制方法
  • 云岚到家项目100问 v1.0
  • 二叉树高频题目——下——不含树型dp
  • 基于单片机的智能小区门禁系统设计(论文+源码)
  • 【填充——双指针,DP】
  • 【算法】剪枝与优化
  • java复习总结
  • 有赞任务js脚本
  • C#的反射使用示例
  • c++小知识点
  • 从规则到神经网络:机器翻译技术的演进与未来展望
  • Golang 执行流程分析
  • 「 机器人 」扑翼飞行器的偏航力矩控制:分周期参数调节机制
  • 【SpringMVC】——Json数据交互处理
  • Leetcode::3432. 统计元素和差值为偶数的分区方案
  • 数据库、数据仓库、数据湖有什么不同
  • redis 实践与扩展
  • 【论文复现】一种改进哈里斯鹰优化算法用于连续和离散优化问题
  • SSM开发(三) spring与mybatis整合(含完整运行demo源码)
  • STM32 OLED屏配置
  • 新电脑第一次开机激活