当前位置: 首页 > article >正文

【Ragflow】2. rag检索原理和效率解析

概述

本文是ragflow内容解析系列的第二篇。本文将详细解析Ragflow是进行一轮信息检索的过程,并通过实验对比 Elasticsearch和 Infinity在检索效率上的差异。

1. 前文错误更正

首先更正一下前文Ragflow技术栈分析及二次开发指南中存在的一处错误,文中错误将ragflow-infinity写做为前端系统,实际上,infinity是ragflow官方开发的高效向量数据库和搜索引擎,和elasticsearch功能类似。

2. Elasticsearch简介与可视化

2.1 Elasticvue可视化

Elasticsearch 是一个开源的分布式搜索和分析引擎。

前文中,并没有对Elasticsearch进行可视化,在此节中,使用比Kibana更轻量化的Elasticvue可视化工具对Elasticsearch内容进行查看。

Elasticvue下载链接:https://elasticvue.com/installation

docker中,启动Elasticsearch容器。

在这里插入图片描述

使用Elasticvue连接本地1200端口:

默认用户名:elastic,默认密码:infini_rag_flow

进入管理首页,可以看到Elasticsearch中,存在1个节点(nodes),2个分片(shards),1个索引(indices)

在这里插入图片描述

2.2 集群

在Elasticsearch中,集群(Cluster)是最大的单位,集群是由一个或多个节点(Node)组成的分布式系统。集群可以自动进行负载均衡,将搜索请求和索引请求分配到各个节点上,以实现数据的均衡存储和处理[1]。

连接的1200端口,就是默认的一个集群。

在这里插入图片描述

2.3 节点

集群的下一级单位是节点(Node),每个节点都是一个独立的工作单元,负责存储数据、参与数据处理(如索引、搜索、聚合等)。

节点主要有以下类型[1][2]:

  • 主节点(Master Node):负责集群范围内的元数据管理和变更,如索引创建、删除、分片分配等。
  • 主节点候选(Master-eligible Node):每一个节点启动后,默认就是一个主节点候选, 可以参加选主流程,成为主节点。当第一个节点启动时候,它会将自己选举成主节点。
  • 数据节点(Data Node):负责保存分片上存储的所有数据,当集群无法保存现有数据的时候,可以通过增加数据节点来解决存储上的问题。
  • 协调节点(Coordinating Node):负责接收 Client 的请求,将请求分发到合适的节点,最终把结果汇集到一起返回给客户端,每个节点默认都起到了协调节点的职责。
  • 冷热节点(Hot & Warm Node) :热节点(Hot Node)就是配置高的节点,可以有更好的磁盘吞吐量和更好的 CPU,冷节点(Warm Node)存储一些比较久的节点,这些节点的机器配置会比较低。
  • 预处理节点(Ingest Node):预处理操作允许在索引文档之前,即写入数据之前,通过事先定义好的一系列的 processors(处理器)和 pipeline(管道),对数据进行某种转换。

在此项目中,只有一个节点,可以看到,它既是主节点,也是数据节点和预处理节点。

在这里插入图片描述

2.4 索引

每个节点可以包含多个索引,索引相当于关系型数据库中的数据表,在 ES 中,索引是一类文档的集合。

在此项目中,只有一个索引,里面有8个分段(ES将数据分成多个段,提升搜索性能),57个文档(类似于数据库的记录),这里的文档记录是我上传了一个文件,然后切分成了57个chunk。

在这里插入图片描述

索引中有很多字段,比较主要的有以下几个字段:

  • title_tks:文档标题的分词结果,用于标题的精确匹配和相关性排序
  • title_sm_tks:文档标题的细粒度分词结果,捕获标题中的更细粒度语义单元,提高召回率
  • content_with_weight:原始内容文本,存储原始文本内容,用于结果展示和后处理
  • content_ltks:文档内容的标准分词结果,用于全文检索的主要字段
  • content_sm_ltks:文档内容的细粒度分词结果,提高内容检索的召回率
  • page_num_int:文档的页码信息,在搜索结果中展示页码信息
  • position_int:文档或段落在集合中的位置信息,在搜索结果中作为返回字段之一
  • top_int:文档的优先级,用于自定义排序
  • q_1024_vec:1024维的密集向量表示,存储文档的语义向量表示,用于向量相似度搜索

2.5 分片

不同于前面几个概念,分片更多是物理层面上的优化。由于单台机器无法存储大量数据,ES 可以将一个索引中的数据切分为多个分片(Shard),分布在多台服务器上存储。有了分片就可以横向扩展,存储更多数据,让搜索和分析等操作分布到多台服务器上去执行,提升吞吐量和性能[2]。

分片可分成主分片(Primary Shard)和副本分片(Replica Shard),主分片用于将数据进行扩展,副本分片是主分片的拷贝,当主分片故障时,保证数据不会丢失。

在此项目中,es的相关设定在conf/mapping.json文件中进行设置,默认只指定了2个主分片,未设定副本分片。

"index": {
  "number_of_shards": 2,
  "number_of_replicas": 0,
}

3. 检索过程解析

3.1 检索过程分解

在RAGFlow中,检索内容主要分以下几个步骤:

  1. 用户查询处理 :用户的问题首先被处理成关键词和向量表示
  2. 检索执行 :使用处理后的查询在知识库中检索相关内容
  3. 结果处理 :对检索结果进行处理和重排序
  4. 提示词构建 :将检索到的内容构建成提示词

这部分的核心代码为rag/nlp/search.py,其主要包括以下几个函数:

  • get_vector
    这个方法用于获取文本的向量表示,并创建一个 MatchDenseExpr 对象用于向量搜索。它接收文本、嵌入模型、topk和相似度阈值作为参数。

  • search
    这是核心搜索方法,处理搜索请求并返回搜索结果。它支持纯关键词搜索和混合搜索(关键词+向量)。

  • rerank
    对搜索结果进行重排序,结合关键词相似度和向量相似度。

  • rerank_by_model
    使用专门的重排序模型对搜索结果进行重排序。

  • retrieval
    这是一个高级搜索方法,整合了搜索、重排序和结果处理的完整流程。

3.2 检索过程模拟

为了更好地说清楚整个检索流程,我选取了部分核心代码,重构了单独的检索py脚本。

尽管在此搜索过程中,并未用到 infinity,但在相关依赖初始化中,import infinity,对于此依赖,直接通过pip uninstall infinity安装会出现版本问题,正确版本安装方式为:

pip install infinity-sdk==0.6.0.dev3

为了方便环境构建,后面我会将使用的环境依赖打包出一个requirements.txt,放在本文附录,方便读者用pip install -r requirements.txt进行安装。

脚本内容如下:

import sys
import os
import logging
import json
import numpy as np
from timeit import default_timer as timer
from collections import defaultdict

# 添加项目根目录到路径
sys.path.append(os.path.abspath("."))

# 导入必要的模块
from rag.utils.doc_store_conn import MatchTextExpr, MatchDenseExpr, FusionExpr, OrderByExpr
from rag.utils.es_conn import ESConnection
from rag.nlp.query import FulltextQueryer
from rag.nlp import rag_tokenizer
from rag.llm.embedding_model import DefaultEmbedding
from rag import settings


# NLTK资源下载
import nltk
try:
    nltk.data.find('tokenizers/punkt_tab')
except LookupError:
    nltk.download('punkt_tab')

# 配置日志
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger('retrieval_test')

class RetrievalProcess:
    """复刻 RAGFlow 中的检索过程"""
    
    def __init__(self, tenant_id="default", kb_ids=None, embedding_model_name=None):
        """初始化检索过程"""
        self.tenant_id = tenant_id
        self.kb_ids = kb_ids or []
        self.es_conn = ESConnection()
        self.queryer = FulltextQueryer()
        
        # 获取嵌入模型
        if embedding_model_name:
            # 修改嵌入模型的初始化方式
            self.embd_mdl = DefaultEmbedding(key=None, model_name=embedding_model_name)
        else:
            # 使用默认嵌入模型
            self.embd_mdl = None
            
        # 向量相似度权重 (1-keywords_similarity_weight)
        self.vector_similarity_weight = 0.3 
        # 关键词相似度权重
        self.keywords_similarity_weight = 0.7 
        # 相似度阈值
        self.similarity_threshold = 0.2
        # 检索结果数量
        self.top_n = 5
        
    def process_query(self, question):
        """处理用户查询"""
        logger.info(f"处理用户查询: {question}")
        
        # 1. 文本处理和分词
        start = timer()
        match_expr, keywords = self.queryer.question(question)
        end = timer()
        logger.info(f"查询处理耗时: {(end - start) * 1000:.2f}ms")
        logger.info(f"提取的关键词: {keywords}")
        
        # 2. 生成向量表示
        if self.embd_mdl:
            start = timer()
            query_vector, _ = self.embd_mdl.encode([question])
            query_vector = query_vector[0]
            end = timer()
            logger.info(f"向量编码耗时: {(end - start) * 1000:.2f}ms")
        else:
            query_vector = None
            logger.warning("未配置嵌入模型,将仅使用关键词检索")
        
        return match_expr, keywords, query_vector
    
    def build_search_query(self, match_expr, query_vector=None):
        """构建搜索查询"""
        # 准备查询条件
        select_fields = ["id", "title", "content", "content_ltks", "content_sm_ltks", "kb_id", "doc_id", "docnm_kwd"]
        highlight_fields = ["title", "content"]
        condition = {"kb_id": self.kb_ids} if self.kb_ids else {}
        
        # 构建匹配表达式列表
        match_exprs = []
        
        # 添加文本匹配表达式
        if match_expr:
            match_exprs.append(match_expr)
        
        # 添加向量匹配表达式
        if query_vector is not None:
            # 确定向量字段名称
            vector_field = f"q_{len(query_vector)}_vec"
            # 添加向量字段到选择字段
            select_fields.append(vector_field)
            
            # 创建向量匹配表达式
            vector_expr = MatchDenseExpr(
                vector_column_name=vector_field,
                embedding_data=query_vector.tolist(),
                embedding_data_type='float',
                distance_type='cosine',
                topn=self.top_n,
                extra_options={"similarity": 0.1}  
            )
            match_exprs.append(vector_expr)
            
            # 添加融合表达式
            fusion_expr = FusionExpr(
                "weighted_sum", 
                self.top_n, 
                {"weights": f"{self.keywords_similarity_weight}, {self.vector_similarity_weight}"}
            )
            match_exprs.append(fusion_expr)
        
        return select_fields, highlight_fields, condition, match_exprs
    
    def search(self, question, index_names):
        """执行搜索"""
        # 处理查询
        match_expr, keywords, query_vector = self.process_query(question)
        
        # 构建搜索查询
        select_fields, highlight_fields, condition, match_exprs = self.build_search_query(match_expr, query_vector)
        
        # 执行搜索
        start = timer()
        search_response = self.es_conn.search(
            selectFields=select_fields,
            highlightFields=highlight_fields,
            condition=condition,
            matchExprs=match_exprs,
            orderBy=None,
            offset=0,
            limit=self.top_n,
            indexNames=index_names,
            knowledgebaseIds=self.kb_ids
        )
        
        # 从响应中提取实际的结果列表
        results = search_response.get("hits", {}).get("hits", [])
        end = timer()
        logger.info(f"搜索耗时: {(end - start) * 1000:.2f}ms")
        logger.info(f"搜索结果数量: {len(results)}")
        
        # 如果结果为空且使用了向量搜索,尝试降低匹配阈值重新搜索
        if len(results) == 0 and query_vector is not None:
            logger.info("搜索结果为空,尝试降低匹配阈值重新搜索")
            match_expr, _ = self.queryer.question(question, min_match=0.1)  # 降低min_match
            select_fields, highlight_fields, condition, match_exprs = self.build_search_query(match_expr, query_vector)
            
            # 修改向量匹配表达式的相似度阈值
            for expr in match_exprs:
                if isinstance(expr, MatchDenseExpr):
                    expr.extra_options["similarity"] = 0.17  # 提高similarity阈值
            
            # 重新执行搜索
            start = timer()
            search_response = self.es_conn.search(
                selectFields=select_fields,
                highlightFields=highlight_fields,
                condition=condition,
                matchExprs=match_exprs,
                orderBy=None,
                offset=0,
                limit=self.top_n,
                indexNames=index_names,
                knowledgebaseIds=self.kb_ids
            )
            results = search_response.get("hits", {}).get("hits", [])
            end = timer()
            logger.info(f"重新搜索耗时: {(end - start) * 1000:.2f}ms")
            logger.info(f"重新搜索结果数量: {len(results)}")
        
        # 如果有向量和分词,计算混合相似度并重新排序
        if query_vector is not None and len(results) > 0:
            results = self.rerank_results(results, query_vector, keywords)
        
        return results
    
    def rerank_results(self, results, query_vector, query_tokens):
        """重新排序搜索结果"""
        start = timer()
        
        # 提取文档向量和分词
        doc_vectors = []
        doc_tokens = []
        
        for result in results:
            # 提取向量
            doc_vectors.append(np.array(result["_source"]["q_1024_vec"]))
            # 提取分词
            doc_tokens.append(result["_source"]["content_ltks"])
     
        # 计算混合相似度
        query_tokens_str = " ".join(query_tokens)
        hybrid_scores, token_scores, vector_scores = self.queryer.hybrid_similarity(
            query_vector, doc_vectors, query_tokens_str, doc_tokens,
            tkweight=self.keywords_similarity_weight,
            vtweight=self.vector_similarity_weight
        )
        
        # 为结果添加分数
        for i, result in enumerate(results):
            result["hybrid_score"] = float(hybrid_scores[i])
            result["token_score"] = float(token_scores[i])
            result["vector_score"] = float(vector_scores[i])

        # 查看各分数信息    
        logger.info(f"混合分数: {hybrid_scores}")
        logger.info(f"关键词分数: {token_scores}")
        logger.info(f"向量分数: {vector_scores}")
            

        # 按混合分数重新排序
        results.sort(key=lambda x: x["hybrid_score"], reverse=True)
        
        # 过滤低于阈值的结果
        results[:] = [r for r in results if r["hybrid_score"] >= self.similarity_threshold]
        
        end = timer()
        logger.info(f"重排序耗时: {(end - start) * 1000:.2f}ms")
        logger.info(f"重排序后结果数量: {len(results)}")
        
        return results
    
    def build_llm_prompt(self, question, results, system_prompt=None):
        """构建输入到LLM的提示"""
        if system_prompt is None:
            system_prompt = """你是一个智能助手。请基于提供的上下文信息回答用户的问题。
        如果上下文中没有足够的信息来回答问题,请说明你无法回答,不要编造信息。
        回答时请引用相关的上下文信息,并标明引用的来源。"""
        
        # 按文档组织内容
        doc2chunks = defaultdict(lambda: {"chunks": []})
        
        # 调试信息:打印结果结构
        # logger.info(f"搜索结果结构示例: {json.dumps(results[0] if results else {}, ensure_ascii=False, indent=2)[:500]}...")
        
        for i, result in enumerate(results):
            # 获取文档名称 - 从_source字段中提取
            source = result.get("_source", {})
            doc_name = source.get("docnm_kwd", f"文档{i+1}")
            
            # 提取内容
            content = source.get("content_with_weight", "")
            
            # 调试信息
            # logger.info(f"文档 {doc_name} 内容长度: {len(content)} 字符")
            # if content:
            #     logger.info(f"内容预览: {content[:100]}...")
            # else:
            #     logger.info("警告: 内容为空!")
            
            # 只有当内容不为空时才添加到对应文档的chunks中
            if content:
                doc2chunks[doc_name]["chunks"].append(content)
        
        # 构建上下文信息
        context_parts = []
        
        for doc_name, doc_info in doc2chunks.items():
            if not doc_info["chunks"]:  # 跳过没有内容的文档
                continue
                
            txt = f"Document: {doc_name}\n"
            txt += "Relevant fragments as following:\n"
            
            for i, chunk in enumerate(doc_info["chunks"], 1):
                txt += f"{i}. {chunk}\n"
            
            context_parts.append(txt)
        
        # 合并上下文
        context = "\n\n".join(context_parts)
        
        # 如果没有有效的上下文,添加提示信息
        if not context_parts:
            context = "未找到与问题相关的文档内容。"
            logger.warning("警告: 没有找到有效的文档内容!")
        
        # 构建完整提示
        messages = [
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": f"我需要回答以下问题:\n\n{question}\n\n以下是相关的上下文信息:\n\n{context}"}
        ]
        
        return messages
    
    def format_llm_prompt(self, messages):
        """格式化LLM提示以便于查看"""
        formatted = []
        
        for msg in messages:
            role = msg["role"]
            content = msg["content"]
            
            # 格式化不同角色的消息
            if role == "system":
                formatted.append(f"### 系统指令\n{content}")
            elif role == "user":
                formatted.append(f"### 用户输入\n{content}")
            elif role == "assistant":
                formatted.append(f"### 助手回复\n{content}")
            else:
                formatted.append(f"### {role}\n{content}")
        
        return "\n\n".join(formatted)

def list_indices():
    """列出所有索引"""
    es_conn = ESConnection()
    indices = es_conn.es.indices.get_alias(index="*")
    return list(indices.keys())

def run_test():
    """运行测试"""
    logger.info("开始检索测试...")
    
    # 列出所有索引
    indices = list_indices()
    logger.info(f"可用索引: {indices}")
    
    if not indices:
        logger.error("没有找到任何索引,请确保 Elasticsearch 中有数据")
        return
    
    # 选择第一个索引进行测试
    selected_index = indices[0]
    logger.info(f"选择索引 {selected_index} 进行测试")
    
    # 创建检索过程实例
    # retrieval = RetrievalProcess()
    # 创建检索过程实例时添加嵌入模型
    retrieval = RetrievalProcess(embedding_model_name="BAAI/bge-large-zh-v1.5")
    
    # 测试一些查询
    test_queries = [
        "事件相机是什么?"
    ]
    
    for query in test_queries:
        logger.info(f"\n测试查询: {query}")
        results = retrieval.search(query, selected_index)
        
        # 构建并打印LLM输入提示
        if results:
            logger.info("\n=== 输入到LLM的完整内容 ===")
            messages = retrieval.build_llm_prompt(query, results)
            formatted_prompt = retrieval.format_llm_prompt(messages)
            logger.info(formatted_prompt)
            
        else:
            logger.info("没有找到相关结果,无法构建LLM输入")

if __name__ == "__main__":
    run_test()

第一次运行时,需要先自动下载punkt_tab,下载的默认路径是C:\Users\UserName\AppData\Roaming\nltk_data\tokenizers,punkt_tab是一个句子分割器模型,用于将文本分割成句子。它是一个预训练的无监督模型,能够识别句子边界,特别是能够区分句号、问号和感叹号是作为句子结束符还是作为缩写、数字等的一部分。

其次是需要下载BAAI/bge-large-zh-v1.5,下载的默认路径是C:\Users\UserName\.ragflow\bge-large-zh-v1.5,这是ragflow默认采用的embedding模型,用于将中文文本向量化。

首先,对于用户输入的question,ragflow会先进行分词,分词的目的是可以将分词结果和知识库中的内容分词进行关键词匹配。

之后,如果添加了embedding模型,ragflow会将question进行embedding成特征向量query_vector。

分词结果和向量结果之后,进一步构造搜索模板,调用rag/utils/es_conn.pysearch方法进行搜索查询,对于分词结果,es使用词项匹配和相关性评分的方式进行查询;对于向量结果,es查询采用的优化版的knn算法。

最终,在得到这两个结果之后,进行重排序(rerank)过程,在不设定rerank模型的情况下,ragflow会将分词查询结果和向量查询结果进行加权求和,并将结果重新排序,选取top-N个文本整合进后面的prompt,默认权重为关键词查询0.7,向量查询0.3。

运行上述代码输出结果示例如下:

2025-03-16 21:53:29,715 - retrieval_test - INFO - 处理用户查询: 事件相机是什么?
2025-03-16 21:53:29,718 - retrieval_test - INFO - 查询处理耗时: 3.36ms
2025-03-16 21:53:29,719 - retrieval_test - INFO - 提取的关键词: ['事件相机', '事件', '相机']
You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
2025-03-16 21:53:29,979 - retrieval_test - INFO - 向量编码耗时: 260.08ms
2025-03-16 21:53:30,038 - elastic_transport.transport - INFO - POST http://localhost:1200/ragflow_d3de1596fb6911efa0f40242ac120006/_search [status:200 duration:0.052s]
2025-03-16 21:53:30,045 - retrieval_test - INFO - 搜索耗时: 65.94ms
2025-03-16 21:53:30,045 - retrieval_test - INFO - 搜索结果数量: 5
2025-03-16 21:53:30,058 - retrieval_test - INFO - 混合分数: [0.66530286 0.6507986  0.63286258 0.63326174 0.5953948 ]
2025-03-16 21:53:30,058 - retrieval_test - INFO - 关键词分数: [0.6695036911658396, 0.6695036911658396, 0.6695036911658396, 0.6695036911658396, 0.6695036911658396]
2025-03-16 21:53:30,059 - retrieval_test - INFO - 向量分数: [0.65550092 0.60715338 0.54736667 0.54869719 0.42247404]
2025-03-16 21:53:30,059 - retrieval_test - INFO - 重排序耗时: 12.69ms
2025-03-16 21:53:30,059 - retrieval_test - INFO - 重排序后结果数量: 5
2025-03-16 21:53:30,059 - retrieval_test - INFO -
=== 输入到LLM的完整内容 ===
2025-03-16 21:53:30,059 - retrieval_test - INFO - ### 系统指令
你是一个智能助手。请基于提供的上下文信息回答用户的问题。
        如果上下文中没有足够的信息来回答问题,请说明你无法回答,不要编造信息。
        回答时请引用相关的上下文信息,并标明引用的来源。

### 用户输入
我需要回答以下问题:

事件相机是什么?

以下是相关的上下文信息:

Document: 基于事件相机的可视化及降噪算法.pdf
Relevant fragments as following:
1. 事件相机在实际应用上还存在问题。一方面事件相机由于其本身结构对环境亮度变化十分敏感,在输出的异步事件流中包 
含大量噪声干扰。噪声可能来源于数字信号传输时的脉冲噪声以及光电二极管所引起的高斯噪声等,对于进一步的事件相机的
特性使其吸引了很多领域研究人//员的关注,比如在图像信息处理方面,关注的是图:到的图像存在很多冗余信息,增加了计 
算量和对
2./:中图分类号:V249.32+5;TN957.52SNp 文献标志码:A 文章编号:10015965(202
1)0203420920b  事件相机(EventCamera)是一种新兴的生物硬件水平的要求,事件相机可以直接
输出稀疏的amerah 视觉传感器,传统帧相机(FramebasedC)通运动边缘信息,简化计算。此外,事件相 
机在视觉b 过固定的曝光时间以一定帧率采集图像,事件相nc 导航定位方向有应用潜力,在实际应用场景中,由. 于剧烈运 
动、光照条件变化、平台功耗限制等影//:在亮度变化超过设定阈值时异步输出像素地址事u 响,传统帧相机存在运动模糊、 
过曝欠曝等限件流数据。传统帧相机受限于软硬件条件,帧率d 制[2],运动剧烈时检测的有效特征点减少,造成视觉信息
不可靠,事件相机则可以弥补这些不p t t一般为15~200fps(fps为帧/s),在高速运动的场景中会产生运动 
模糊。与此同时,在高动态的. 足。事件相机可应用于无人车、无人机、自主机器工作场景下会产生过曝与欠曝的现象,丢失
场景人以及增强现实(AugmentedReality,AR)和同时定位及地图构建(SimultaneousLo
calizationand部分细节信息。事件相机则具有高时间分辨率、b 高动态范围的特点[1],由于每一像素的异
步输出Mapping,SLAM)技术,相比于传统帧相机具有低. 特性,没有帧率的概念其响应时间可达到微秒级延时、
高动态、抗运动模糊、低运算量的特点,可以别,由于其检测光强对数的变化,其动态范围可达提高导航系统的鲁棒性[6]
。140dB[2]。
3. 有效性。事件流的可视化及降噪算法2 //:p 事件相机输出的异步事件流代表像素点亮度t t变化超出阈值,单个事件可 
以表示为event=(x,iiy,t,p),i为序号,(x,y)为事件发生的像素位置iiih 坐标,t为事件发生
时间,p∈{0,1}表示事件极性(0代表亮度变暗,1代表亮度变亮)。假设事件相机零延时且无噪声干扰,则t时刻物 
体边缘运动所引起的事件集合可以表示为
4. cameraoutput
Fig.2 Comparisonofframebasedcameraandevent
图2 传统帧相机和事件相机输出对比
传统帧机机
事件机机
5. http:∥bhxb.buaa.edu.cn  jbuaa@buaa.edu.cnDOI:10.13700/
j.bh.10015965.2020.0192u 基于事件相机的可及降视化算法

3.3 检索压力测试

下面再考虑一个问题:如果需要es里面有1000份文件,那么检索起来需要花费多少时间?

为此,进一步进行压力测试,考虑到一个文件大约可以分切分成100个文档,那么1000份文件约为100000个文档。下面进一步进行压力测试,模拟两个索引,像各索引内部临时插入5000份随机生成的文档,再进行搜索查询,代码如下:

import sys
import os
import logging
import json
import numpy as np
from timeit import default_timer as timer
from collections import defaultdict

# 添加项目根目录到路径
sys.path.append(os.path.abspath("."))

# 导入必要的模块
from rag.utils.doc_store_conn import MatchTextExpr, MatchDenseExpr, FusionExpr, OrderByExpr
from rag.utils.es_conn import ESConnection
from rag.nlp.query import FulltextQueryer
from rag.nlp import rag_tokenizer
from rag.llm.embedding_model import DefaultEmbedding
from rag import settings


# NLTK资源下载
import nltk
try:
    nltk.data.find('tokenizers/punkt_tab')
except LookupError:
    nltk.download('punkt_tab')

# 配置日志
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger('retrieval_test')

class RetrievalProcess:
    """复刻 RAGFlow 中的检索过程"""
    
    def __init__(self, tenant_id="default", kb_ids=None, embedding_model_name=None):
        """初始化检索过程"""
        self.tenant_id = tenant_id
        self.kb_ids = kb_ids or []
        self.es_conn = ESConnection()
        self.queryer = FulltextQueryer()
        
        # 获取嵌入模型
        if embedding_model_name:
            # 修改嵌入模型的初始化方式
            self.embd_mdl = DefaultEmbedding(key=None, model_name=embedding_model_name)
        else:
            # 使用默认嵌入模型
            self.embd_mdl = None
            
        # 向量相似度权重 (1-keywords_similarity_weight)
        self.vector_similarity_weight = 0.3 
        # 关键词相似度权重
        self.keywords_similarity_weight = 0.7 
        # 相似度阈值
        self.similarity_threshold = 0.2
        # 检索结果数量
        self.top_n = 5
        
    def process_query(self, question):
        """处理用户查询"""
        logger.info(f"处理用户查询: {question}")
        
        # 1. 文本处理和分词
        start = timer()
        match_expr, keywords = self.queryer.question(question)
        end = timer()
        logger.info(f"查询处理耗时: {(end - start) * 1000:.2f}ms")
        logger.info(f"提取的关键词: {keywords}")
        
        # 2. 生成向量表示
        if self.embd_mdl:
            start = timer()
            query_vector, _ = self.embd_mdl.encode([question])
            query_vector = query_vector[0]
            end = timer()
            logger.info(f"向量编码耗时: {(end - start) * 1000:.2f}ms")
        else:
            query_vector = None
            logger.warning("未配置嵌入模型,将仅使用关键词检索")
        
        return match_expr, keywords, query_vector
    
    def build_search_query(self, match_expr, query_vector=None):
        """构建搜索查询"""
        # 准备查询条件
        select_fields = ["id", "title", "content", "content_ltks", "content_sm_ltks", "kb_id", "doc_id", "docnm_kwd"]
        highlight_fields = ["title", "content"]
        condition = {"kb_id": self.kb_ids} if self.kb_ids else {}
        
        # 构建匹配表达式列表
        match_exprs = []
        
        # 添加文本匹配表达式
        if match_expr:
            match_exprs.append(match_expr)
        
        # 添加向量匹配表达式
        if query_vector is not None:
            # 确定向量字段名称
            vector_field = f"q_{len(query_vector)}_vec"
            # 添加向量字段到选择字段
            select_fields.append(vector_field)
            
            # 创建向量匹配表达式
            vector_expr = MatchDenseExpr(
                vector_column_name=vector_field,
                embedding_data=query_vector.tolist(),
                embedding_data_type='float',
                distance_type='cosine',
                topn=self.top_n,
                extra_options={"similarity": 0.1}  
            )
            match_exprs.append(vector_expr)
            
            # 添加融合表达式
            fusion_expr = FusionExpr(
                "weighted_sum", 
                self.top_n, 
                {"weights": f"{self.keywords_similarity_weight}, {self.vector_similarity_weight}"}
            )
            match_exprs.append(fusion_expr)
        
        return select_fields, highlight_fields, condition, match_exprs
    
    def search(self, question, index_names):
        """执行搜索"""
        # 处理查询
        match_expr, keywords, query_vector = self.process_query(question)
        
        # 构建搜索查询
        select_fields, highlight_fields, condition, match_exprs = self.build_search_query(match_expr, query_vector)
        
        # 执行搜索
        start = timer()
        search_response = self.es_conn.search(
            selectFields=select_fields,
            highlightFields=highlight_fields,
            condition=condition,
            matchExprs=match_exprs,
            orderBy=None,
            offset=0,
            limit=self.top_n,
            indexNames=index_names,
            knowledgebaseIds=self.kb_ids
        )
        
        # 从响应中提取实际的结果列表
        results = search_response.get("hits", {}).get("hits", [])
        end = timer()
        logger.info(f"搜索耗时: {(end - start) * 1000:.2f}ms")
        logger.info(f"搜索结果数量: {len(results)}")
        
        # 如果结果为空且使用了向量搜索,尝试降低匹配阈值重新搜索
        if len(results) == 0 and query_vector is not None:
            logger.info("搜索结果为空,尝试降低匹配阈值重新搜索")
            match_expr, _ = self.queryer.question(question, min_match=0.1)  # 降低min_match
            select_fields, highlight_fields, condition, match_exprs = self.build_search_query(match_expr, query_vector)
            
            # 修改向量匹配表达式的相似度阈值
            for expr in match_exprs:
                if isinstance(expr, MatchDenseExpr):
                    expr.extra_options["similarity"] = 0.17  # 提高similarity阈值
            
            # 重新执行搜索
            start = timer()
            search_response = self.es_conn.search(
                selectFields=select_fields,
                highlightFields=highlight_fields,
                condition=condition,
                matchExprs=match_exprs,
                orderBy=None,
                offset=0,
                limit=self.top_n,
                indexNames=index_names,
                knowledgebaseIds=self.kb_ids
            )
            results = search_response.get("hits", {}).get("hits", [])
            end = timer()
            logger.info(f"重新搜索耗时: {(end - start) * 1000:.2f}ms")
            logger.info(f"重新搜索结果数量: {len(results)}")
        
        # 如果有向量和分词,计算混合相似度并重新排序
        if query_vector is not None and len(results) > 0:
            results = self.rerank_results(results, query_vector, keywords)
        
        return results
    
    def rerank_results(self, results, query_vector, query_tokens):
        """重新排序搜索结果"""
        start = timer()
        
        # 提取文档向量和分词
        doc_vectors = []
        doc_tokens = []
        
        for result in results:
            # 提取向量
            vector = result["_source"]["q_1024_vec"]
            # 确保向量是一维数组
            if isinstance(vector, list) and len(vector) > 0:
                # 如果是嵌套列表,则取第一个元素
                if isinstance(vector[0], list):
                    vector = vector[0]
                doc_vectors.append(np.array(vector))
            else:
                # 如果向量格式不正确,使用零向量
                logger.warning(f"向量格式不正确: {type(vector)}")
                doc_vectors.append(np.zeros(1024))
                
            # 提取分词
            doc_tokens.append(result["_source"]["content_ltks"])
     
        # 计算混合相似度
        query_tokens_str = " ".join(query_tokens)
        
        # 确保查询向量是一维的
        if len(query_vector.shape) > 1:
            query_vector = query_vector.flatten()
        
        # 确保所有文档向量都是二维的 [n_samples, n_features]
        doc_vectors_array = np.vstack([v.flatten() for v in doc_vectors])
        
        try:
            hybrid_scores, token_scores, vector_scores = self.queryer.hybrid_similarity(
                query_vector, doc_vectors_array, query_tokens_str, doc_tokens,
                tkweight=self.keywords_similarity_weight,
                vtweight=self.vector_similarity_weight
            )
        except Exception as e:
            # 如果混合相似度计算失败,使用备用方法
            logger.error(f"混合相似度计算失败: {e}")
            # 计算向量相似度
            vector_scores = []
            for doc_vec in doc_vectors_array:
                # 计算余弦相似度
                sim = np.dot(query_vector, doc_vec) / (np.linalg.norm(query_vector) * np.linalg.norm(doc_vec) + 1e-8)
                vector_scores.append(float(sim))
            
            # 使用向量相似度作为混合分数
            hybrid_scores = vector_scores
            token_scores = [0.5] * len(vector_scores)  # 默认值
        
        # 为结果添加分数
        for i, result in enumerate(results):
            if i < len(hybrid_scores):
                result["hybrid_score"] = float(hybrid_scores[i])
                result["token_score"] = float(token_scores[i]) if i < len(token_scores) else 0.5
                result["vector_score"] = float(vector_scores[i]) if i < len(vector_scores) else 0.5
    
        # 查看各分数信息    
        # logger.info(f"混合分数: {hybrid_scores[:5]}...")
        # logger.info(f"关键词分数: {token_scores[:5]}...")
        # logger.info(f"向量分数: {vector_scores[:5]}...")
            
        # 按混合分数重新排序
        results.sort(key=lambda x: x["hybrid_score"], reverse=True)
        
        # 过滤低于阈值的结果
        results[:] = [r for r in results if r["hybrid_score"] >= self.similarity_threshold]
        
        end = timer()
        logger.info(f"重排序耗时: {(end - start) * 1000:.2f}ms")
        logger.info(f"重排序后结果数量: {len(results)}")
        
        return results
    
    def build_llm_prompt(self, question, results, system_prompt=None):
        """构建输入到LLM的提示"""
        if system_prompt is None:
            system_prompt = """你是一个智能助手。请基于提供的上下文信息回答用户的问题。
        如果上下文中没有足够的信息来回答问题,请说明你无法回答,不要编造信息。
        回答时请引用相关的上下文信息,并标明引用的来源。"""
        
        # 按文档组织内容
        doc2chunks = defaultdict(lambda: {"chunks": []})
        
        # 调试信息:打印结果结构
        # logger.info(f"搜索结果结构示例: {json.dumps(results[0] if results else {}, ensure_ascii=False, indent=2)[:500]}...")
        
        for i, result in enumerate(results):
            # 获取文档名称 - 从_source字段中提取
            source = result.get("_source", {})
            doc_name = source.get("docnm_kwd", f"文档{i+1}")
            
            # 提取内容
            content = source.get("content_with_weight", "")
            
            # 调试信息
            # logger.info(f"文档 {doc_name} 内容长度: {len(content)} 字符")
            # if content:
            #     logger.info(f"内容预览: {content[:100]}...")
            # else:
            #     logger.info("警告: 内容为空!")
            
            # 只有当内容不为空时才添加到对应文档的chunks中
            if content:
                doc2chunks[doc_name]["chunks"].append(content)
        
        # 构建上下文信息
        context_parts = []
        
        for doc_name, doc_info in doc2chunks.items():
            if not doc_info["chunks"]:  # 跳过没有内容的文档
                continue
                
            txt = f"Document: {doc_name}\n"
            txt += "Relevant fragments as following:\n"
            
            for i, chunk in enumerate(doc_info["chunks"], 1):
                txt += f"{i}. {chunk}\n"
            
            context_parts.append(txt)
        
        # 合并上下文
        context = "\n\n".join(context_parts)
        
        # 如果没有有效的上下文,添加提示信息
        if not context_parts:
            context = "未找到与问题相关的文档内容。"
            logger.warning("警告: 没有找到有效的文档内容!")
        
        # 构建完整提示
        messages = [
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": f"我需要回答以下问题:\n\n{question}\n\n以下是相关的上下文信息:\n\n{context}"}
        ]
        
        return messages
    
    def format_llm_prompt(self, messages):
        """格式化LLM提示以便于查看"""
        formatted = []
        
        for msg in messages:
            role = msg["role"]
            content = msg["content"]
            
            # 格式化不同角色的消息
            if role == "system":
                formatted.append(f"### 系统指令\n{content}")
            elif role == "user":
                formatted.append(f"### 用户输入\n{content}")
            elif role == "assistant":
                formatted.append(f"### 助手回复\n{content}")
            else:
                formatted.append(f"### {role}\n{content}")
        
        return "\n\n".join(formatted)

def list_indices():
    """列出所有索引"""
    es_conn = ESConnection()
    indices = es_conn.es.indices.get_alias(index="*")
    return list(indices.keys())

def run_test():
    """运行测试"""
    logger.info("开始检索测试...")
    
    # 列出所有索引
    indices = list_indices()
    logger.info(f"可用索引: {indices}")
    
    if not indices:
        logger.error("没有找到任何索引,请确保 Elasticsearch 中有数据")
        return
    
    # 选择第一个索引进行测试
    selected_index = indices[0]
    logger.info(f"选择索引 {selected_index} 进行测试")
    
    # 创建检索过程实例
    # retrieval = RetrievalProcess()
    # 创建检索过程实例时添加嵌入模型
    retrieval = RetrievalProcess(embedding_model_name="BAAI/bge-large-zh-v1.5")
    
    # 测试一些查询
    test_queries = [
        "事件相机是什么?"
    ]
    
    for query in test_queries:
        logger.info(f"\n测试查询: {query}")
        results = retrieval.search(query, selected_index)
        
        # 构建并打印LLM输入提示
        if results:
            logger.info("\n=== 输入到LLM的完整内容 ===")
            messages = retrieval.build_llm_prompt(query, results)
            formatted_prompt = retrieval.format_llm_prompt(messages)
            logger.info(formatted_prompt)
            
        else:
            logger.info("没有找到相关结果,无法构建LLM输入")

def run_stress_test(num_results=100, num_iterations=5):
    """
    运行压力测试,模拟更多的搜索结果
    
    Args:
        num_results: 模拟的结果数量
        num_iterations: 重复测试的次数
    """
    logger.info(f"开始压力测试 (模拟 {num_results} 个结果, {num_iterations} 次迭代)...")
    
    # 列出所有索引
    indices = list_indices()
    if not indices:
        logger.error("没有找到任何索引,请确保 Elasticsearch 中有数据")
        return
    
    # 选择第一个索引进行测试
    selected_index = indices[0]
    logger.info(f"选择索引 {selected_index} 进行测试")
    
    # 创建检索过程实例
    retrieval = RetrievalProcess(embedding_model_name="BAAI/bge-large-zh-v1.5")
    
    # 测试查询
    test_query = "事件相机是什么?"
    
    # 获取一个真实的搜索结果作为模板
    logger.info(f"获取模板结果...")
    template_results = retrieval.search(test_query, selected_index)
    
    if not template_results:
        logger.error("无法获取模板结果,请确保搜索能够返回至少一个结果")
        return
    
    # 创建模拟结果
    logger.info(f"创建 {num_results} 个模拟结果...")
    mock_results = []
    template = template_results[0]
    
    # 确保模板中有向量
    if "_source" not in template or "q_1024_vec" not in template["_source"]:
        logger.error("模板结果中没有向量字段,无法创建模拟结果")
        return
    
    # 获取向量维度
    vector_dim = len(template["_source"]["q_1024_vec"])
    logger.info(f"向量维度: {vector_dim}")
    
    for i in range(num_results):
        # 深拷贝模板结果并修改ID和分数
        mock_result = json.loads(json.dumps(template))
        mock_result["_id"] = f"mock_id_{i}"
        mock_result["_score"] = 1.0 - (i * 0.5 / num_results)  # 递减的分数
        
        # 修改向量以确保每个结果都不同,但保持一维数组
        if "_source" in mock_result and "q_1024_vec" in mock_result["_source"]:
            # 创建一个新的一维向量
            mock_result["_source"]["q_1024_vec"] = [
                template["_source"]["q_1024_vec"][j] + (i * 0.001) 
                for j in range(vector_dim)
            ]
        
        mock_results.append(mock_result)
    
    # 运行多次迭代测试
    for iteration in range(num_iterations):
        logger.info(f"\n迭代 {iteration+1}/{num_iterations}")
        results = mock_results
        # 测试重排序性能
        start = timer()
        query_vector = retrieval.embd_mdl.encode([test_query])[0]
        _, keywords = retrieval.queryer.question(test_query)
        reranked_results = retrieval.rerank_results(results[:], query_vector, keywords)
        end = timer()
        logger.info(f"重排序 {len(results)} 个结果耗时: {(end - start) * 1000:.2f}ms")
        
        # 测试提示构建性能
        start = timer()
        messages = retrieval.build_llm_prompt(test_query, reranked_results[:5])
        end = timer()
        logger.info(f"为前5个结果构建提示耗时: {(end - start) * 1000:.2f}ms")
    
    logger.info("\n压力测试完成")

def run_realistic_test(num_iterations=5, index_prefix="test_index_", num_indices=5, docs_per_index=1000):
    """
    运行更真实的测试,通过创建多个索引和大量文档来模拟大规模数据环境
    
    Args:
        num_iterations: 重复测试的次数
        index_prefix: 测试索引的前缀
        num_indices: 要创建的索引数量
        docs_per_index: 每个索引中的文档数量
    """
    logger.info(f"开始真实环境测试 (创建 {num_indices} 个索引,每个包含 {docs_per_index} 个文档)...")
    
    # 创建 ES 连接
    es_conn = ESConnection()
    
    # 创建测试索引和数据
    test_indices = []
    try:
        # 创建测试索引
        for i in range(num_indices):
            index_name = f"{index_prefix}{i}"
            test_indices.append(index_name)
            
            # 检查索引是否已存在
            if es_conn.es.indices.exists(index=index_name):
                logger.info(f"索引 {index_name} 已存在,跳过创建")
                continue
                
            # 创建索引
            logger.info(f"创建索引 {index_name}...")
            # 这里应该使用与实际应用相同的索引映射
            index_mapping = {
                "mappings": {
                    "properties": {
                        "title": {"type": "text"},
                        "content": {"type": "text"},
                        "content_ltks": {"type": "text"},
                        "content_sm_ltks": {"type": "text"},
                        "kb_id": {"type": "keyword"},
                        "doc_id": {"type": "keyword"},
                        "docnm_kwd": {"type": "keyword"},
                        "q_1024_vec": {
                            "type": "dense_vector",
                            "dims": 1024,
                            "index": True,
                            "similarity": "cosine"
                        }
                    }
                }
            }
            es_conn.es.indices.create(index=index_name, body=index_mapping)
            
            # 批量插入测试文档
            logger.info(f"向索引 {index_name} 插入 {docs_per_index} 个文档...")
            bulk_data = []
            for j in range(docs_per_index):
                # 创建随机文档
                doc = {
                    "title": f"测试文档 {j}",
                    "content": f"这是测试文档 {j} 的内容,包含一些随机文本用于测试检索功能。",
                    "content_ltks": f"测试 文档 {j} 内容 随机 文本 检索",
                    "content_sm_ltks": f"测试 文档 内容",
                    "kb_id": "test_kb",
                    "doc_id": f"doc_{j}",
                    "docnm_kwd": f"测试文档{j}",
                    "q_1024_vec": [0.01 * (k % 100) for k in range(1024)]  # 创建随机向量
                }
                
                # 添加到批量操作
                bulk_data.append({"index": {"_index": index_name, "_id": f"test_doc_{j}"}})
                bulk_data.append(doc)
                
                # 每1000个文档执行一次批量操作
                if len(bulk_data) >= 2000:
                    es_conn.es.bulk(body=bulk_data, refresh=True)
                    bulk_data = []
            
            # 处理剩余的文档
            if bulk_data:
                es_conn.es.bulk(body=bulk_data, refresh=True)
        
        # 等待索引刷新
        logger.info("等待索引刷新...")
        es_conn.es.indices.refresh(index=",".join(test_indices))
        
        # 创建检索过程实例
        retrieval = RetrievalProcess(embedding_model_name="BAAI/bge-large-zh-v1.5")
        
        # 测试查询
        test_queries = [
            "如何使用测试文档?"
        ]
        
        # 运行多次迭代测试
        for iteration in range(num_iterations):
            logger.info(f"\n迭代 {iteration+1}/{num_iterations}")
            
            for query in test_queries:
                logger.info(f"\n测试查询: {query}")
                
                # 测量整个搜索过程的性能
                start_total = timer()
                # 将查询作为第一个参数,索引作为第二个参数
                results = retrieval.search(query, test_indices)
                end_total = timer()
                
                logger.info(f"总搜索耗时: {(end_total - start_total) * 1000:.2f}ms")
                logger.info(f"搜索结果数量: {len(results)}")
                
                # 如果有结果,构建提示
                if results:
                    start_prompt = timer()
                    messages = retrieval.build_llm_prompt(query, results[:5])
                    end_prompt = timer()
                    logger.info(f"构建提示耗时: {(end_prompt - start_prompt) * 1000:.2f}ms")
    
    finally:
        # 清理测试索引
        if test_indices:
            logger.info(f"删除测试索引...")
            for index in test_indices:
                if es_conn.es.indices.exists(index=index):
                    es_conn.es.indices.delete(index=index)
            logger.info(f"测试索引已删除")
   

if __name__ == "__main__":
    # 运行测试
    run_realistic_test(num_iterations=3, num_indices=2, docs_per_index=50000)

部分输出结果如下:

迭代 3/3
2025-03-16 22:16:49,318 - retrieval_test - INFO -
测试查询: 如何使用测试文档?
2025-03-16 22:16:49,318 - retrieval_test - INFO - 处理用户查询: 如何使用测试文档?
2025-03-16 22:16:49,319 - retrieval_test - INFO - 查询处理耗时: 0.67ms
2025-03-16 22:16:49,319 - retrieval_test - INFO - 提取的关键词: ['使用测试文档', '测试', '文档', '使用']      
2025-03-16 22:16:49,497 - retrieval_test - INFO - 向量编码耗时: 176.55ms
2025-03-16 22:16:49,581 - elastic_transport.transport - INFO - POST http://localhost:1200/test_index_0,test_index_1/_search [status:200 duration:0.080s]
2025-03-16 22:16:49,584 - retrieval_test - INFO - 搜索耗时: 87.69ms
2025-03-16 22:16:49,584 - retrieval_test - INFO - 搜索结果数量: 5
2025-03-16 22:16:49,588 - retrieval_test - INFO - 重排序耗时: 3.53ms
2025-03-16 22:16:49,588 - retrieval_test - INFO - 重排序后结果数量: 5
2025-03-16 22:16:49,589 - retrieval_test - INFO - 总搜索耗时: 270.33ms
2025-03-16 22:16:49,589 - retrieval_test - INFO - 搜索结果数量: 5
2025-03-16 22:16:49,589 - retrieval_test - WARNING - 警告: 没有找到有效的文档内容!
2025-03-16 22:16:49,589 - retrieval_test - INFO - 构建提示耗时: 0.17ms
2025-03-16 22:16:49,589 - retrieval_test - INFO - 删除测试索引...

测试结果表明,即便文档很多,搜索耗时也没有明显上升,说明ES的索引结构的确非常高效。

4. Infinity简介

既然 es 如此高效,为什么 ragflow 项目还会构建 Infinity 容器?

据了解,Infinity 是专门服务大模型的数据库,采用 C++ 20 标准开发,确保了最优的执行路径,在各种创新算法的共同加持下,Infinity 在向量搜索性能上超越了所有已知向量数据库[3]。

官方的blog[4]给出了不同搜索引擎之间的性能对比,从延迟(latency)和每秒查询率(QPS)两个角度,Infinity 都遥遥领先。

在这里插入图片描述

Infinity 和 Elasticsearch 两者一个显著的技术区别是 Infinity 加入了执行引擎:当收到查询请求时,执行引擎会将其编译成计算 DAG(有向无环图),请求以流水线的方式在图中流动,引擎会根据可用资源动态确定图中每个节点的并行度。而 Elasticsearch 等搜索引擎没有执行引擎,而是直接通过倒排索引检索数据,并在排序后返回结果,所有操作都在一个线程内完成[4]。

此外,Elasticsearch 底层语言是JAVA, Infinity 底层语言是C++,通过运用C++多种特性实现了查询加速。

Infinity 并没有可视化的数据管理工具,参考 Infinity 连接的相关接口,通过以下脚本,可以查看Infinity 的一些信息:

#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Infinity数据库连接和探索工具
# 基于ragflow项目中的infinity_conn.py
#

import sys
import os
import logging
import json
import time
import argparse
from typing import List, Dict, Any, Optional

# 添加项目根目录到路径
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))

# 导入infinity相关模块
import infinity
from infinity.common import ConflictType, InfinityException, SortType
from infinity.index import IndexInfo, IndexType
from infinity.connection_pool import ConnectionPool
from infinity.errors import ErrorCode

# 配置日志
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger('infinity_explorer')

class InfinityExplorer:
    """Infinity数据库连接和探索工具"""
    
    def __init__(self, host: str = "localhost", port: int = 23817, db_name: str = "default_db"):
        """初始化Infinity连接"""
        self.host = host
        self.port = port
        self.db_name = db_name
        self.conn_pool = None
        self.connect()
        
    def connect(self) -> bool:
        """连接到Infinity数据库"""
        infinity_uri = infinity.common.NetworkAddress(self.host, self.port)
        logger.info(f"正在连接Infinity数据库: {self.host}:{self.port}")
        
        for _ in range(5):  # 尝试5次
            try:
                conn_pool = ConnectionPool(infinity_uri)
                inf_conn = conn_pool.get_conn()
                res = inf_conn.show_current_node()
                
                if res.error_code == ErrorCode.OK and res.server_status == "started":
                    self.conn_pool = conn_pool
                    conn_pool.release_conn(inf_conn)
                    logger.info(f"成功连接到Infinity数据库: {self.host}:{self.port}")
                    return True
                
                conn_pool.release_conn(inf_conn)
                logger.warning(f"Infinity状态: {res.server_status},等待Infinity服务就绪...")
                time.sleep(2)
            except Exception as e:
                logger.warning(f"连接Infinity失败: {str(e)},正在重试...")
                time.sleep(2)
                
        logger.error(f"无法连接到Infinity数据库: {self.host}:{self.port}")
        return False
    
    def list_databases(self) -> List[str]:
        """列出所有数据库"""
        if not self.conn_pool:
            logger.error("未连接到Infinity数据库")
            return []
            
        try:
            inf_conn = self.conn_pool.get_conn()
            dbs = inf_conn.list_databases()
            self.conn_pool.release_conn(inf_conn)
            return dbs.database_names
        except Exception as e:
            logger.error(f"列出数据库失败: {str(e)}")
            return []
    
    def list_tables(self, db_name: Optional[str] = None) -> List[str]:
        """列出指定数据库中的所有表"""
        if not self.conn_pool:
            logger.error("未连接到Infinity数据库")
            return []
            
        db_name = db_name or self.db_name
        
        try:
            inf_conn = self.conn_pool.get_conn()
            db_instance = inf_conn.get_database(db_name)
            tables = db_instance.list_tables()
            self.conn_pool.release_conn(inf_conn)
            return tables.table_names
        except Exception as e:
            logger.error(f"列出表失败: {str(e)}")
            return []
    
    def show_table_schema(self, table_name: str, db_name: Optional[str] = None) -> List[tuple]:
        """显示表结构"""
        if not self.conn_pool:
            logger.error("未连接到Infinity数据库")
            return []
            
        db_name = db_name or self.db_name
        
        try:
            inf_conn = self.conn_pool.get_conn()
            db_instance = inf_conn.get_database(db_name)
            table_instance = db_instance.get_table(table_name)
            columns = table_instance.show_columns()
            self.conn_pool.release_conn(inf_conn)
            
            # 返回列信息: (name, type, default_value, constraint)
            return columns.rows()
        except Exception as e:
            logger.error(f"获取表结构失败: {str(e)}")
            return []
    
    def list_indexes(self, table_name: str, db_name: Optional[str] = None) -> List[str]:
        """列出表的所有索引"""
        if not self.conn_pool:
            logger.error("未连接到Infinity数据库")
            return []
            
        db_name = db_name or self.db_name
        
        try:
            inf_conn = self.conn_pool.get_conn()
            db_instance = inf_conn.get_database(db_name)
            table_instance = db_instance.get_table(table_name)
            indexes = table_instance.list_indexes()
            self.conn_pool.release_conn(inf_conn)
            return indexes.index_names
        except Exception as e:
            logger.error(f"列出索引失败: {str(e)}")
            return []
    
    def query_table(self, table_name: str, limit: int = 10, db_name: Optional[str] = None) -> Any:
        """查询表数据"""
        if not self.conn_pool:
            logger.error("未连接到Infinity数据库")
            return None
            
        db_name = db_name or self.db_name
        
        try:
            inf_conn = self.conn_pool.get_conn()
            db_instance = inf_conn.get_database(db_name)
            table_instance = db_instance.get_table(table_name)
            
            # 查询所有列的前limit行数据
            result, _ = table_instance.output(["*"]).limit(limit).to_pl()
            self.conn_pool.release_conn(inf_conn)
            return result
        except Exception as e:
            logger.error(f"查询表数据失败: {str(e)}")
            return None
    
    def close(self):
        """关闭连接"""
        if self.conn_pool:
            logger.info("关闭Infinity数据库连接")
            self.conn_pool = None


def main():
    """主函数"""
    parser = argparse.ArgumentParser(description='Infinity数据库探索工具')
    parser.add_argument('--host', default='localhost', help='Infinity主机地址')
    parser.add_argument('--port', type=int, default=23817, help='Infinity端口')
    parser.add_argument('--db', default='default_db', help='数据库名称')
    parser.add_argument('--list-dbs', action='store_true', help='列出所有数据库')
    parser.add_argument('--list-tables', action='store_true', help='列出所有表')
    parser.add_argument('--table', help='指定要操作的表名')
    parser.add_argument('--schema', action='store_true', help='显示表结构')
    parser.add_argument('--indexes', action='store_true', help='显示表索引')
    parser.add_argument('--query', action='store_true', help='查询表数据')
    parser.add_argument('--limit', type=int, default=10, help='查询限制行数')
    
    args = parser.parse_args()
    
    explorer = InfinityExplorer(args.host, args.port, args.db)
    
    if args.list_dbs:
        dbs = explorer.list_databases()
        print(f"数据库列表 ({len(dbs)}):")
        for db in dbs:
            print(f"  - {db}")
    
    if args.list_tables:
        tables = explorer.list_tables(args.db)
        print(f"表列表 ({len(tables)}):")
        for table in tables:
            print(f"  - {table}")
    
    if args.table:
        if args.schema:
            columns = explorer.show_table_schema(args.table)
            print(f"表 {args.table} 的结构:")
            print("  名称\t\t类型\t\t默认值\t\t约束")
            print("  " + "-" * 60)
            for name, type_, default, constraint in columns:
                print(f"  {name}\t\t{type_}\t\t{default}\t\t{constraint}")
        
        if args.indexes:
            indexes = explorer.list_indexes(args.table)
            print(f"表 {args.table} 的索引 ({len(indexes)}):")
            for idx in indexes:
                print(f"  - {idx}")
        
        if args.query:
            result = explorer.query_table(args.table, args.limit)
            if result is not None:
                print(f"表 {args.table} 的数据 (前 {args.limit} 行):")
                print(result)
    
    explorer.close()


if __name__ == "__main__":
    main()

查看数据表:

python task_test/infinity_explorer.py --list-tables

查看表索引:

python task_test/infinity_explorer.py --table table_name --indexes

查看表数据:

python task_test/infinity_explorer.py --table your_table_name --query --limit 20

在Ragflow中,替换Elasticsearch为Infinity很容易,只需要修改 docker/.env

将 DOC_ENGINE 设置为 infinity :

DOC_ENGINE=${DOC_ENGINE:-infinity}

再次执行:

docker compose -f docker/docker-compose.yml up -d

这样就完成替换。

5. Elasticsearch和Infinity实际检索效率对比

ragflow中,在每轮对话都可以点这个按钮,查看具体的响应统计时间:

在这里插入图片描述

输出示例如下:

Query:
事件相机是什么?

Total: 15737.4ms
Check LLM: 6.3ms
Create retriever: 2.0ms
Bind embedding: 6099.0ms
Bind LLM: 48.4ms
Tune question: 2.0ms
Bind reranker: 0.0ms
Generate keyword: 0.0ms
Retrieval: 1554.9ms
Generate answer: 8024.8ms

各个参数含义如下:

  • Total (总耗时) : 整个对话回合的总耗时,包括检索和生成答案的全部过程
  • Check LLM (检查LLM) : 验证指定的大语言模型是否可用的时间
  • Create retriever (创建检索器) : 创建文档块检索器的时间
  • Bind embedding (绑定嵌入模型) : 初始化嵌入模型实例的时间,用于将文本转换为向量表示
  • Bind LLM (绑定LLM) : 初始化大语言模型实例的时间
  • Tune question (优化问题) : 使用多轮对话上下文优化用户查询的时间
  • Bind reranker (绑定重排序器) : 初始化重排序模型实例的时间,用于文档块检索的结果排序
  • Generate keyword (生成关键词) : 从用户查询中提取关键词的时间
  • Retrieval (检索) : 检索相关文档块的时间
  • Generate answer (生成答案) : 生成最终答案的时间

对于检索效率,实际只需要对比Retrieval的数值就行了。

下面我进行了一个小规模实验,比较一篇文章和十篇文章规模的情况下,Elasticsearch 和 Infinity 分别的检索用时,结果如下表所示:

引擎一篇文章检索用时十篇文章检索用时
Elasticsearch524.3ms3052.0ms
Infinity293.8ms1554.9ms

从数值角度看,nfinity 比 Elasticsearch 检索效率快一倍左右,因此,选Infinity替代Elasticsearch基本没什么问题。

总结

本文对比了Infinity和Elasticsearch两种搜索引擎的检索效率,虽然验证了Infinity的确效率高,但仍存在一个问题未解决。随知识库规模增长,检索时间的提升究竟是哪部分引起的?在3.3的压力测试上,发现大批量的索引,并没有带来检索时间的明显增加。合理怀疑是插入的内容较短,多内容重复,es对其有进行单独的优化策略。对于这一点,有了解的读者欢迎在评论区交流。

参考资料

[1] Elasticsearch 索引、类型、文档、分片与副本等核心概念介绍 :https://blog.csdn.net/weixin_53269650/article/details/138609908

[2] 一篇文章带你搞定 ElasticSearch 术语: https://zhuanlan.zhihu.com/p/109578675

[3] AI 原生数据库 Infinity 正式开源:https://zhuanlan.zhihu.com/p/673618509

[4] infiniflow blog:https://infiniflow.org/blog

附录

requirements.txt,用于环境安装:

academicagent==0.1.2
accelerate==1.5.2
aiohappyeyeballs==2.5.0
aiohttp==3.11.13
aiosignal==1.3.2
annotated-types==0.7.0
anyio==4.8.0
async-timeout==4.0.3
attrs==25.1.0
backoff==2.2.1
backports.tarfile==1.2.0
backtrader==1.9.78.123
beartype==0.20.0
beautifulsoup4==4.13.3
bs4==0.0.2
cachetools==5.5.2
cbor==1.0.0
certifi==2025.1.31
cffi==1.17.1
chardet==5.2.0
charset-normalizer==3.4.1
click==8.1.8
cn2an==0.5.23
cnki-agent==0.1.2
CnkiSpider==1.1.0
colorama==0.4.6
coloredlogs==15.0.1
colpali_engine==0.3.8
contourpy==1.3.1
cramjam==2.9.1
cryptography==44.0.2
csscompressor==0.9.5
cssselect==1.3.0
cssutils==2.11.1
ctranslate2==4.5.0
cycler==0.12.1
dashscope==1.22.2
dataclasses-json==0.6.7
DataRecorder==3.6.2
datasets==3.4.0
datrie==0.8.2
dill==0.3.8
diskcache==5.6.3
distro==1.9.0
docutils==0.21.2
DownloadKit==2.0.7
DrissionPage==4.1.0.17
einops==0.8.1
elastic-transport==8.17.1
elasticsearch==8.17.2
elasticsearch-dsl==8.17.1
et_xmlfile==2.0.0
evaluate==0.4.3
exceptiongroup==1.2.2
fastapi==0.115.11
fastparquet==2024.11.0
filelock==3.17.0
FlagEmbedding==1.3.4
flatbuffers==25.2.10
fonttools==4.56.0
frozenlist==1.5.0
fsspec==2024.12.0
google-ai-generativelanguage==0.6.15
google-api-core==2.24.2
google-api-python-client==2.164.0
google-auth==2.38.0
google-auth-httplib2==0.2.0
google-generativeai==0.8.4
googleapis-common-protos==1.69.1
GPUtil==1.4.0
greenlet==3.1.1
grpcio==1.71.0
grpcio-status==1.71.0
h11==0.14.0
hanziconv==0.3.2
hf_transfer==0.1.9
html-minifier==0.0.4
httpcore==1.0.7
httplib2==0.22.0
httptools==0.6.4
httpx==0.28.1
httpx-sse==0.4.0
huggingface-hub==0.29.3
humanfriendly==10.0
id==1.5.0
idna==3.10
ijson==3.3.0
importlib_metadata==8.6.1
infinity-sdk==0.6.0.dev3
infinity_emb==0.0.75
iniconfig==2.0.0
inscriptis==2.5.3
ir_datasets==0.5.10
jaraco.classes==3.4.0
jaraco.context==6.0.1
jaraco.functools==4.1.0
Jinja2==3.1.6
jiter==0.9.0
joblib==1.4.2
jsmin==3.0.1
json_repair==0.39.1
jsonpatch==1.33
jsonpointer==3.0.0
keyring==25.6.0
kiwisolver==1.4.8
langchain==0.3.20
langchain-community==0.3.19
langchain-core==0.3.41
langchain-ollama==0.2.3
langchain-text-splitters==0.3.6
langsmith==0.3.12
lxml==5.3.1
lz4==4.4.3
markdown-it-py==3.0.0
MarkupSafe==3.0.2
marshmallow==3.26.1
matplotlib==3.10.0
mdurl==0.1.2
monotonic==1.6
more-itertools==10.6.0
mpmath==1.3.0
multidict==6.1.0
multiprocess==0.70.16
mypy-extensions==1.0.0
mysql==0.0.3
mysql-connector-python==9.2.0
mysqlclient==2.2.7
networkx==3.4.2
nh3==0.2.21
nltk==3.9.1
numpy==1.26.4
ollama==0.4.7
onnx==1.17.0
onnxruntime==1.21.0
openai==1.66.3
openpyxl==3.1.5
optimum==1.24.0
orjson==3.10.15
ormsgpack==1.8.0
outcome==1.3.0.post0
packaging==24.2
pandas==2.2.3
pdfminer.six==20231228
pdfplumber==0.11.5
peft==0.14.0
pillow==11.1.0
pluggy==1.5.0
polars-lts-cpu==1.9.0
posthog==3.20.0
proces==0.1.7
prometheus-fastapi-instrumentator==7.0.2
prometheus_client==0.21.1
propcache==0.3.0
proto-plus==1.26.1
protobuf==5.29.3
psutil==7.0.0
pyarrow==17.0.0
pyasn1==0.6.1
pyasn1_modules==0.4.1
pycparser==2.22
pycryptodome==3.21.0
pycryptodomex==3.20.0
pydantic==2.9.2
pydantic-settings==2.8.1
pydantic_core==2.23.4
Pygments==2.19.1
PyJWT==2.8.0
PyMuPDF==1.25.3
PyMySQL==1.1.1
pyparsing==3.2.1
pypdfium2==4.30.1
pyreadline3==3.5.4
PySocks==1.7.1
pytest==8.3.5
python-dateutil==2.9.0.post0
python-dotenv==1.0.1
pytz==2025.1
pywin32-ctypes==0.2.3
PyYAML==6.0.2
readerwriterlock==1.0.9
readme_renderer==44.0
regex==2024.11.6
requests==2.32.3
requests-file==2.1.0
requests-toolbelt==1.0.0
rfc3986==2.0.0
rich==13.9.4
roman-numbers==1.0.2
rsa==4.9
ruamel.yaml==0.18.10
ruamel.yaml.clib==0.2.12
safetensors==0.5.3
scikit-learn==1.6.1
scipy==1.15.2
selenium==4.29.0
sentence-transformers==3.4.1
sentencepiece==0.2.0
shellingham==1.5.4
simplejson==3.20.1
six==1.17.0
sniffio==1.3.1
sortedcontainers==2.4.0
soundfile==0.12.1
soupsieve==2.6
SQLAlchemy==2.0.38
sqlglot==11.7.1
starlette==0.46.1
StrEnum==0.4.15
sympy==1.13.1
tenacity==9.0.0
threadpoolctl==3.6.0
thrift==0.20.0
tiktoken==0.9.0
timm==1.0.15
tldextract==5.1.3
tokenizers==0.21.1
tomli==2.2.1
torch==2.6.0
torchvision==0.21.0
tqdm==4.67.1
transformers==4.47.1
trec-car-tools==2.6
trio==0.29.0
trio-websocket==0.12.2
tushare==1.4.18
twine==6.1.0
typer==0.12.5
typing-inspect==0.9.0
typing_extensions==4.12.2
tzdata==2025.1
unlzw3==0.2.3
uritemplate==4.1.1
urllib3==2.3.0
uvicorn==0.32.1
valkey==6.1.0
warc3-wet==0.2.5
warc3-wet-clueweb09==0.2.5
watchfiles==1.0.4
webdriver-manager==4.0.2
websocket-client==1.8.0
websockets==15.0.1
Werkzeug==3.1.3
word2number==1.1
wsproto==1.2.0
xxhash==3.5.0
yarl==1.18.3
zhipuai==2.1.5.20250106
zipp==3.21.0
zlib-state==0.1.9
zstandard==0.23.0
原文地址:https://blog.csdn.net/qq1198768105/article/details/146287335
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.kler.cn/a/594709.html

相关文章:

  • 私域电商的进化逻辑与技术赋能:基于开源AI大模型与S2B2C商城的创新融合研究
  • 深度学习pytorch笔记:TCN
  • Vs code搭建uniapp-vue项目
  • DAY35贪心算法Ⅳ 重叠区间问题
  • Java 大视界 -- Java 大数据在智能政务舆情引导与公共危机管理中的应用(138)
  • Flask 模版引擎的语法
  • 【FAQ】HarmonyOS SDK 闭源开放能力 —Push Kit(10)
  • Redis 在windows下的下载安装与配置
  • 医院信息系统平台总体架构原则
  • 创造型设计模式
  • canvas数据标注功能简单实现:矩形、圆形
  • <el-form >ref数据监测不到的原因
  • 将MySQL数据同步到Elasticsearch作为全文检索数据的实战指南
  • 【从零开始学习计算机科学与技术】计算机网络(五)网络层
  • RocketMQ 架构
  • ssm_mysql_校园二手交易系统
  • 数据结构:用C语言实现插入排序
  • 设计模式,如单例模式、观察者模式在什么场景下使用
  • 在Oracle Linux 7上安装Oracle 11gr2数据库
  • 【 Kubernetes 风云录 】- Istio 实现流量染色及透传