当前位置: 首页 > article >正文

基于Langchain的txt文本向量库搭建与检索

这里的源码主要来自于Langchain-ChatGLM中的向量库部分,做了一些代码上的修改和封装,以适用于基于问题包含数据库表描述的txt文件(文件名为库表名,文件内容为库表中的字段及描述)对数据库表进行快速检索。

中文分词类

splitter.py

from langchain.text_splitter import CharacterTextSplitter
import re
from typing import List


class ChineseTextSplitter(CharacterTextSplitter):
    def __init__(self, pdf: bool = False, sentence_size: int = 100, **kwargs):
        super().__init__(**kwargs)
        self.pdf = pdf
        self.sentence_size = sentence_size

    def split_text1(self, text: str) -> List[str]:
        if self.pdf:
            text = re.sub(r"\n{3,}", "\n", text)
            text = re.sub('\s', ' ', text)
            text = text.replace("\n\n", "")
        sent_sep_pattern = re.compile('([﹒﹔﹖﹗。!?]["’”」』]{0,2}|(?=["‘“「『]{1,2}|$))')  # del :;
        sent_list = []
        for ele in sent_sep_pattern.split(text):
            if sent_sep_pattern.match(ele) and sent_list:
                sent_list[-1] += ele
            elif ele:
                sent_list.append(ele)
        return sent_list

    def split_text(self, text: str) -> List[str]:   ##此处需要进一步优化逻辑
        if self.pdf:
            text = re.sub(r"\n{3,}", r"\n", text)
            text = re.sub('\s', " ", text)
            text = re.sub("\n\n", "", text)

        text = re.sub(r'([;;!?。!?\?])([^”’])', r"\1\n\2", text)  # 单字符断句符
        text = re.sub(r'(\.{6})([^"’”」』])', r"\1\n\2", text)  # 英文省略号
        text = re.sub(r'(\…{2})([^"’”」』])', r"\1\n\2", text)  # 中文省略号
        text = re.sub(r'([;;!?。!?\?]["’”」』]{0,2})([^;;!?,。!?\?])', r'\1\n\2', text)
        # 如果双引号前有终止符,那么双引号才是句子的终点,把分句符\n放到双引号后,注意前面的几句都小心保留了双引号
        text = text.rstrip()  # 段尾如果有多余的\n就去掉它
        # 很多规则中会考虑分号;,但是这里我把它忽略不计,破折号、英文双引号等同样忽略,需要的再做些简单调整即可。
        ls = [i for i in text.split("\n") if i]
        for ele in ls:
            if len(ele) > self.sentence_size:
                ele1 = re.sub(r'([,,]["’”」』]{0,2})([^,,])', r'\1\n\2', ele)
                ele1_ls = ele1.split("\n")
                for ele_ele1 in ele1_ls:
                    if len(ele_ele1) > self.sentence_size:
                        ele_ele2 = re.sub(r'([\n]{1,}| {2,}["’”」』]{0,2})([^\s])', r'\1\n\2', ele_ele1)
                        ele2_ls = ele_ele2.split("\n")
                        for ele_ele2 in ele2_ls:
                            if len(ele_ele2) > self.sentence_size:
                                ele_ele3 = re.sub('( ["’”」』]{0,2})([^ ])', r'\1\n\2', ele_ele2)
                                ele2_id = ele2_ls.index(ele_ele2)
                                ele2_ls = ele2_ls[:ele2_id] + [i for i in ele_ele3.split("\n") if i] + ele2_ls[
                                                                                                       ele2_id + 1:]
                        ele_id = ele1_ls.index(ele_ele1)
                        ele1_ls = ele1_ls[:ele_id] + [i for i in ele2_ls if i] + ele1_ls[ele_id + 1:]

                id = ls.index(ele)
                ls = ls[:id] + [i for i in ele1_ls if i] + ls[id + 1:]
        return ls

faiss向量库类

myfaiss.py

from langchain.vectorstores import FAISS
from langchain.vectorstores.base import VectorStore
from langchain.vectorstores.faiss import dependable_faiss_import
from typing import Any, Callable, List, Dict
from langchain.docstore.base import Docstore
from langchain.docstore.document import Document
import numpy as np
import copy
import os


class MyFAISS(FAISS, VectorStore):
    def __init__(
            self,
            embedding_function: Callable,
            index: Any,
            docstore: Docstore,
            index_to_docstore_id: Dict[int, str],
            normalize_L2: bool = False,
    ):
        super().__init__(embedding_function=embedding_function,
                         index=index,
                         docstore=docstore,
                         index_to_docstore_id=index_to_docstore_id,
                         normalize_L2=normalize_L2)

    def seperate_list(self, ls: List[int]) -> List[List[int]]:
        lists = []
        ls1 = [ls[0]]
        source1 = self.index_to_docstore_source(ls[0])
        for i in range(1, len(ls)):
            if ls[i - 1] + 1 == ls[i] and self.index_to_docstore_source(ls[i]) == source1:
                ls1.append(ls[i])
            else:
                lists.append(ls1)
                ls1 = [ls[i]]
                source1 = self.index_to_docstore_source(ls[i])
        lists.append(ls1)
        return lists

    def similarity_search_with_score_by_vector(
            self, embedding: List[float], k: int = 4
    ) -> List[Document]:
        faiss = dependable_faiss_import()
        # (1,1024)
        vector = np.array([embedding], dtype=np.float32)
        # 默认FALSE
        if self._normalize_L2:
            faiss.normalize_L2(vector)
        # shape均为(1, k)
        scores, indices = self.index.search(vector, k)
        docs = []
        id_set = set()
        # 存储关键句
        keysentences = []
        # 遍历找到的k个最近相关文档的索引
        # top-k是第一次的筛选条件,score是第二次的筛选条件
        for j, i in enumerate(indices[0]):
            if i in self.index_to_docstore_id:
                _id = self.index_to_docstore_id[i]
            # 执行接下来的操作
            else:
                continue
            # index→id→content
            doc = self.docstore.search(_id)
            doc.metadata["score"] = int(scores[0][j])
            docs.append(doc)
            # 其实存的都是index
            id_set.add(i)
        docs.sort(key=lambda doc: doc.metadata['score'])
        return docs

嵌入检索类

embedder.py

from langchain.embeddings.huggingface import HuggingFaceEmbeddings
from langchain.document_loaders import TextLoader
from embeddings.splitter import ChineseTextSplitter
from embeddings.myfaiss import MyFAISS
import os
import torch
from config import *

def torch_gc():
    if torch.cuda.is_available():
        # with torch.cuda.device(DEVICE):
        torch.cuda.empty_cache()
        torch.cuda.ipc_collect()
    elif torch.backends.mps.is_available():
        try:
            from torch.mps import empty_cache
            empty_cache()
        except Exception as e:
            print(e)
            print("如果您使用的是 macOS 建议将 pytorch 版本升级至 2.0.0 或更高版本,以支持及时清理 torch 产生的内存占用。")

class Embedder:
    def __init__(self, config):
        self.model = HuggingFaceEmbeddings(model_name=
            "/home/df1500/NLP/LLM/pretrained_model/WordEmbeddings/"+config.emb_model,
            model_kwargs={'device': 'cuda'})
        self.config = config
        self.create_vector_score()
        self.vector_store = MyFAISS.load_local(self.config.db_vs_path, self.model)

    def load_file(self, filepath):
        # 对文件分词
        if filepath.lower().endswith(".txt"):
            loader = TextLoader(filepath, autodetect_encoding=True)
            textsplitter = ChineseTextSplitter(pdf=False, sentence_size=self.config.sentence_size)
            docs = loader.load_and_split(textsplitter)
        else:
            raise Exception("{}文件不是txt格式".format(filepath))
        return docs
    
    def txt2vector_store(self, filepaths):
        # 批量建立知识库
        docs = []
        for filepath in filepaths:
            try:
                docs += self.load_file(filepath)
            except Exception as e:
                raise Exception("{}文件加载失败".format(filepath))
        print("文件加载完毕,正在生成向量库")
        vector_store = MyFAISS.from_documents(docs, self.model)
        torch_gc()
        vector_store.save_local(self.config.db_vs_path)

    def create_vector_score(self):
        if "index.faiss" not in os.listdir(self.config.db_vs_path):
            filepaths = os.listdir(self.config.db_doc_path)
            filepaths = [os.path.join(self.config.db_doc_path, filepath) for filepath in filepaths]
            self.txt2vector_store(filepaths)
        print("向量库已建立成功")

    def get_topk_db(self, query):
        related_dbs_with_score = self.vector_store.similarity_search_with_score(query, k=self.config.sim_k)
        topk_db = [{'匹配句': db_data.page_content, 
                    '数据库': os.path.basename(db_data.metadata['source'])[:-4], 
                    '得分': db_data.metadata['score']} 
                   for db_data in related_dbs_with_score]
        return topk_db

测试代码

Config是用来传参的类,这里略去定义

if __name__ == '__main__':
    Conf = Config()
    configs = Conf.get_config()
    embedder = Embedder(configs)
    query = "公司哪个月的出勤率是最高的?"
    topk_db = embedder.get_topk_db(query)
    print(topk_db)

http://www.kler.cn/a/156857.html

相关文章:

  • 《MYSQL45讲》误删数据怎么办
  • 软件测试面试2024最新热点问题
  • SOLIDWORKS代理商鑫辰信息科技
  • golang分布式缓存项目 Day1 LRU 缓存淘汰策略
  • 2024 kali操作系统安装Docker步骤
  • 运行springBlade项目历程
  • 菜鸟学习日记(python)——数据类型转换
  • 记一次ThreadPoolTaskExecutor的坑
  • 2023年道路运输企业主要负责人证模拟考试题库及道路运输企业主要负责人理论考试试题
  • IRS辅助的隐蔽通信 (IRS aided covert communication)
  • csapp-linklab之第3阶段“输出学号”实验报告(强弱符号)
  • qt 安装
  • [C/C++]数据结构 堆排序(详细图解)
  • C++ 基础篇
  • 预约按摩小程序有哪些功能特点?
  • autojs-ui悬浮按钮模板
  • 【C语言】存储类型说明符——auto、static、extern、register
  • 华为OD机试真题-电脑病毒感染-2023年OD统一考试(C卷)
  • 【数据库设计和SQL基础语法】--SQL语言概述--SQL的基本结构和语法规则(二)
  • 预测胶质瘤预后的铜结合蛋白的转录组学特征
  • 优维低代码实践:搜索功能
  • 前端工作总结03
  • Ubuntu20.04/Linux中常用软件的安装
  • 翻硬币(第四届蓝桥杯省赛C++B组)(java版)
  • csdn语法说明/csdn新手指导/csdn入门指导/csdn博文助手
  • 初试占比7成!只考一门数据结构+学硕复录比1:1的神仙学校,大连交通大学考情分析