使用Pytorch Geometric建立异构图HeteroData数据集
from torch_geometric.data import HeteroData
import torch
# 创建一个 HeteroData 对象
data = HeteroData()
# 添加类型为 'type1' 的节点,这些节点有2个特征
data['user'].point = [1,3]
data['comp'].point = torch.randn(1, 2) # 假设有1个这样的节点
data['process'].point = torch.randn(1, 2)
# 添加从 'type1' 到 'type2' 的边,边的类型为 'logon'
data['user', 'logon', 'comp'].edge_index = [2, 4] # 假设有1条边,其特征为[1, 4]
# 添加从 'type2' 到 'type3' 的边,边的类型为 'create'
data['comp', 'create', 'process'].edge_index = torch.randint(200, (2, 1)) # 假设有1条边,其特征为随机值生成的tensor
print()
类似字典,.point会使得data内部的keys增一个'point',key叫什么可以自己定义
data['user']会使得data内部node_types增加'user'
点和边对应的值(如[1,3] 或 torch.randn(1,2))会各自存在stores和edges_stores下,一般存储的类型是tensor,这里为了方便对比学习,在这用了一个数组