当前位置: 首页 > article >正文

torch_geometric使用手册-Creating Message Passing Networks(专题二)

创建消息传递网络 (Message Passing Networks)

在图神经网络中,将卷积操作推广到不规则域通常表现为一种邻域聚合 (neighborhood aggregation)消息传递 (message passing) 机制。
这一机制通过聚合节点的邻居信息,更新每个节点的特征。

以下公式描述了消息传递机制的基本形式:

公式解释

x i ( k ) = γ ( k ) ( x i ( k − 1 ) , ⨁ j ∈ N ( i )   ϕ ( k ) ( x i ( k − 1 ) , x j ( k − 1 ) , e j , i ) ) \mathbf{x}_i^{(k)} = \gamma^{(k)} \left( \mathbf{x}_i^{(k-1)}, \bigoplus_{j \in \mathcal{N}(i)} \, \phi^{(k)}\left(\mathbf{x}_i^{(k-1)}, \mathbf{x}_j^{(k-1)},\mathbf{e}_{j,i}\right) \right) xi(k)=γ(k) xi(k1),jN(i)ϕ(k)(xi(k1),xj(k1),ej,i)

  • x i ( k − 1 ) \mathbf{x}_i^{(k-1)} xi(k1): 表示第 k − 1 k-1 k1 层时节点 i i i 的特征。
  • e j , i \mathbf{e}_{j,i} ej,i: 表示从节点 j j j 到节点 i i i 的边特征(可选)。
  • N ( i ) \mathcal{N}(i) N(i): 节点 i i i 的邻居节点集合。
  • ϕ ( k ) \phi^{(k)} ϕ(k): 消息函数 (message function),生成从邻居节点 j j j 到节点 i i i 的消息。
  • ⨁ \bigoplus : 聚合函数 (aggregation function),例如加和 (sum)、均值 (mean) 或最大值 (max)。
  • γ ( k ) \gamma^{(k)} γ(k): 更新函数 (update function),结合节点本身的特征与聚合后的消息。

PyTorch Geometric (PyG) 提供了一个名为 MessagePassing 的基类,专门用于实现基于消息传递机制的图神经网络(GNN)。这个类封装了消息传递中的许多细节,开发者只需要定义核心函数,例如消息构造(message)、特征更新(update),以及选择合适的聚合方式(aggr),即可实现复杂的 GNN 算法。


核心概念与方法解析

1. 构造 MessagePassing 基类

MessagePassing(aggr="add", flow="source_to_target", node_dim=-2)
功能
  • 定义消息传递的聚合方式:

    • aggr: 表示如何将来自邻居节点的消息聚合到目标节点。
      • add(加和):计算邻居节点消息的加权和。
      • mean(平均):取邻居节点消息的加权平均值。
      • max(最大值):选择邻居节点消息的最大值。
  • 定义消息的传递方向:

    • flow:
      • "source_to_target":从源节点传递消息到目标节点。
      • "target_to_source":从目标节点向源节点传递消息。
  • node_dim:

    • 指定在哪一维度上传递节点特征。通常是倒数第二维(默认为 -2),适配节点特征张量。

2. 消息传递的入口:propagate 方法

MessagePassing.propagate(edge_index, size=None, **kwargs)
功能
  • 触发消息传递过程,从边索引和输入特征开始,依次执行:
    1. 消息构造message):生成从邻居节点传来的消息。
    2. 消息聚合aggregate,自动完成):将邻居节点的消息聚合到目标节点。
    3. 特征更新update):更新目标节点的最终特征。

注意: 这是入口函数,类似forward的操作,会调用messageaggregateupdate函数.

参数
  • edge_index:

    • 图的边索引,形状为 [2, num_edges]
    • 第一行表示源节点,第二行表示目标节点。
  • size:

    • 图中节点的数量或维度。
    • 对于普通图,默认为 [num_nodes, num_nodes];对于二分图(bipartite graph),可以传递 (N, M),分别表示源节点和目标节点数量。
  • kwargs:

    • 其他参数,如节点特征 x,边特征 edge_attr 等。

3. 消息生成:message 方法

MessagePassing.message(...)
功能
  • 根据每条边的两端节点特征(源节点和目标节点)以及边特征,生成要传递的消息。
参数
  • 默认情况下:
    • x_j: 源节点的特征。
    • x_i: 目标节点的特征。
    • edge_attr: 边的特征(如果存在)。
自动变量映射

propagate 内部,会根据 edge_index 自动将输入特征分为:

  • x_j:从源节点出发的特征。
  • x_i:传递到目标节点的特征。

4. 特征更新:update 方法

MessagePassing.update(aggr_out, ...)
功能
  • 根据聚合后的结果 aggr_out,计算目标节点的最终特征。
参数
  • aggr_out: 聚合后的邻居节点消息。
  • 可以使用其他参数,例如目标节点本身的初始特征。

5. 应用流程总结

  1. 消息生成

    • 根据边和节点特征,生成从邻居节点传递的消息(通过 message 方法)。
  2. 消息聚合

    • 使用选定的聚合方式(aggr 参数,如加和或平均),将消息聚合到目标节点。
  3. 特征更新

    • 在目标节点上应用更新规则,生成最终的节点特征(通过 update 方法)。

示例:实现经典的 GCN 和 EdgeConv

实现 GCN 层(Graph Convolutional Layer)

GCN 层的数学定义如下:

x i ( k ) = ∑ j ∈ N ( i ) ∪ { i } 1 deg ⁡ ( i ) ⋅ deg ⁡ ( j ) ⋅ ( W ⊤ ⋅ x j ( k − 1 ) ) + b \mathbf{x}_i^{(k)} = \sum_{j \in \mathcal{N}(i) \cup \{ i \}} \frac{1}{\sqrt{\deg(i)} \cdot \sqrt{\deg(j)}} \cdot \left( \mathbf{W}^{\top} \cdot \mathbf{x}_j^{(k-1)} \right) + \mathbf{b} xi(k)=jN(i){i}deg(i) deg(j) 1(Wxj(k1))+b

  • 邻居节点特征通过一个权重矩阵 W \mathbf{W} W 进行变换。
  • 然后,按照节点度进行归一化。
  • 最后,对邻居节点特征进行聚合并添加偏置项 b \mathbf{b} b

这个公式可以拆解为以下几个步骤:

  1. 为邻接矩阵添加自环(self-loops)
  2. 对节点特征矩阵进行线性变换
  3. 计算归一化系数
  4. 对特征进行归一化处理
  5. 聚合邻居节点特征(使用"加和"操作,"add" 聚合)。
  6. 对聚合结果加上最终的偏置项

在实现过程中:

  • 步骤 1-3 通常在消息传递(message passing)前完成。
  • 步骤 4-5 使用 MessagePassing 基类轻松实现。

以下是完整的 GCN 层实现代码:

import torch
from torch.nn import Linear, Parameter
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree

class GCNConv(MessagePassing):
    def __init__(self, in_channels, out_channels):
        super().__init__(aggr='add')  # "加和"聚合 (Step 5)
        self.lin = Linear(in_channels, out_channels, bias=False)
        self.bias = Parameter(torch.empty(out_channels))
        self.reset_parameters()

    def reset_parameters(self):
        self.lin.reset_parameters()
        self.bias.data.zero_()

    def forward(self, x, edge_index):
        # Step 1: 添加自环到邻接矩阵
        edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))

        # Step 2: 对节点特征进行线性变换
        x = self.lin(x)

        # Step 3: 计算归一化系数
        row, col = edge_index
        deg = degree(col, x.size(0), dtype=x.dtype)
        deg_inv_sqrt = deg.pow(-0.5)
        deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
        norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]

        # Step 4-5: 开始消息传递
        out = self.propagate(edge_index, x=x, norm=norm)

        # Step 6: 添加最终的偏置项
        out = out + self.bias

        return out

    def message(self, x_j, norm):
        # 对节点特征进行归一化 (Step 4)
        return norm.view(-1, 1) * x_j

实现 EdgeConv(边卷积)

边卷积用于处理图结构或点云,其数学定义为:

x i ( k ) = max ⁡ j ∈ N ( i ) h Θ ( x i ( k − 1 ) , x j ( k − 1 ) − x i ( k − 1 ) ) \mathbf{x}_i^{(k)} = \max_{j \in \mathcal{N}(i)} h_{\mathbf{\Theta}} \left( \mathbf{x}_i^{(k-1)}, \mathbf{x}_j^{(k-1)} - \mathbf{x}_i^{(k-1)} \right) xi(k)=jN(i)maxhΘ(xi(k1),xj(k1)xi(k1))

其中, h Θ h_{\mathbf{\Theta}} hΘ 是一个多层感知机(MLP)。
与 GCN 类似,EdgeConv 层也基于 MessagePassing 实现,但使用的是 "max" 聚合方式。

以下是 EdgeConv 的实现代码:

import torch
from torch.nn import Sequential as Seq, Linear, ReLU
from torch_geometric.nn import MessagePassing

class EdgeConv(MessagePassing):
    def __init__(self, in_channels, out_channels):
        super().__init__(aggr='max')  # "最大值" 聚合
        self.mlp = Seq(Linear(2 * in_channels, out_channels),
                       ReLU(),
                       Linear(out_channels, out_channels))

    def forward(self, x, edge_index):
        return self.propagate(edge_index, x=x)

    def message(self, x_i, x_j):
        # 计算相对特征并输入到 MLP
        tmp = torch.cat([x_i, x_j - x_i], dim=1)
        return self.mlp(tmp)

EdgeConv 实际上是一个动态卷积,每一层都在特征空间中根据最近邻重新计算图。
PyG 提供了一个 GPU 加速的 k-NN 图生成方法 knn_graph

from torch_geometric.nn import knn_graph

class DynamicEdgeConv(EdgeConv):
    def __init__(self, in_channels, out_channels, k=6):
        super().__init__(in_channels, out_channels)
        self.k = k

    def forward(self, x, batch=None):
        edge_index = knn_graph(x, self.k, batch, loop=False, flow=self.flow)
        return super().forward(x, edge_index)

DynamicEdgeConv 动态生成 k-NN 图,然后调用 EdgeConvforward 方法。


练习题翻译

关于 GCNConv:

  1. rowcol 包含什么信息?
  2. degree 方法的作用是什么?
  3. 为什么用 degree(col, ...) 而不是 degree(row, ...)
  4. deg_inv_sqrt[col]deg_inv_sqrt[row] 的作用是什么?
  5. message 方法中,x_j 包含什么信息?如果 self.lin 是恒等函数,x_j 的内容具体是什么?
  6. 添加一个 update 方法,使其将变换后的中心节点特征添加到聚合输出中。

关于 EdgeConv:

  1. x_ix_j - x_i 是什么?
  2. torch.cat([x_i, x_j - x_i], dim=1) 的作用是什么?为什么是 dim=1

http://www.kler.cn/a/407730.html

相关文章:

  • width设置100vh但出现横向滚动条的问题
  • LeetCode 力扣 热题 100道(九)反转链表(C++)
  • 后端开发详细学习框架与路线
  • Thymeleaf模板引擎生成的html字符串转换成pdf
  • Python + 深度学习从 0 到 1(00 / 99)
  • 安宝特分享 | 如何利用AR技术革新医疗实践:从远程急救到多学科协作
  • Docker 配置 HTTP 和 HTTPS 网络代理
  • 【MATLAB蓝牙定位代码】三维平面定位设计,通过N个蓝牙锚点实现对未知位置的精准定位
  • (STM32)ADC驱动配置
  • [RabbitMQ] 重试机制+TTL+死信队列
  • vue3---watch监听
  • 什么是沙箱(Sandbox)技术
  • 图像处理-简单的图像操作
  • # linux 清理指定目录下,指定时间的历史文件
  • ssm旅游推荐系统的设计与开发
  • Oracle SQL优化③——表的连接方式
  • 【数据结构-队列】力扣225. 用队列实现栈
  • 人工智能之机器学习5-回归算法1【培训机构学习笔记】
  • 【STM32】启动配置和自动串口下载
  • 性能监控利器:Ubuntu 22.04 上的 Zabbix 安装与配置指南
  • windows实现VNC连接ubuntu22.04服务器
  • Java 基础知识 (集合框架 + 并发编程 + JVM 原理 + 数据结构与算法)
  • 2023年下半年信息安全工程师《案例分析》真题答案(2)
  • 移远通信推出全新5G RedCap模组RG255AA系列,以更高性价比加速5G轻量化大规模商用
  • 中小企业人事管理自动化:SpringBoot实践
  • Oracle分析表和索引