torch_geometric使用手册-Creating Message Passing Networks(专题二)
创建消息传递网络 (Message Passing Networks)
在图神经网络中,将卷积操作推广到不规则域通常表现为一种邻域聚合 (neighborhood aggregation) 或 消息传递 (message passing) 机制。
这一机制通过聚合节点的邻居信息,更新每个节点的特征。
以下公式描述了消息传递机制的基本形式:
公式解释
x i ( k ) = γ ( k ) ( x i ( k − 1 ) , ⨁ j ∈ N ( i ) ϕ ( k ) ( x i ( k − 1 ) , x j ( k − 1 ) , e j , i ) ) \mathbf{x}_i^{(k)} = \gamma^{(k)} \left( \mathbf{x}_i^{(k-1)}, \bigoplus_{j \in \mathcal{N}(i)} \, \phi^{(k)}\left(\mathbf{x}_i^{(k-1)}, \mathbf{x}_j^{(k-1)},\mathbf{e}_{j,i}\right) \right) xi(k)=γ(k) xi(k−1),j∈N(i)⨁ϕ(k)(xi(k−1),xj(k−1),ej,i)
- x i ( k − 1 ) \mathbf{x}_i^{(k-1)} xi(k−1): 表示第 k − 1 k-1 k−1 层时节点 i i i 的特征。
- e j , i \mathbf{e}_{j,i} ej,i: 表示从节点 j j j 到节点 i i i 的边特征(可选)。
- N ( i ) \mathcal{N}(i) N(i): 节点 i i i 的邻居节点集合。
-
ϕ
(
k
)
\phi^{(k)}
ϕ(k): 消息函数 (
message function
),生成从邻居节点 j j j 到节点 i i i 的消息。 -
⨁
\bigoplus
⨁: 聚合函数 (
aggregation function
),例如加和 (sum)、均值 (mean) 或最大值 (max)。 -
γ
(
k
)
\gamma^{(k)}
γ(k): 更新函数 (
update function
),结合节点本身的特征与聚合后的消息。
PyTorch Geometric (PyG
) 提供了一个名为 MessagePassing
的基类,专门用于实现基于消息传递机制的图神经网络(GNN)。这个类封装了消息传递中的许多细节,开发者只需要定义核心函数,例如消息构造(message
)、特征更新(update
),以及选择合适的聚合方式(aggr
),即可实现复杂的 GNN 算法。
核心概念与方法解析
1. 构造 MessagePassing
基类
MessagePassing(aggr="add", flow="source_to_target", node_dim=-2)
功能
-
定义消息传递的聚合方式:
aggr
: 表示如何将来自邻居节点的消息聚合到目标节点。add
(加和):计算邻居节点消息的加权和。mean
(平均):取邻居节点消息的加权平均值。max
(最大值):选择邻居节点消息的最大值。
-
定义消息的传递方向:
flow
:"source_to_target"
:从源节点传递消息到目标节点。"target_to_source"
:从目标节点向源节点传递消息。
-
node_dim
:- 指定在哪一维度上传递节点特征。通常是倒数第二维(默认为 -2),适配节点特征张量。
2. 消息传递的入口:propagate
方法
MessagePassing.propagate(edge_index, size=None, **kwargs)
功能
- 触发消息传递过程,从边索引和输入特征开始,依次执行:
- 消息构造(
message
):生成从邻居节点传来的消息。 - 消息聚合(
aggregate
,自动完成):将邻居节点的消息聚合到目标节点。 - 特征更新(
update
):更新目标节点的最终特征。
- 消息构造(
注意: 这是入口函数,类似forward的操作,会调用
message
、aggregate
、update
函数.
参数
-
edge_index
:- 图的边索引,形状为
[2, num_edges]
。 - 第一行表示源节点,第二行表示目标节点。
- 图的边索引,形状为
-
size
:- 图中节点的数量或维度。
- 对于普通图,默认为
[num_nodes, num_nodes]
;对于二分图(bipartite graph),可以传递(N, M)
,分别表示源节点和目标节点数量。
-
kwargs
:- 其他参数,如节点特征
x
,边特征edge_attr
等。
- 其他参数,如节点特征
3. 消息生成:message
方法
MessagePassing.message(...)
功能
- 根据每条边的两端节点特征(源节点和目标节点)以及边特征,生成要传递的消息。
参数
- 默认情况下:
x_j
: 源节点的特征。x_i
: 目标节点的特征。edge_attr
: 边的特征(如果存在)。
自动变量映射
在 propagate
内部,会根据 edge_index
自动将输入特征分为:
x_j
:从源节点出发的特征。x_i
:传递到目标节点的特征。
4. 特征更新:update
方法
MessagePassing.update(aggr_out, ...)
功能
- 根据聚合后的结果
aggr_out
,计算目标节点的最终特征。
参数
aggr_out
: 聚合后的邻居节点消息。- 可以使用其他参数,例如目标节点本身的初始特征。
5. 应用流程总结
-
消息生成:
- 根据边和节点特征,生成从邻居节点传递的消息(通过
message
方法)。
- 根据边和节点特征,生成从邻居节点传递的消息(通过
-
消息聚合:
- 使用选定的聚合方式(
aggr
参数,如加和或平均),将消息聚合到目标节点。
- 使用选定的聚合方式(
-
特征更新:
- 在目标节点上应用更新规则,生成最终的节点特征(通过
update
方法)。
- 在目标节点上应用更新规则,生成最终的节点特征(通过
示例:实现经典的 GCN 和 EdgeConv
实现 GCN 层(Graph Convolutional Layer)
GCN 层的数学定义如下:
x i ( k ) = ∑ j ∈ N ( i ) ∪ { i } 1 deg ( i ) ⋅ deg ( j ) ⋅ ( W ⊤ ⋅ x j ( k − 1 ) ) + b \mathbf{x}_i^{(k)} = \sum_{j \in \mathcal{N}(i) \cup \{ i \}} \frac{1}{\sqrt{\deg(i)} \cdot \sqrt{\deg(j)}} \cdot \left( \mathbf{W}^{\top} \cdot \mathbf{x}_j^{(k-1)} \right) + \mathbf{b} xi(k)=j∈N(i)∪{i}∑deg(i)⋅deg(j)1⋅(W⊤⋅xj(k−1))+b
- 邻居节点特征通过一个权重矩阵 W \mathbf{W} W 进行变换。
- 然后,按照节点度进行归一化。
- 最后,对邻居节点特征进行聚合并添加偏置项 b \mathbf{b} b。
这个公式可以拆解为以下几个步骤:
- 为邻接矩阵添加自环(self-loops)。
- 对节点特征矩阵进行线性变换。
- 计算归一化系数。
- 对特征进行归一化处理。
- 聚合邻居节点特征(使用"加和"操作,
"add"
聚合)。 - 对聚合结果加上最终的偏置项。
在实现过程中:
- 步骤 1-3 通常在消息传递(message passing)前完成。
- 步骤 4-5 使用
MessagePassing
基类轻松实现。
以下是完整的 GCN 层实现代码:
import torch
from torch.nn import Linear, Parameter
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree
class GCNConv(MessagePassing):
def __init__(self, in_channels, out_channels):
super().__init__(aggr='add') # "加和"聚合 (Step 5)
self.lin = Linear(in_channels, out_channels, bias=False)
self.bias = Parameter(torch.empty(out_channels))
self.reset_parameters()
def reset_parameters(self):
self.lin.reset_parameters()
self.bias.data.zero_()
def forward(self, x, edge_index):
# Step 1: 添加自环到邻接矩阵
edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))
# Step 2: 对节点特征进行线性变换
x = self.lin(x)
# Step 3: 计算归一化系数
row, col = edge_index
deg = degree(col, x.size(0), dtype=x.dtype)
deg_inv_sqrt = deg.pow(-0.5)
deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]
# Step 4-5: 开始消息传递
out = self.propagate(edge_index, x=x, norm=norm)
# Step 6: 添加最终的偏置项
out = out + self.bias
return out
def message(self, x_j, norm):
# 对节点特征进行归一化 (Step 4)
return norm.view(-1, 1) * x_j
实现 EdgeConv(边卷积)
边卷积用于处理图结构或点云,其数学定义为:
x i ( k ) = max j ∈ N ( i ) h Θ ( x i ( k − 1 ) , x j ( k − 1 ) − x i ( k − 1 ) ) \mathbf{x}_i^{(k)} = \max_{j \in \mathcal{N}(i)} h_{\mathbf{\Theta}} \left( \mathbf{x}_i^{(k-1)}, \mathbf{x}_j^{(k-1)} - \mathbf{x}_i^{(k-1)} \right) xi(k)=j∈N(i)maxhΘ(xi(k−1),xj(k−1)−xi(k−1))
其中,
h
Θ
h_{\mathbf{\Theta}}
hΘ 是一个多层感知机(MLP)。
与 GCN 类似,EdgeConv
层也基于 MessagePassing
实现,但使用的是 "max"
聚合方式。
以下是 EdgeConv
的实现代码:
import torch
from torch.nn import Sequential as Seq, Linear, ReLU
from torch_geometric.nn import MessagePassing
class EdgeConv(MessagePassing):
def __init__(self, in_channels, out_channels):
super().__init__(aggr='max') # "最大值" 聚合
self.mlp = Seq(Linear(2 * in_channels, out_channels),
ReLU(),
Linear(out_channels, out_channels))
def forward(self, x, edge_index):
return self.propagate(edge_index, x=x)
def message(self, x_i, x_j):
# 计算相对特征并输入到 MLP
tmp = torch.cat([x_i, x_j - x_i], dim=1)
return self.mlp(tmp)
EdgeConv
实际上是一个动态卷积,每一层都在特征空间中根据最近邻重新计算图。
PyG 提供了一个 GPU 加速的 k-NN 图生成方法 knn_graph
:
from torch_geometric.nn import knn_graph
class DynamicEdgeConv(EdgeConv):
def __init__(self, in_channels, out_channels, k=6):
super().__init__(in_channels, out_channels)
self.k = k
def forward(self, x, batch=None):
edge_index = knn_graph(x, self.k, batch, loop=False, flow=self.flow)
return super().forward(x, edge_index)
DynamicEdgeConv
动态生成 k-NN 图,然后调用 EdgeConv
的 forward
方法。
练习题翻译
关于 GCNConv:
row
和col
包含什么信息?degree
方法的作用是什么?- 为什么用
degree(col, ...)
而不是degree(row, ...)
? deg_inv_sqrt[col]
和deg_inv_sqrt[row]
的作用是什么?- 在
message
方法中,x_j
包含什么信息?如果self.lin
是恒等函数,x_j
的内容具体是什么? - 添加一个
update
方法,使其将变换后的中心节点特征添加到聚合输出中。
关于 EdgeConv:
x_i
和x_j - x_i
是什么?torch.cat([x_i, x_j - x_i], dim=1)
的作用是什么?为什么是dim=1
?