当前位置: 首页 > article >正文

DGL库之dgl.function.u_mul_e(代替dgl.function.src_mul_edge)

DGL库之dgl.function.u_mul_e

  • 语法格式
  • 例子

语法格式

dgl.function.u_mul_e代替了dgl.function.src_mul_edge

dgl.function.u_mul_e(lhs_field, rhs_field, out)

一个用于计算消息传递的内置函数,它通过对源节点(u)和边(e)的特征执行逐元素(element-wise)乘法操作来生成消息。如果源节点和边的特征形状相同,则直接进行逐元素乘法;如果形状不同,它会先广播特征到新的形状,然后再进行逐元素操作。(广播和NumPy规则一致)

参数说明:

  • lhs_field (str):源节点 u 的特征字段名称。
  • rhs_field (str):边 e 的特征字段名称。
  • out (str):输出消息字段的名称。

例子

图结构如下:
在这里插入图片描述

import dgl
import torch
import dgl.function as fn

# 创建图,节点和边
g = dgl.graph(([0, 1, 2], [1, 2, 0]))

# 给边设置特征 'e_feat',确保它是浮动类型
g.edata['e_feat'] = torch.tensor([2000, 3000, 4000], dtype=torch.float32)

# 给节点设置特征 'n_feat',确保它是浮动类型
g.ndata['n_feat'] = torch.tensor([20, 21, 22], dtype=torch.float32)

# 使用 dgl.function.u_mul_e 计算消息(逐元素乘法)
g.apply_edges(fn.u_mul_e('n_feat', 'e_feat', 'msg'))

# 查看计算出来的消息
print("消息(msg):")
print(g.edata['msg'])

# 定义消息聚合函数:对消息求和
def reduce_sum(nodes):
    # torch.sum目的是将数据转成列表格式
    return {'n_feat': torch.sum(nodes.mailbox['msg'], dim=1)}

# 使用 send_and_recv 发送消息,并在目标节点聚合
g.send_and_recv(g.edges(), fn.u_mul_e('n_feat', 'e_feat', 'msg'), reduce_sum)

# 查看目标节点的特征(经过消息聚合更新后的节点特征)
print("目标节点更新后的特征(n_feat):")
print(g.ndata['n_feat'])

代码详解:

  • 使用 dgl.function.u_mul_e 计算消息,在每条边上,使用 u_mul_e 函数计算消息,消息是节点特征和边特征的逐元素乘积。‘n_feat’ 是节点特征字段,‘e_feat’ 是边特征字段,计算结果存储在 msg 字段中,这表示每条边上生成的消息是:

    • 节点0和边1的消息:20 * 2000 = 40000
    • 节点1和边2的消息:21 * 3000 = 63000
    • 节点2和边0的消息:22 * 4000 = 88000
  • reduce_sum 是一个聚合函数,将每个节点收到的消息赋值给节点特征 ‘n_feat’。

    • 节点0 接收到的消息是来自边 (2, 0) 的消息 88000,所以节点0的聚合结果是:88000。
    • 节点1 接收到的消息是来自边 (0, 1) 的消息 40000,所以节点1的聚合结果是:40000。
    • 节点2 接收到的消息是来自边 (1, 2) 的消息 63000,所以节点2的聚合结果是:63000。
  • send_and_recv函数发送和接收消息,每个节点的特征通过其入边上的消息来更新。

    • 节点0的更新值是来自边 (2, 0) 的消息 88000。
    • 节点1的更新值是来自边 (0, 1) 的消息 40000。
    • 节点2的更新值是来自边 (1, 2) 的消息 63000。

最终代码执行结果如下:
在这里插入图片描述


http://www.kler.cn/a/382350.html

相关文章:

  • 在Java中,实现数据库连接通常使用JDBC
  • 基于梯度的快速准确头部运动补偿方法在锥束CT中的应用|文献速递-基于深度学习的病灶分割与数据超分辨率
  • 工作中问题
  • C# 独立线程
  • 机器学习—构建一个神经网络
  • 【Java语言】继承和多态(一)
  • 模拟实现strcat函数
  • 线程池核心参数有哪些
  • Vue 组件传递数据-Props(六)
  • Vue+Springboot 前后端分离项目如何部署?
  • 【FPGA】Verilog:理解德摩根第一定律: ( ̅A + ̅B) = ̅A x ̅B
  • 爬虫下载网页文夹
  • 【C++刷题】力扣-#697-数组的度
  • 【人工智能】Transformers之Pipeline(二十二):零样本文本分类(zero-shot-classification)
  • 7.2 设计模式
  • [WSL][桌面][X11]WSL2 Ubuntu22.04 安装Ubuntu桌面并且实现GUI转发(Gnome)
  • 【论文阅读】-- 多元时间序列聚类算法综述
  • Sigrity Power SI 3D-EM Full Wave Extraction模式如何进行S参数提取和观测3D电磁场和远场操作指导(一)
  • “再探构造函数”(2)
  • 解释器模式:有效处理语言的设计模式
  • Redis 权限控制(ACL)|ACL 命令详解、ACL 持久化
  • 【题解】CF2033G
  • ThinkPHP腾讯云国际短信对接
  • W5100S-EVB-Pico2评估板介绍
  • 史上最全盘点:一文告诉你低代码(Low-Code)是什么?为什么要用?
  • 【青牛科技】GC8549替代LV8549/ONSEMI在摇头机、舞台灯、打印机和白色家电等产品上的应用分析