当前位置: 首页 > article >正文

【DGL系列】dgl中为graph指定CSR/COO/CSC矩阵格式

转载请注明出处:小锋学长生活大爆炸[xfxuezhagn.cn]

如果本文帮助到了你,欢迎[点赞、收藏、关注]哦~

只讲几个注意事项:

1、graph.formats() 函数可以查看graph格式,也可以指定graph格式。

g = dgl.graph(([0, 0, 1], [2, 3, 2]))
g.ndata['h'] = torch.ones(4, 1)

# 查看格式
g.formats()
# => {'created': ['coo'], 'not created': ['csr', 'csc']}

# 指定一种格式
csr_g = g.formats('csr')
csr_g.formats()
# => {'created': ['csr'], 'not created': []}

# 指定多种格式
new_g = g.formats(['coo', 'csr'])
new_g .formats()
# => {'created': ['coo', 'csr'], 'not created': []}

2、在调用 formats(['coo', 'csr']) 时,如果当前图的格式与指定格式没有交集,DGL 会按照 coo -> csr -> csc 的顺序选择一种格式创建。因此,如果图在反序列化后没有 CSR 格式,调用 formats(['coo', 'csr']) 可能只会创建 COO 格式。

g = dgl.graph(([0, 0, 1], [2, 3, 2]))
g.ndata['h'] = torch.ones(4, 1)

# 假设只有一种格式
g.formats()
# => {'created': ['coo'], 'not created': ['csc']}

# 交集没有csr,就不会设置成功
new_g = g.formats(['coo', 'csr'])
new_g .formats()
# => {'created': ['coo'], 'not created': []}

3、上述第2点,虽然没有指定格式,但是可以通过graph.create_formats_来显式创建。

g = dgl.graph(([0, 0, 1], [2, 3, 2]))
g.ndata['h'] = torch.ones(4, 1)

# 假设只有一种coo格式
g.formats()
# => {'created': ['coo'], 'not created': ['csc']}

# 交集没有csr,就不会设置成功
new_g = g.formats(['coo', 'csr'])
new_g .formats()
# => {'created': ['coo'], 'not created': ['csr']}

# 显式创建格式
new_g.create_formats_()
print(new_g.formats())
# => {'created': ['coo', 'csr'], 'not created': []}

4、使用 pickle 对 DGL 图对象进行序列化和反序列化后,图的存储格式可能会丢失或被重置为 COO 格式。

import dgl
import pickle

# 创建一个图并设置多种格式
g = dgl.graph(([0, 1, 2], [1, 2, 3]))
g = g.formats(['coo', 'csr', 'csc'])

# 使用 pickle 保存
with open('graph.pkl', 'wb') as f:
    pickle.dump(g, f)

# 使用 pickle 加载
with open('graph.pkl', 'rb') as f:
    loaded_g = pickle.load(f)

# 检查加载后的格式
print(loaded_g.formats())  # 可能会丢失某些格式

5、可以考虑使用 DGL 提供的保存dgl.save_graphs和加载dgl.load_graphs方法,这些方法能够更好地处理图的内部状态,包括稀疏格式。

# 保存图
dgl.save_graphs("graph.bin", [graph])

# 加载图
loaded_graphs, _ = dgl.load_graphs("graph.bin")
graph = loaded_graphs[0]


http://www.kler.cn/a/519324.html

相关文章:

  • 第20篇:Python 开发进阶:使用Django进行Web开发详解
  • 深度学习项目--基于LSTM的糖尿病预测探究(pytorch实现)
  • 深度学习 Pytorch 单层神经网络
  • Windows Defender添加排除项无权限的解决方法
  • 2024年博客之星主题创作|2024年蓝桥杯与数学建模年度总结与心得
  • 浏览器hid 和蓝牙bluetooth技术区别
  • HTB:Support[WriteUP]
  • docker-制作镜像gcc添加jdk运行java程序
  • 2025-1-25 c++学习中关于static,初始化列表,友元函数和友元类的问题
  • 算法:模拟的巧妙演绎
  • 【MySQL】 表的操作
  • 思科交换机telnet配置案例
  • 第23篇:Python开发进阶:详解测试驱动开发(TDD)
  • ubuntu22.04 系统 A100显卡 深度学习环境配置记录
  • 嵌入式知识点总结 ARM体系与架构 专题提升(二)-ARM处理器
  • Smalltalk语言是何物?面向对象鼻祖Simula的诞生?Simula和Smalltalk有什么区别?面向对象设计?
  • 嵌入式C语言:回调函数
  • Java实现经典算法题之模拟双指针用法
  • xss靶场
  • 免费获取Photoshop及其他设计软件的使用权限
  • FastExcel的使用
  • STM32项目分享:智能语音台灯
  • 视频网站服务器为什么需要使用负载均衡?
  • Lsky-Pro在线图片搭建教程(Docker部署方式)
  • 系统思考—动态问题分析
  • AF3 AtomAttentionDecoder类源码解读