动手学图神经网络(4):利用图神经网络进行图分类
利用图神经网络进行图分类:从理论到实践
引言
在之前的学习中,大家了解了如何使用图神经网络(GNNs)进行节点分类。本次教程将深入探讨如何运用 GNNs 解决图分类问题。图分类是指在给定一个图数据集的情况下,根据图的一些结构属性对整个图进行分类,而不是对图中的节点进行分类。
环境准备
import os
import torch
os.environ['TORCH'] = torch.__version__
print(torch.__version__)
!pip install -q torch-scatter -f https://data.pyg.org/whl/torch-${TORCH}.html
!pip install -q torch-sparse -f https://data.pyg.org/whl/torch-${TORCH}.html
!pip install -q git+https://github.com/pyg-team/pytorch_geometric.git
数据集介绍
图分类中最常见的任务是分子属性预测,其中分子被表示为图,任务可能是推断分子是否抑制 病毒复制。
大家使用 TU Dortmund University 收集的 TUDatasets 中的 MUTAG 数据集,可以通过 PyTorch Geometric 的 torch_geometric.datasets.TUDataset
访问。
import torch
from torch_geometric.datasets import TUDataset
dataset = TUDataset(root='data/TUDataset', name='MUTAG')
print()
print(f'Dataset: {dataset}:')
print('====================')
print(f'Number of graphs: {len(dataset)}')
print(f'Number of features: {dataset.num_features}')
print(f'Number of classes: {dataset.num_classes}')
data = dataset[0] # 获取第一个图对象
print()
print(data)
print('=============================================================')
# 收集第一个图的统计信息
print(f'Number of nodes: {data.num_nodes}')
print(f'Number of edges: {data.num_edges}')
print(f'Average node degree: {data.num_edges / data.num_nodes:.2f}')
print(f'Has isolated nodes: {data.has_isolated_nodes()}')
print(f'Has self-loops: {data.has_self_loops()}')
print(f'Is undirected: {data.is_undirected()}')
该数据集提供了 188 个不同的图,任务是将每个图分类到两个类别中的一个。第一个图对象有 17 个节点(每个节点有 7 维特征向量)和 38 条边,平均节点度为 2.24,还有一个图标签和 4 维边特征,但为了简单起见,大家不使用边特征。
大家将数据集打乱,并将前 150 个图用作训练集,其余用作测试集。
torch.manual_seed(12345)
dataset = dataset.shuffle()
train_dataset = dataset[:150]
test_dataset = dataset[150:]
print(f'Number of training graphs: {len(train_dataset)}')
print(f'Number of test graphs: {len(test_dataset)}')
图的小批量处理
由于图分类数据集中的图通常较小,在将图输入到图神经网络之前对其进行批处理是个好主意,以保证 GPU 的充分利用。对于 GNNs,PyTorch Geometric 采用了一种特殊的批处理方法,即将邻接矩阵以对角方式堆叠,节点和目标特征在节点维度上简单拼接。
from torch_geometric.loader import DataLoader
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)
for step, data in enumerate(train_loader):
print(f'Step {step + 1}:')
print('=======')
print(f'Number of graphs in the current batch: {data.num_graphs}')
print(data)
print()
这里大家选择了 64 的批处理大小,形成了 3 个(随机打乱的)小批量,包含了所有 150 个图。每个 Batch
对象都有一个 batch
向量,用于将每个节点映射到其在批处理中所属的图。
图数据的多种属性:
- edge_attr:边的属性,其格式为 [边的数量, 每个边属性的维度]。
- edge_index:边的索引,格式为 [2, 边的数量],用于表示图中边的连接关系。
- x:节点特征矩阵,格式为 [节点数量, 每个节点特征的维度]。
- y: 图的标签或目标值,其长度对应批次中图的数量。
- batch:一个长度与节点数量相同的向量,用于标识每个节点属于哪个图。
- ptr:指针向量,长度为批次中图的数量加 1,常用于在批次中索引各个图的起始位置。
训练图神经网络(GNN)
训练用于图分类的 GNN 通常遵循以下简单步骤:
- 通过多次消息传递嵌入每个节点。
- 将节点嵌入聚合为统一的图嵌入(读出层)。
- 在图嵌入上训练最终的分类器。
最常见的读出层是简单地取节点嵌入的平均值:
x
G
=
1
∣
V
∣
∑
v
∈
V
x
v
(
L
)
\mathbf{x}_{\mathcal{G}} = \frac{1}{|\mathcal{V}|} \sum_{v \in \mathcal{V}} \mathcal{x}^{(L)}_v
xG=∣V∣1v∈V∑xv(L)
PyTorch Geometric 通过 torch_geometric.nn.global_mean_pool
提供了此功能。
定义 GCN 模型
from torch.nn import Linear
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.nn import global_mean_pool
class GCN(torch.nn.Module):
def __init__(self, hidden_channels):
super(GCN, self).__init__()
torch.manual_seed(12345)
self.conv1 = GCNConv(dataset.num_node_features, hidden_channels)
self.conv2 = GCNConv(hidden_channels, hidden_channels)
self.conv3 = GCNConv(hidden_channels, hidden_channels)
self.lin = Linear(hidden_channels, dataset.num_classes)
def forward(self, x, edge_index, batch):
# 1. 获得节点嵌入
x = self.conv1(x, edge_index)
x = x.relu()
x = self.conv2(x, edge_index)
x = x.relu()
x = self.conv3(x, edge_index)
# 2. 读出层
x = global_mean_pool(x, batch) # [batch_size, hidden_channels]
# 3. 应用最终的分类器
x = F.dropout(x, p=0.5, training=self.training)
x = self.lin(x)
return x
model = GCN(hidden_channels=64)
print(model)
训练和测试模型
from IPython.display import Javascript
display(Javascript('''google.colab.output.setIframeHeight(0, true, {maxHeight: 300})'''))
model = GCN(hidden_channels=64)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = torch.nn.CrossEntropyLoss()
def train():
model.train()
for data in train_loader: # 在训练数据集上批量迭代
out = model(data.x, data.edge_index, data.batch) # 执行一次前向传播
loss = criterion(out, data.y) # 计算损失
loss.backward() # 计算梯度
optimizer.step() # 根据梯度更新参数
optimizer.zero_grad() # 清除梯度
def test(loader):
model.eval()
correct = 0
for data in loader: # 在训练/测试数据集上批量迭代
out = model(data.x, data.edge_index, data.batch)
pred = out.argmax(dim=1) # 使用概率最高的类别
correct += int((pred == data.y).sum()) # 与真实标签进行比较
return correct / len(loader.dataset) # 计算正确预测的比例
for epoch in range(1, 171):
train()
train_acc = test(train_loader)
test_acc = test(test_loader)
print(f'Epoch: {epoch:03d}, Train Acc: {train_acc:.4f}, Test Acc: {test_acc:.4f}')
Epoch: 001, Train Acc: 0.6467, Test Acc: 0.7368
Epoch: 002, Train Acc: 0.6467, Test Acc: 0.7368
Epoch: 003, Train Acc: 0.6467, Test Acc: 0.7368
Epoch: 004, Train Acc: 0.6467, Test Acc: 0.7368
Epoch: 005, Train Acc: 0.6467, Test Acc: 0.7368
Epoch: 006, Train Acc: 0.6533, Test Acc: 0.7368
Epoch: 007, Train Acc: 0.7467, Test Acc: 0.7632
Epoch: 008, Train Acc: 0.7267, Test Acc: 0.7632
Epoch: 009, Train Acc: 0.7200, Test Acc: 0.7632
Epoch: 010, Train Acc: 0.7133, Test Acc: 0.7895
Epoch: 011, Train Acc: 0.7200, Test Acc: 0.7632
Epoch: 012, Train Acc: 0.7200, Test Acc: 0.7895
Epoch: 013, Train Acc: 0.7200, Test Acc: 0.7895
Epoch: 014, Train Acc: 0.7133, Test Acc: 0.8421
Epoch: 015, Train Acc: 0.7133, Test Acc: 0.8421
Epoch: 016, Train Acc: 0.7533, Test Acc: 0.7368
Epoch: 017, Train Acc: 0.7400, Test Acc: 0.7632
Epoch: 018, Train Acc: 0.7133, Test Acc: 0.8421
Epoch: 019, Train Acc: 0.7400, Test Acc: 0.7895
Epoch: 020, Train Acc: 0.7533, Test Acc: 0.7368
Epoch: 021, Train Acc: 0.7467, Test Acc: 0.7895
Epoch: 022, Train Acc: 0.7467, Test Acc: 0.7895
Epoch: 023, Train Acc: 0.7533, Test Acc: 0.7895
Epoch: 024, Train Acc: 0.7267, Test Acc: 0.8421
Epoch: 025, Train Acc: 0.7533, Test Acc: 0.7632
Epoch: 026, Train Acc: 0.7533, Test Acc: 0.7632
Epoch: 027, Train Acc: 0.7600, Test Acc: 0.8158
Epoch: 028, Train Acc: 0.7533, Test Acc: 0.8421
Epoch: 029, Train Acc: 0.7600, Test Acc: 0.7632
Epoch: 030, Train Acc: 0.7600, Test Acc: 0.8158
Epoch: 031, Train Acc: 0.7600, Test Acc: 0.8158
Epoch: 032, Train Acc: 0.7600, Test Acc: 0.7632
Epoch: 033, Train Acc: 0.7667, Test Acc: 0.7632
Epoch: 034, Train Acc: 0.7667, Test Acc: 0.7895
Epoch: 035, Train Acc: 0.7667, Test Acc: 0.7895
Epoch: 036, Train Acc: 0.7667, Test Acc: 0.7632
Epoch: 037, Train Acc: 0.7400, Test Acc: 0.7632
Epoch: 038, Train Acc: 0.7667, Test Acc: 0.8158
Epoch: 039, Train Acc: 0.7667, Test Acc: 0.7895
Epoch: 040, Train Acc: 0.7533, Test Acc: 0.7368
Epoch: 041, Train Acc: 0.7467, Test Acc: 0.7368
Epoch: 042, Train Acc: 0.7667, Test Acc: 0.7895
Epoch: 043, Train Acc: 0.7667, Test Acc: 0.8158
Epoch: 044, Train Acc: 0.7533, Test Acc: 0.7632
Epoch: 045, Train Acc: 0.7600, Test Acc: 0.7632
Epoch: 046, Train Acc: 0.7600, Test Acc: 0.7632
Epoch: 047, Train Acc: 0.7667, Test Acc: 0.8158
Epoch: 048, Train Acc: 0.7600, Test Acc: 0.7632
Epoch: 049, Train Acc: 0.7667, Test Acc: 0.7632
Epoch: 050, Train Acc: 0.7667, Test Acc: 0.8158
Epoch: 051, Train Acc: 0.7733, Test Acc: 0.7895
Epoch: 052, Train Acc: 0.7733, Test Acc: 0.8158
Epoch: 053, Train Acc: 0.7667, Test Acc: 0.7632
Epoch: 054, Train Acc: 0.7667, Test Acc: 0.7632
Epoch: 055, Train Acc: 0.7800, Test Acc: 0.7895
Epoch: 056, Train Acc: 0.7667, Test Acc: 0.7632
Epoch: 057, Train Acc: 0.7533, Test Acc: 0.7632
Epoch: 058, Train Acc: 0.7733, Test Acc: 0.7895
Epoch: 059, Train Acc: 0.7800, Test Acc: 0.7632
Epoch: 060, Train Acc: 0.7733, Test Acc: 0.8158
Epoch: 061, Train Acc: 0.7667, Test Acc: 0.7632
Epoch: 062, Train Acc: 0.7733, Test Acc: 0.8158
Epoch: 063, Train Acc: 0.7733, Test Acc: 0.8158
Epoch: 064, Train Acc: 0.7733, Test Acc: 0.8158
Epoch: 065, Train Acc: 0.7733, Test Acc: 0.8158
Epoch: 066, Train Acc: 0.7733, Test Acc: 0.7895
Epoch: 067, Train Acc: 0.7733, Test Acc: 0.7895
Epoch: 068, Train Acc: 0.7667, Test Acc: 0.7895
Epoch: 069, Train Acc: 0.7733, Test Acc: 0.8158
Epoch: 070, Train Acc: 0.7667, Test Acc: 0.7895
Epoch: 071, Train Acc: 0.7733, Test Acc: 0.7895
Epoch: 072, Train Acc: 0.7800, Test Acc: 0.7895
Epoch: 073, Train Acc: 0.7733, Test Acc: 0.8158
Epoch: 074, Train Acc: 0.7733, Test Acc: 0.8158
Epoch: 075, Train Acc: 0.7667, Test Acc: 0.7632
Epoch: 076, Train Acc: 0.7800, Test Acc: 0.7895
Epoch: 077, Train Acc: 0.7800, Test Acc: 0.7895
Epoch: 078, Train Acc: 0.7733, Test Acc: 0.8421
Epoch: 079, Train Acc: 0.7667, Test Acc: 0.8158
Epoch: 080, Train Acc: 0.7800, Test Acc: 0.7895
Epoch: 081, Train Acc: 0.7667, Test Acc: 0.7895
Epoch: 082, Train Acc: 0.7600, Test Acc: 0.7632
Epoch: 083, Train Acc: 0.7800, Test Acc: 0.7895
Epoch: 084, Train Acc: 0.7733, Test Acc: 0.7895
Epoch: 085, Train Acc: 0.7667, Test Acc: 0.7895
Epoch: 086, Train Acc: 0.7800, Test Acc: 0.8158
Epoch: 087, Train Acc: 0.7667, Test Acc: 0.7895
Epoch: 088, Train Acc: 0.7800, Test Acc: 0.7895
Epoch: 089, Train Acc: 0.7667, Test Acc: 0.7895
Epoch: 090, Train Acc: 0.7800, Test Acc: 0.7895
Epoch: 091, Train Acc: 0.7800, Test Acc: 0.7895
Epoch: 092, Train Acc: 0.7800, Test Acc: 0.8158
Epoch: 093, Train Acc: 0.7800, Test Acc: 0.7895
Epoch: 094, Train Acc: 0.7733, Test Acc: 0.7895
Epoch: 095, Train Acc: 0.7800, Test Acc: 0.7895
Epoch: 096, Train Acc: 0.7600, Test Acc: 0.7895
Epoch: 097, Train Acc: 0.7733, Test Acc: 0.7895
Epoch: 098, Train Acc: 0.7733, Test Acc: 0.8158
Epoch: 099, Train Acc: 0.7733, Test Acc: 0.7895
Epoch: 100, Train Acc: 0.7733, Test Acc: 0.7895
Epoch: 101, Train Acc: 0.7667, Test Acc: 0.7895
Epoch: 102, Train Acc: 0.7667, Test Acc: 0.7895
Epoch: 103, Train Acc: 0.7733, Test Acc: 0.7895
Epoch: 104, Train Acc: 0.7600, Test Acc: 0.7632
Epoch: 105, Train Acc: 0.7733, Test Acc: 0.7368
Epoch: 106, Train Acc: 0.7600, Test Acc: 0.7632
Epoch: 107, Train Acc: 0.7733, Test Acc: 0.7105
Epoch: 108, Train Acc: 0.8000, Test Acc: 0.7632
Epoch: 109, Train Acc: 0.7800, Test Acc: 0.7895
Epoch: 110, Train Acc: 0.7733, Test Acc: 0.7632
Epoch: 111, Train Acc: 0.7733, Test Acc: 0.7895
Epoch: 112, Train Acc: 0.7733, Test Acc: 0.7895
Epoch: 113, Train Acc: 0.7667, Test Acc: 0.7895
Epoch: 114, Train Acc: 0.7733, Test Acc: 0.7895
Epoch: 115, Train Acc: 0.7667, Test Acc: 0.7895
Epoch: 116, Train Acc: 0.7733, Test Acc: 0.7632
Epoch: 117, Train Acc: 0.7733, Test Acc: 0.7895
Epoch: 118, Train Acc: 0.7733, Test Acc: 0.7632
Epoch: 119, Train Acc: 0.7667, Test Acc: 0.7632
Epoch: 120, Train Acc: 0.8000, Test Acc: 0.7105
Epoch: 121, Train Acc: 0.7600, Test Acc: 0.7632
Epoch: 122, Train Acc: 0.7667, Test Acc: 0.7105
Epoch: 123, Train Acc: 0.7667, Test Acc: 0.7632
Epoch: 124, Train Acc: 0.7667, Test Acc: 0.7632
Epoch: 125, Train Acc: 0.7667, Test Acc: 0.7632
Epoch: 126, Train Acc: 0.7733, Test Acc: 0.7368
Epoch: 127, Train Acc: 0.7733, Test Acc: 0.7632
Epoch: 128, Train Acc: 0.7733, Test Acc: 0.7632
Epoch: 129, Train Acc: 0.7733, Test Acc: 0.7632
Epoch: 130, Train Acc: 0.7733, Test Acc: 0.7632
Epoch: 131, Train Acc: 0.7667, Test Acc: 0.7632
Epoch: 132, Train Acc: 0.7800, Test Acc: 0.7895
Epoch: 133, Train Acc: 0.7733, Test Acc: 0.7632
Epoch: 134, Train Acc: 0.7667, Test Acc: 0.7632
Epoch: 135, Train Acc: 0.8067, Test Acc: 0.7368
Epoch: 136, Train Acc: 0.7800, Test Acc: 0.7632
Epoch: 137, Train Acc: 0.7733, Test Acc: 0.7632
Epoch: 138, Train Acc: 0.8133, Test Acc: 0.7105
Epoch: 139, Train Acc: 0.7867, Test Acc: 0.7632
Epoch: 140, Train Acc: 0.7800, Test Acc: 0.7895
Epoch: 141, Train Acc: 0.8000, Test Acc: 0.6579
Epoch: 142, Train Acc: 0.7733, Test Acc: 0.7632
Epoch: 143, Train Acc: 0.7933, Test Acc: 0.7632
Epoch: 144, Train Acc: 0.7867, Test Acc: 0.7368
Epoch: 145, Train Acc: 0.8267, Test Acc: 0.7368
Epoch: 146, Train Acc: 0.7667, Test Acc: 0.7895
Epoch: 147, Train Acc: 0.7800, Test Acc: 0.7105
Epoch: 148, Train Acc: 0.7933, Test Acc: 0.7895
Epoch: 149, Train Acc: 0.8200, Test Acc: 0.7105
Epoch: 150, Train Acc: 0.7800, Test Acc: 0.7895
Epoch: 151, Train Acc: 0.7800, Test Acc: 0.7632
Epoch: 152, Train Acc: 0.7867, Test Acc: 0.7632
Epoch: 153, Train Acc: 0.8067, Test Acc: 0.7368
Epoch: 154, Train Acc: 0.8067, Test Acc: 0.7368
Epoch: 155, Train Acc: 0.7867, Test Acc: 0.7632
Epoch: 156, Train Acc: 0.7800, Test Acc: 0.7105
Epoch: 157, Train Acc: 0.8000, Test Acc: 0.7368
Epoch: 158, Train Acc: 0.7800, Test Acc: 0.7368
Epoch: 159, Train Acc: 0.7867, Test Acc: 0.7632
Epoch: 160, Train Acc: 0.7867, Test Acc: 0.7632
Epoch: 161, Train Acc: 0.7800, Test Acc: 0.7632
Epoch: 162, Train Acc: 0.7867, Test Acc: 0.7632
Epoch: 163, Train Acc: 0.7867, Test Acc: 0.7632
Epoch: 164, Train Acc: 0.7800, Test Acc: 0.8158
Epoch: 165, Train Acc: 0.7800, Test Acc: 0.8158
Epoch: 166, Train Acc: 0.7733, Test Acc: 0.7632
Epoch: 167, Train Acc: 0.7867, Test Acc: 0.7895
Epoch: 168, Train Acc: 0.7867, Test Acc: 0.7895
Epoch: 169, Train Acc: 0.8000, Test Acc: 0.7632
Epoch: 170, Train Acc: 0.8000, Test Acc: 0.7632
模型达到了约 76% 的测试准确率。由于数据集较小(只有 38 个测试图),准确率会有波动,在应用于更大的数据集时会比较稳定。
(可选)练习
研究指出,应用邻域归一化会降低 GNN 在区分某些图结构时的表达能力。一种替代方案是完全省略邻域归一化,并在 GNN 层中添加一个简单的跳过连接,以保留中心节点信息:
x
v
(
ℓ
+
1
)
=
W
1
(
ℓ
+
1
)
x
v
(
ℓ
)
+
W
2
(
ℓ
+
1
)
∑
w
∈
N
(
v
)
x
w
(
ℓ
)
\mathbf{x}_v^{(\ell+1)} = \mathbf{W}^{(\ell + 1)}_1 \mathbf{x}_v^{(\ell)} + \mathbf{W}^{(\ell + 1)}_2 \sum_{w \in \mathcal{N}(v)} \mathbf{x}_w^{(\ell)}
xv(ℓ+1)=W1(ℓ+1)xv(ℓ)+W2(ℓ+1)w∈N(v)∑xw(ℓ)
这个层在 PyTorch Geometric 中以 GraphConv
的名称实现。作为练习,你可以完成以下代码,使用 GraphConv
而不是 GCNConv
,这应该可以使测试准确率接近 82%。
from torch_geometric.nn import GraphConv
class GNN(torch.nn.Module):
def __init__(self, hidden_channels):
super(GNN, self).__init__()
torch.manual_seed(12345)
self.conv1 = ... # TODO
self.conv2 = ... # TODO
self.conv3 = ... # TODO
self.lin = Linear(hidden_channels, dataset.num_classes)
def forward(self, x, edge_index, batch):
x = self.conv1(x, edge_index)
x = x.relu()
x = self.conv2(x, edge_index)
x = x.relu()
x = self.conv3(x, edge_index)
x = global_mean_pool(x, batch)
x = F.dropout(x, p=0.5, training=self.training)
x = self.lin(x)
return x
model = GNN(hidden_channels=64)
print(model)
from IPython.display import Javascript
display(Javascript('''google.colab.output.setIframeHeight(0, true, {maxHeight: 300})'''))
model = GNN(hidden_channels=64)
print(model)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
for epoch in range(1, 201):
train()
train_acc = test(train_loader)
test_acc = test(test_loader)
print(f'Epoch: {epoch:03d}, Train Acc: {train_acc:.4f}, Test Acc: {test_acc:.4f}')
在本章中,大家学习了如何将 GNNs 应用于图分类任务。了解了如何对图进行批处理以提高 GPU 利用率,以及如何使用读出层获得图嵌入而不是节点嵌入。