【DGL系列】dgl中为graph指定CSR/COO/CSC矩阵格式
转载请注明出处:小锋学长生活大爆炸[xfxuezhagn.cn]
如果本文帮助到了你,欢迎[点赞、收藏、关注]哦~
只讲几个注意事项:
1、graph.formats() 函数可以查看graph格式,也可以指定graph格式。
g = dgl.graph(([0, 0, 1], [2, 3, 2]))
g.ndata['h'] = torch.ones(4, 1)
# 查看格式
g.formats()
# => {'created': ['coo'], 'not created': ['csr', 'csc']}
# 指定一种格式
csr_g = g.formats('csr')
csr_g.formats()
# => {'created': ['csr'], 'not created': []}
# 指定多种格式
new_g = g.formats(['coo', 'csr'])
new_g .formats()
# => {'created': ['coo', 'csr'], 'not created': []}
2、在调用 formats(['coo', 'csr'])
时,如果当前图的格式与指定格式没有交集,DGL 会按照 coo -> csr -> csc
的顺序选择一种格式创建。因此,如果图在反序列化后没有 CSR 格式,调用 formats(['coo', 'csr'])
可能只会创建 COO 格式。
g = dgl.graph(([0, 0, 1], [2, 3, 2]))
g.ndata['h'] = torch.ones(4, 1)
# 假设只有一种格式
g.formats()
# => {'created': ['coo'], 'not created': ['csc']}
# 交集没有csr,就不会设置成功
new_g = g.formats(['coo', 'csr'])
new_g .formats()
# => {'created': ['coo'], 'not created': []}
3、上述第2点,虽然没有指定格式,但是可以通过graph.create_formats_来显式创建。
g = dgl.graph(([0, 0, 1], [2, 3, 2]))
g.ndata['h'] = torch.ones(4, 1)
# 假设只有一种coo格式
g.formats()
# => {'created': ['coo'], 'not created': ['csc']}
# 交集没有csr,就不会设置成功
new_g = g.formats(['coo', 'csr'])
new_g .formats()
# => {'created': ['coo'], 'not created': ['csr']}
# 显式创建格式
new_g.create_formats_()
print(new_g.formats())
# => {'created': ['coo', 'csr'], 'not created': []}
4、使用 pickle
对 DGL 图对象进行序列化和反序列化后,图的存储格式可能会丢失或被重置为 COO 格式。
import dgl
import pickle
# 创建一个图并设置多种格式
g = dgl.graph(([0, 1, 2], [1, 2, 3]))
g = g.formats(['coo', 'csr', 'csc'])
# 使用 pickle 保存
with open('graph.pkl', 'wb') as f:
pickle.dump(g, f)
# 使用 pickle 加载
with open('graph.pkl', 'rb') as f:
loaded_g = pickle.load(f)
# 检查加载后的格式
print(loaded_g.formats()) # 可能会丢失某些格式
5、可以考虑使用 DGL 提供的保存dgl.save_graphs和加载dgl.load_graphs方法,这些方法能够更好地处理图的内部状态,包括稀疏格式。
# 保存图
dgl.save_graphs("graph.bin", [graph])
# 加载图
loaded_graphs, _ = dgl.load_graphs("graph.bin")
graph = loaded_graphs[0]