DGL库之dgl.function.u_mul_e(代替dgl.function.src_mul_edge)
DGL库之dgl.function.u_mul_e
- 语法格式
- 例子
语法格式
dgl.function.u_mul_e代替了dgl.function.src_mul_edge
dgl.function.u_mul_e(lhs_field, rhs_field, out)
一个用于计算消息传递的内置函数,它通过对源节点(u)和边(e)的特征执行逐元素(element-wise)乘法操作来生成消息。如果源节点和边的特征形状相同,则直接进行逐元素乘法;如果形状不同,它会先广播特征到新的形状,然后再进行逐元素操作。(广播和NumPy规则一致)
参数说明:
- lhs_field (str):源节点 u 的特征字段名称。
- rhs_field (str):边 e 的特征字段名称。
- out (str):输出消息字段的名称。
例子
图结构如下:
import dgl
import torch
import dgl.function as fn
# 创建图,节点和边
g = dgl.graph(([0, 1, 2], [1, 2, 0]))
# 给边设置特征 'e_feat',确保它是浮动类型
g.edata['e_feat'] = torch.tensor([2000, 3000, 4000], dtype=torch.float32)
# 给节点设置特征 'n_feat',确保它是浮动类型
g.ndata['n_feat'] = torch.tensor([20, 21, 22], dtype=torch.float32)
# 使用 dgl.function.u_mul_e 计算消息(逐元素乘法)
g.apply_edges(fn.u_mul_e('n_feat', 'e_feat', 'msg'))
# 查看计算出来的消息
print("消息(msg):")
print(g.edata['msg'])
# 定义消息聚合函数:对消息求和
def reduce_sum(nodes):
# torch.sum目的是将数据转成列表格式
return {'n_feat': torch.sum(nodes.mailbox['msg'], dim=1)}
# 使用 send_and_recv 发送消息,并在目标节点聚合
g.send_and_recv(g.edges(), fn.u_mul_e('n_feat', 'e_feat', 'msg'), reduce_sum)
# 查看目标节点的特征(经过消息聚合更新后的节点特征)
print("目标节点更新后的特征(n_feat):")
print(g.ndata['n_feat'])
代码详解:
-
使用 dgl.function.u_mul_e 计算消息,在每条边上,使用 u_mul_e 函数计算消息,消息是节点特征和边特征的逐元素乘积。‘n_feat’ 是节点特征字段,‘e_feat’ 是边特征字段,计算结果存储在 msg 字段中,这表示每条边上生成的消息是:
- 节点0和边1的消息:20 * 2000 = 40000
- 节点1和边2的消息:21 * 3000 = 63000
- 节点2和边0的消息:22 * 4000 = 88000
-
reduce_sum 是一个聚合函数,将每个节点收到的消息赋值给节点特征 ‘n_feat’。
- 节点0 接收到的消息是来自边 (2, 0) 的消息 88000,所以节点0的聚合结果是:88000。
- 节点1 接收到的消息是来自边 (0, 1) 的消息 40000,所以节点1的聚合结果是:40000。
- 节点2 接收到的消息是来自边 (1, 2) 的消息 63000,所以节点2的聚合结果是:63000。
-
send_and_recv函数发送和接收消息,每个节点的特征通过其入边上的消息来更新。
- 节点0的更新值是来自边 (2, 0) 的消息 88000。
- 节点1的更新值是来自边 (0, 1) 的消息 40000。
- 节点2的更新值是来自边 (1, 2) 的消息 63000。
最终代码执行结果如下: