LightRAG源码:NetworkXStorage测试(1)
目录
- 测试代码
- 代码运行逻辑解释
- 1. 导入依赖库
- 2. 定义工作目录
- 3. 定义 `setup_teardown` 夹具
- 4. 定义 `mock_embedding` 函数
- 5. 定义 `networkx_storage` 夹具
- 6. 测试函数
- 6.1 `test_upsert_and_get_node`
- 6.2 `test_upsert_and_get_edge`
- 6.3 `test_node_degree`
- 6.4 `test_edge_degree`
- 6.5 `test_get_node_edges`
- 7. 总结
测试代码
nano-graphrag跟lightrag代码类似,测试案例来自与nano-graphrag
import os
import shutil
import pytest
import networkx as nx
import numpy as np
import asyncio
import json
from lightrag import LightRAG
from lightrag.storage import NetworkXStorage
from lightrag.utils import wrap_embedding_func_with_attrs
WORKING_DIR = "./tests/nano_graphrag_cache_networkx_storage_test"
@pytest.fixture(scope="function")
def setup_teardown():
if os.path.exists(WORKING_DIR):
shutil.rmtree(WORKING_DIR)
os.mkdir(WORKING_DIR)
yield
shutil.rmtree(WORKING_DIR)
@wrap_embedding_func_with_attrs(embedding_dim=384, max_token_size=8192)
async def mock_embedding(texts: list[str]) -> np.ndarray:
return np.random.rand(len(texts), 384)
@pytest.fixture
def networkx_storage(setup_teardown):
rag = LightRAG(working_dir=WORKING_DIR, embedding_func=mock_embedding)
return NetworkXStorage(
namespace="test",
global_config=rag.__dict__,
)
@pytest.mark.asyncio
async def test_upsert_and_get_node(networkx_storage):
node_id = "node1"
node_data = {"attr1": "value1", "attr2": "value2"}
await networkx_storage.upsert_node(node_id, node_data)
result = await networkx_storage.get_node(node_id)
assert result == node_data
has_node = await networkx_storage.has_node(node_id)
assert has_node is True
@pytest.mark.asyncio
async def test_upsert_and_get_edge(networkx_storage):
source_id = "node1"
target_id = "node2"
edge_data = {"weight": 1.0, "type": "connection"}
await networkx_storage.upsert_node(source_id, {})
await networkx_storage.upsert_node(target_id, {})
await networkx_storage.upsert_edge(source_id, target_id, edge_data)
result = await networkx_storage.get_edge(source_id, target_id)
assert result == edge_data
has_edge = await networkx_storage.has_edge(source_id, target_id)
assert has_edge is True
@pytest.mark.asyncio
async def test_node_degree(networkx_storage):
node_id = "center"
await networkx_storage.upsert_node(node_id, {})
num_neighbors = 5
for i in range(num_neighbors):
neighbor_id = f"neighbor{i}"
await networkx_storage.upsert_node(neighbor_id, {})
await networkx_storage.upsert_edge(node_id, neighbor_id, {})
degree = await networkx_storage.node_degree(node_id)
assert degree == num_neighbors
@pytest.mark.asyncio
async def test_edge_degree(networkx_storage):
source_id = "node1"
target_id = "node2"
await networkx_storage.upsert_node(source_id, {})
await networkx_storage.upsert_node(target_id, {})
await networkx_storage.upsert_edge(source_id, target_id, {})
num_source_neighbors = 3
for i in range(num_source_neighbors):
neighbor_id = f"neighbor{i}"
await networkx_storage.upsert_node(neighbor_id, {})
await networkx_storage.upsert_edge(source_id, neighbor_id, {})
num_target_neighbors = 2
for i in range(num_target_neighbors):
neighbor_id = f"target_neighbor{i}"
await networkx_storage.upsert_node(neighbor_id, {})
await networkx_storage.upsert_edge(target_id, neighbor_id, {})
expected_edge_degree = (num_source_neighbors + 1) + (num_target_neighbors + 1)
edge_degree = await networkx_storage.edge_degree(source_id, target_id)
assert edge_degree == expected_edge_degree
@pytest.mark.asyncio
async def test_get_node_edges(networkx_storage):
center_id = "center"
await networkx_storage.upsert_node(center_id, {})
expected_edges = []
for i in range(3):
neighbor_id = f"neighbor{i}"
await networkx_storage.upsert_node(neighbor_id, {})
await networkx_storage.upsert_edge(center_id, neighbor_id, {})
expected_edges.append((center_id, neighbor_id))
result = await networkx_storage.get_node_edges(center_id)
assert set(result) == set(expected_edges)
代码运行逻辑解释
1. 导入依赖库
os
和shutil
:用于文件和目录操作。pytest
:用于编写和运行测试。networkx as nx
:用于创建和操作图结构。numpy as np
:用于生成随机数。asyncio
:用于异步编程。json
:用于处理JSON数据。lightrag
相关模块:用于实现RAG(Retrieval-Augmented Generation)模型的相关功能。
2. 定义工作目录
WORKING_DIR
:定义了一个工作目录路径,用于存储测试过程中生成的文件。
3. 定义 setup_teardown
夹具
- 该夹具在每个测试函数运行前后执行。
setup
:在测试开始前,检查并删除工作目录(如果存在),然后重新创建该目录。teardown
:在测试结束后,删除工作目录及其内容。
4. 定义 mock_embedding
函数
- 这是一个模拟的嵌入函数,用于生成随机嵌入向量。
- 使用
@wrap_embedding_func_with_attrs
装饰器,指定嵌入向量的维度和最大token大小。 - 返回一个随机生成的
numpy
数组,形状为(len(texts), 384)
。
- 使用
5. 定义 networkx_storage
夹具
- 该夹具用于创建
NetworkXStorage
实例。- 使用
LightRAG
类初始化NetworkXStorage
,并传入工作目录和模拟的嵌入函数。 - 返回一个
NetworkXStorage
实例,用于后续测试。
- 使用
6. 测试函数
6.1 test_upsert_and_get_node
- 功能:测试节点的插入和获取。
- 插入一个节点
node1
,并为其添加一些属性。 - 使用
get_node
方法获取该节点,并验证返回的数据是否正确。 - 使用
has_node
方法检查节点是否存在。
- 插入一个节点
6.2 test_upsert_and_get_edge
- 功能:测试边的插入和获取。
- 插入两个节点
node1
和node2
。 - 插入一条边,并为其添加一些属性。
- 使用
get_edge
方法获取该边,并验证返回的数据是否正确。 - 使用
has_edge
方法检查边是否存在。
- 插入两个节点
6.3 test_node_degree
- 功能:测试节点的度数(即节点的邻居数量)。
- 插入一个中心节点
center
。 - 插入多个邻居节点,并将它们与中心节点连接。
- 使用
node_degree
方法获取中心节点的度数,并验证其是否正确。
- 插入一个中心节点
6.4 test_edge_degree
- 功能:测试边的度数(即与边相连的节点总数)。
- 插入两个节点
node1
和node2
,并将它们连接。 - 插入多个邻居节点,并将它们分别与
node1
和node2
连接。 - 使用
edge_degree
方法获取边的度数,并验证其是否正确。
- 插入两个节点
6.5 test_get_node_edges
- 功能:测试获取节点的所有边。
- 插入一个中心节点
center
。 - 插入多个邻居节点,并将它们与中心节点连接。
- 使用
get_node_edges
方法获取中心节点的所有边,并验证返回的边是否正确。
- 插入一个中心节点
7. 总结
- 这些测试函数主要用于验证
NetworkXStorage
类的功能,包括节点的插入、获取、边的插入、获取、节点和边的度数计算等。 - 通过
pytest
和asyncio
的结合,实现了异步测试,确保代码在异步环境下的正确性。
这些测试代码的主要目的是验证 NetworkXStorage
类的功能,确保其能够正确地处理节点和边的插入、获取、度数计算等操作。