当前位置: 首页 > article >正文

LightRAG源码:NetworkXStorage测试(1)

目录

    • 测试代码
    • 代码运行逻辑解释
      • 1. 导入依赖库
      • 2. 定义工作目录
      • 3. 定义 `setup_teardown` 夹具
      • 4. 定义 `mock_embedding` 函数
      • 5. 定义 `networkx_storage` 夹具
      • 6. 测试函数
        • 6.1 `test_upsert_and_get_node`
        • 6.2 `test_upsert_and_get_edge`
        • 6.3 `test_node_degree`
        • 6.4 `test_edge_degree`
        • 6.5 `test_get_node_edges`
      • 7. 总结

测试代码

nano-graphrag跟lightrag代码类似,测试案例来自与nano-graphrag

import os
import shutil
import pytest
import networkx as nx
import numpy as np
import asyncio

import json
from lightrag import LightRAG

from lightrag.storage import NetworkXStorage
from lightrag.utils import wrap_embedding_func_with_attrs

WORKING_DIR = "./tests/nano_graphrag_cache_networkx_storage_test"


@pytest.fixture(scope="function")
def setup_teardown():
    if os.path.exists(WORKING_DIR):
        shutil.rmtree(WORKING_DIR)
    os.mkdir(WORKING_DIR)

    yield

    shutil.rmtree(WORKING_DIR)


@wrap_embedding_func_with_attrs(embedding_dim=384, max_token_size=8192)
async def mock_embedding(texts: list[str]) -> np.ndarray:
    return np.random.rand(len(texts), 384)


@pytest.fixture
def networkx_storage(setup_teardown):
    rag = LightRAG(working_dir=WORKING_DIR, embedding_func=mock_embedding)
    return NetworkXStorage(
        namespace="test",
        global_config=rag.__dict__,
    )


@pytest.mark.asyncio
async def test_upsert_and_get_node(networkx_storage):
    node_id = "node1"
    node_data = {"attr1": "value1", "attr2": "value2"}

    await networkx_storage.upsert_node(node_id, node_data)

    result = await networkx_storage.get_node(node_id)
    assert result == node_data

    has_node = await networkx_storage.has_node(node_id)
    assert has_node is True


@pytest.mark.asyncio
async def test_upsert_and_get_edge(networkx_storage):
    source_id = "node1"
    target_id = "node2"
    edge_data = {"weight": 1.0, "type": "connection"}

    await networkx_storage.upsert_node(source_id, {})
    await networkx_storage.upsert_node(target_id, {})
    await networkx_storage.upsert_edge(source_id, target_id, edge_data)

    result = await networkx_storage.get_edge(source_id, target_id)
    assert result == edge_data

    has_edge = await networkx_storage.has_edge(source_id, target_id)
    assert has_edge is True


@pytest.mark.asyncio
async def test_node_degree(networkx_storage):
    node_id = "center"
    await networkx_storage.upsert_node(node_id, {})

    num_neighbors = 5
    for i in range(num_neighbors):
        neighbor_id = f"neighbor{i}"
        await networkx_storage.upsert_node(neighbor_id, {})
        await networkx_storage.upsert_edge(node_id, neighbor_id, {})

    degree = await networkx_storage.node_degree(node_id)
    assert degree == num_neighbors


@pytest.mark.asyncio
async def test_edge_degree(networkx_storage):
    source_id = "node1"
    target_id = "node2"

    await networkx_storage.upsert_node(source_id, {})
    await networkx_storage.upsert_node(target_id, {})
    await networkx_storage.upsert_edge(source_id, target_id, {})

    num_source_neighbors = 3
    for i in range(num_source_neighbors):
        neighbor_id = f"neighbor{i}"
        await networkx_storage.upsert_node(neighbor_id, {})
        await networkx_storage.upsert_edge(source_id, neighbor_id, {})

    num_target_neighbors = 2
    for i in range(num_target_neighbors):
        neighbor_id = f"target_neighbor{i}"
        await networkx_storage.upsert_node(neighbor_id, {})
        await networkx_storage.upsert_edge(target_id, neighbor_id, {})

    expected_edge_degree = (num_source_neighbors + 1) + (num_target_neighbors + 1)
    edge_degree = await networkx_storage.edge_degree(source_id, target_id)
    assert edge_degree == expected_edge_degree


@pytest.mark.asyncio
async def test_get_node_edges(networkx_storage):
    center_id = "center"
    await networkx_storage.upsert_node(center_id, {})

    expected_edges = []
    for i in range(3):
        neighbor_id = f"neighbor{i}"
        await networkx_storage.upsert_node(neighbor_id, {})
        await networkx_storage.upsert_edge(center_id, neighbor_id, {})
        expected_edges.append((center_id, neighbor_id))

    result = await networkx_storage.get_node_edges(center_id)
    assert set(result) == set(expected_edges)

代码运行逻辑解释

1. 导入依赖库

  • osshutil:用于文件和目录操作。
  • pytest:用于编写和运行测试。
  • networkx as nx:用于创建和操作图结构。
  • numpy as np:用于生成随机数。
  • asyncio:用于异步编程。
  • json:用于处理JSON数据。
  • lightrag 相关模块:用于实现RAG(Retrieval-Augmented Generation)模型的相关功能。

2. 定义工作目录

  • WORKING_DIR:定义了一个工作目录路径,用于存储测试过程中生成的文件。

3. 定义 setup_teardown 夹具

  • 该夹具在每个测试函数运行前后执行。
    • setup:在测试开始前,检查并删除工作目录(如果存在),然后重新创建该目录。
    • teardown:在测试结束后,删除工作目录及其内容。

4. 定义 mock_embedding 函数

  • 这是一个模拟的嵌入函数,用于生成随机嵌入向量。
    • 使用 @wrap_embedding_func_with_attrs 装饰器,指定嵌入向量的维度和最大token大小。
    • 返回一个随机生成的 numpy 数组,形状为 (len(texts), 384)

5. 定义 networkx_storage 夹具

  • 该夹具用于创建 NetworkXStorage 实例。
    • 使用 LightRAG 类初始化 NetworkXStorage,并传入工作目录和模拟的嵌入函数。
    • 返回一个 NetworkXStorage 实例,用于后续测试。

6. 测试函数

6.1 test_upsert_and_get_node
  • 功能:测试节点的插入和获取。
    • 插入一个节点 node1,并为其添加一些属性。
    • 使用 get_node 方法获取该节点,并验证返回的数据是否正确。
    • 使用 has_node 方法检查节点是否存在。
6.2 test_upsert_and_get_edge
  • 功能:测试边的插入和获取。
    • 插入两个节点 node1node2
    • 插入一条边,并为其添加一些属性。
    • 使用 get_edge 方法获取该边,并验证返回的数据是否正确。
    • 使用 has_edge 方法检查边是否存在。
6.3 test_node_degree
  • 功能:测试节点的度数(即节点的邻居数量)。
    • 插入一个中心节点 center
    • 插入多个邻居节点,并将它们与中心节点连接。
    • 使用 node_degree 方法获取中心节点的度数,并验证其是否正确。
6.4 test_edge_degree
  • 功能:测试边的度数(即与边相连的节点总数)。
    • 插入两个节点 node1node2,并将它们连接。
    • 插入多个邻居节点,并将它们分别与 node1node2 连接。
    • 使用 edge_degree 方法获取边的度数,并验证其是否正确。
6.5 test_get_node_edges
  • 功能:测试获取节点的所有边。
    • 插入一个中心节点 center
    • 插入多个邻居节点,并将它们与中心节点连接。
    • 使用 get_node_edges 方法获取中心节点的所有边,并验证返回的边是否正确。

7. 总结

  • 这些测试函数主要用于验证 NetworkXStorage 类的功能,包括节点的插入、获取、边的插入、获取、节点和边的度数计算等。
  • 通过 pytestasyncio 的结合,实现了异步测试,确保代码在异步环境下的正确性。

这些测试代码的主要目的是验证 NetworkXStorage 类的功能,确保其能够正确地处理节点和边的插入、获取、度数计算等操作。


http://www.kler.cn/a/514536.html

相关文章:

  • simulink入门学习01
  • Chrome远程桌面无法连接怎么解决?
  • StarRocks 3.4 发布--AI 场景新支点,Lakehouse 能力再升级
  • 《从入门到精通:蓝桥杯编程大赛知识点全攻略》(五)-数的三次方根、机器人跳跃问题、四平方和
  • linux如何并行执行命令
  • 在k8s中部署一个可外部访问的Redis Sentinel
  • vscode如何选用不同的python的解释器
  • Yii框架中的队列:如何实现异步操作
  • MySQL(1)概述
  • # [Unity] [游戏开发]基础协程应用与实现详解
  • 基于quartz,刷新定时器的cron表达式
  • R语言学习笔记之开发环境配置
  • Spring Boot 邂逅Netty:构建高性能网络应用的奇妙之旅
  • iOS 权限管理:同时请求相机和麦克风权限的最佳实践
  • 工业网关边缘计算:智能制造的强劲引擎
  • python学习笔记4-字符串和字节转换
  • 14_音乐播放服务_字典缓存避免重复加载
  • Dart语言的云计算
  • Linux 执行 fdisk -l 出现 GPT PMBR 大小不符 解决方法
  • 一部手机如何配置内网电脑同时访问内外网
  • 【面试题】Java 多线程编程基础知识
  • 分析一个深度学习项目并设计算法和用PyTorch实现的方法和步骤
  • 五、华为 RSTP
  • React 中hooks之useSyncExternalStore使用总结
  • NS3网络模拟器中如何利用Gnuplot工具像MATLAB一样绘制各类图形?
  • Vue - ref( ) 和 reactive( ) 响应式数据的使用