当前位置: 首页 > article >正文

加载MiniLM-L12-v2模型及知识库,调用Deepseek进行问答

这段 Python 代码实现了一个基于知识库的问答系统 KnowledgeBaseSystem。该系统主要具备以下功能:

  1. 加载本地模型对文本进行编码。
  2. 从指定路径加载知识库文件。
  3. 对知识库中的文本进行向量化处理。
  4. 根据用户的问题,找出知识库中最相关的知识。
  5. 调用 DeepSeek API 结合相关知识生成回答。

代码详细说明

导入必要的库

python

import os
import numpy as np
import requests
from typing import Dict, List, Tuple
import time
from sentence_transformers import SentenceTransformer

  • os:用于处理文件和目录操作。
  • numpy:用于数值计算,如向量运算。
  • requests:用于发送 HTTP 请求,调用 DeepSeek API。
  • typing:用于类型提示,增强代码的可读性和可维护性。
  • time:用于记录查询耗时。
  • SentenceTransformer:用于加载和使用预训练的句子嵌入模型。
KnowledgeBaseSystem 类

python

class KnowledgeBaseSystem:
    def __init__(self, knowledge_base_path: str, deepseek_api_key: str):
        self.knowledge_base_path = knowledge_base_path
        self.deepseek_api_key = deepseek_api_key
        self.model = self._load_local_model()  # 加载本地模型
        self.knowledge_base: Dict[str, str] = {}
        self.vectorized_knowledge: Dict[str, np.ndarray] = {}
        
        self._load_knowledge_base()
        self._vectorize_knowledge()

  • __init__ 方法:类的构造函数,初始化知识库系统。
    • knowledge_base_path:知识库文件所在的路径。
    • deepseek_api_key:DeepSeek API 的密钥。
    • self.model:通过 _load_local_model 方法加载的本地模型。
    • self.knowledge_base:存储知识库文件内容的字典,键为文件名,值为文件内容。
    • self.vectorized_knowledge:存储知识库文件向量化结果的字典,键为文件名,值为向量表示。
_load_local_model 方法

python

def _load_local_model(self):
    """从本地文件加载模型"""
    model_paths = [
        XXXX,  # 本地模型路径
        os.path.join("models", "paraphrase-multilingual-MiniLM-L12-v2"),
        os.path.join("models", "all-MiniLM-L6-v2"),
        "all-MiniLM-L6-v2"  # 最后尝试从缓存加载
    ]
    
    for path in model_paths:
        try:
            if os.path.exists(path):
                print(f"尝试从本地加载模型: {path}")
                return SentenceTransformer(path)
            else:
                print(f"尝试加载模型: {path} (未找到本地文件)")
                return SentenceTransformer(path.split('/')[-1])  # 尝试从名称加载
        except Exception as e:
            print(f"加载模型 {path} 失败: {e}")
            continue
    
    raise RuntimeError("""
无法加载任何模型,请按以下步骤操作:
1. 手动下载模型文件:
   - 访问 https://huggingface.co/sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2
   - 点击"↓"按钮下载整个仓库
   - 解压到项目目录下的 models/ 文件夹中
2. 或者运行以下命令自动下载(需要有网络连接):
   python -c "from sentence_transformers import SentenceTransformer; SentenceTransformer('all-MiniLM-L6-v2')"
""")

  • 功能:尝试从多个路径加载本地模型。
  • 步骤
    1. 定义一个包含多个模型路径的列表 model_paths
    2. 遍历 model_paths,尝试加载模型。
    3. 如果路径存在,使用 SentenceTransformer 从本地路径加载模型。
    4. 如果路径不存在,尝试从模型名称加载模型。
    5. 如果所有路径都无法加载模型,抛出 RuntimeError 并给出解决建议。
_load_knowledge_base 方法

python

def _load_knowledge_base(self):
    """加载知识库"""
    print("正在加载知识库...")
    for root, _, files in os.walk(self.knowledge_base_path):
        for file in files:
            if file.endswith(".txt"):
                file_path = os.path.join(root, file)
                try:
                    with open(file_path, 'r', encoding='utf-8') as f:
                        self.knowledge_base[file] = f.read()
                except Exception as e:
                    print(f"加载文件 {file_path} 失败: {e}")
    print(f"已加载 {len(self.knowledge_base)} 个知识文件")

  • 功能:从指定路径加载所有 .txt 文件作为知识库。
  • 步骤
    1. 使用 os.walk 遍历指定路径下的所有文件。
    2. 筛选出 .txt 文件,读取文件内容并存储到 self.knowledge_base 字典中。
    3. 打印加载的文件数量。
_vectorize_knowledge 方法

python

def _vectorize_knowledge(self):
    """向量化知识库"""
    print("正在向量化知识库...")
    for key, content in self.knowledge_base.items():
        self.vectorized_knowledge[key] = self.model.encode(content)
    print("知识库向量化完成")

  • 功能:使用加载的模型对知识库中的文本进行向量化处理。
  • 步骤
    1. 遍历 self.knowledge_base 字典,对每个文件的内容进行编码。
    2. 将编码结果存储到 self.vectorized_knowledge 字典中。
    3. 打印向量化完成的信息。
_get_most_relevant_knowledge 方法

python

def _get_most_relevant_knowledge(self, query: str, top_k: int = 3) -> List[Tuple[str, float]]:
    """获取最相关知识"""
    query_vector = self.model.encode(query)
    similarities = []
    
    for key, vector in self.vectorized_knowledge.items():
        similarity = np.dot(query_vector, vector) / (np.linalg.norm(query_vector) * np.linalg.norm(vector))
        similarities.append((key, similarity))
    
    return sorted(similarities, key=lambda x: x[1], reverse=True)[:top_k]

  • 功能:根据用户的问题,找出知识库中最相关的 top_k 个知识。
  • 步骤
    1. 对用户的问题进行编码,得到查询向量 query_vector
    2. 计算查询向量与知识库中每个文件向量的余弦相似度。
    3. 将文件名和相似度组成元组,存储到 similarities 列表中。
    4. 对 similarities 列表按相似度降序排序,取前 top_k 个结果。
_call_deepseek_api 方法

python

def _call_deepseek_api(self, context: str, query: str) -> str:
    """调用DeepSeek API"""
    headers = {
        "Authorization": f"Bearer {self.deepseek_api_key}",
        "Content-Type": "application/json"
    }
    
    prompt = f"""基于以下上下文回答问题:
    
【上下文】
{context}

【问题】
{query}

请给出专业、准确的回答:"""
    
    try:
        response = requests.post(
            "https://api.deepseek.com/v1/chat/completions",
            json={
                "model": "deepseek-chat",
                "messages": [{"role": "user", "content": prompt}],
                "temperature": 0.7,
                "max_tokens": 1000
            },
            headers=headers,
            timeout=30
        )
        response.raise_for_status()
        return response.json()["choices"][0]["message"]["content"]
    except Exception as e:
        print(f"API调用失败: {e}")
        return "无法获取回答,请检查网络连接和API密钥。"

  • 功能:调用 DeepSeek API,结合相关知识和用户问题生成回答。
  • 步骤
    1. 设置请求头,包含 API 密钥和内容类型。
    2. 构造请求的提示信息,包含上下文和用户问题。
    3. 使用 requests.post 方法发送请求到 DeepSeek API。
    4. 检查响应状态码,如果正常,返回 API 的回答;否则,打印错误信息并返回错误提示。
query 方法

python

def query(self, question: str) -> str:
    """查询知识库"""
    try:
        start_time = time.time()
        
        relevant = self._get_most_relevant_knowledge(question)
        if not relevant:
            return "未找到相关信息。"
        
        context = "\n\n".join(f"【{k}】\n{self.knowledge_base[k]}" for k, _ in relevant)
        answer = self._call_deepseek_api(context, question)
        
        print(f"查询耗时: {time.time()-start_time:.2f}秒")
        return answer
    except Exception as e:
        return f"查询失败: {str(e)}"

  • 功能:处理用户的查询请求,返回回答。
  • 步骤
    1. 记录查询开始时间。
    2. 调用 _get_most_relevant_knowledge 方法找出最相关的知识。
    3. 如果没有找到相关知识,返回提示信息。
    4. 构造上下文信息,调用 _call_deepseek_api 方法生成回答。
    5. 记录查询结束时间,打印查询耗时。
    6. 返回回答,如果出现异常,返回错误信息。
main 函数

python

def main():
    # 配置参数
    KNOWLEDGE_BASE_PATH = r"D:\06_Python\20250328_Graph_knowledge\laws"
    DEEPSEEK_API_KEY = "XXXX"  # 替换为你的API密钥
    
    try:
        print("初始化知识库系统...")
        kb = KnowledgeBaseSystem(KNOWLEDGE_BASE_PATH, DEEPSEEK_API_KEY)
        print("系统已就绪,输入问题开始查询('退出'结束)")
        
        while True:
            try:
                q = input("\n问题: ").strip()
                if q.lower() in ['退出', 'exit', 'quit']:
                    break
                if q:
                    print("\n回答:", kb.query(q))
            except KeyboardInterrupt:
                print("\n输入'退出'结束程序")
                continue
    except Exception as e:
        print(f"系统初始化失败: {str(e)}")
    finally:
        print("系统已关闭")

  • 功能:程序的入口函数,初始化知识库系统并处理用户的查询请求。
  • 步骤
    1. 设置知识库路径和 DeepSeek API 密钥。
    2. 初始化 KnowledgeBaseSystem 类的实例。
    3. 进入循环,等待用户输入问题。
    4. 如果用户输入 退出exit 或 quit,退出循环。
    5. 如果用户输入有效问题,调用 query 方法获取回答并打印。
    6. 处理异常,确保系统关闭时打印关闭信息。
程序入口

python

if __name__ == "__main__":
    main()

  • 确保代码作为脚本直接运行时,调用 main 函数。

使用说明

  1. 确保已经安装了所需的库:numpyrequests 和 sentence-transformers
  2. 将 XXXX 替换为实际的本地模型路径和 DeepSeek API 密钥。
  3. 将知识库文件(.txt 格式)放在指定的路径下。
  4. 运行脚本,按照提示输入问题进行查询。输入 退出exit 或 quit 结束程序。

完整代码(需添加本地模型路径及deep seek的API)

import os
import numpy as np
import requests
from typing import Dict, List, Tuple
import time
from sentence_transformers import SentenceTransformer

class KnowledgeBaseSystem:
    def __init__(self, knowledge_base_path: str, deepseek_api_key: str):
        self.knowledge_base_path = knowledge_base_path
        self.deepseek_api_key = deepseek_api_key
        self.model = self._load_local_model()  # 加载本地模型
        self.knowledge_base: Dict[str, str] = {}
        self.vectorized_knowledge: Dict[str, np.ndarray] = {}
        
        self._load_knowledge_base()
        self._vectorize_knowledge()

    def _load_local_model(self):
        """从本地文件加载模型"""
        model_paths = [
            XXXX,  # 本地模型路径
            os.path.join("models", "paraphrase-multilingual-MiniLM-L12-v2"),
            os.path.join("models", "all-MiniLM-L6-v2"),
            "all-MiniLM-L6-v2"  # 最后尝试从缓存加载
        ]
        
        for path in model_paths:
            try:
                if os.path.exists(path):
                    print(f"尝试从本地加载模型: {path}")
                    return SentenceTransformer(path)
                else:
                    print(f"尝试加载模型: {path} (未找到本地文件)")
                    return SentenceTransformer(path.split('/')[-1])  # 尝试从名称加载
            except Exception as e:
                print(f"加载模型 {path} 失败: {e}")
                continue
        
        raise RuntimeError("""
无法加载任何模型,请按以下步骤操作:
1. 手动下载模型文件:
   - 访问 https://huggingface.co/sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2
   - 点击"↓"按钮下载整个仓库
   - 解压到项目目录下的 models/ 文件夹中
2. 或者运行以下命令自动下载(需要有网络连接):
   python -c "from sentence_transformers import SentenceTransformer; SentenceTransformer('all-MiniLM-L6-v2')"
""")

    def _load_knowledge_base(self):
        """加载知识库"""
        print("正在加载知识库...")
        for root, _, files in os.walk(self.knowledge_base_path):
            for file in files:
                if file.endswith(".txt"):
                    file_path = os.path.join(root, file)
                    try:
                        with open(file_path, 'r', encoding='utf-8') as f:
                            self.knowledge_base[file] = f.read()
                    except Exception as e:
                        print(f"加载文件 {file_path} 失败: {e}")
        print(f"已加载 {len(self.knowledge_base)} 个知识文件")

    def _vectorize_knowledge(self):
        """向量化知识库"""
        print("正在向量化知识库...")
        for key, content in self.knowledge_base.items():
            self.vectorized_knowledge[key] = self.model.encode(content)
        print("知识库向量化完成")

    def _get_most_relevant_knowledge(self, query: str, top_k: int = 3) -> List[Tuple[str, float]]:
        """获取最相关知识"""
        query_vector = self.model.encode(query)
        similarities = []
        
        for key, vector in self.vectorized_knowledge.items():
            similarity = np.dot(query_vector, vector) / (np.linalg.norm(query_vector) * np.linalg.norm(vector))
            similarities.append((key, similarity))
        
        return sorted(similarities, key=lambda x: x[1], reverse=True)[:top_k]

    def _call_deepseek_api(self, context: str, query: str) -> str:
        """调用DeepSeek API"""
        headers = {
            "Authorization": f"Bearer {self.deepseek_api_key}",
            "Content-Type": "application/json"
        }
        
        prompt = f"""基于以下上下文回答问题:
        
【上下文】
{context}

【问题】
{query}

请给出专业、准确的回答:"""
        
        try:
            response = requests.post(
                "https://api.deepseek.com/v1/chat/completions",
                json={
                    "model": "deepseek-chat",
                    "messages": [{"role": "user", "content": prompt}],
                    "temperature": 0.7,
                    "max_tokens": 1000
                },
                headers=headers,
                timeout=30
            )
            response.raise_for_status()
            return response.json()["choices"][0]["message"]["content"]
        except Exception as e:
            print(f"API调用失败: {e}")
            return "无法获取回答,请检查网络连接和API密钥。"

    def query(self, question: str) -> str:
        """查询知识库"""
        try:
            start_time = time.time()
            
            relevant = self._get_most_relevant_knowledge(question)
            if not relevant:
                return "未找到相关信息。"
            
            context = "\n\n".join(f"【{k}】\n{self.knowledge_base[k]}" for k, _ in relevant)
            answer = self._call_deepseek_api(context, question)
            
            print(f"查询耗时: {time.time()-start_time:.2f}秒")
            return answer
        except Exception as e:
            return f"查询失败: {str(e)}"
def main():
    # 配置参数
    KNOWLEDGE_BASE_PATH = r"D:\06_Python\20250328_Graph_knowledge\laws"
    DEEPSEEK_API_KEY = "XXXX"  # 替换为你的API密钥
    
    try:
        print("初始化知识库系统...")
        kb = KnowledgeBaseSystem(KNOWLEDGE_BASE_PATH, DEEPSEEK_API_KEY)
        print("系统已就绪,输入问题开始查询('退出'结束)")
        
        while True:
            try:
                q = input("\n问题: ").strip()
                if q.lower() in ['退出', 'exit', 'quit']:
                    break
                if q:
                    print("\n回答:", kb.query(q))
            except KeyboardInterrupt:
                print("\n输入'退出'结束程序")
                continue
    except Exception as e:
        print(f"系统初始化失败: {str(e)}")
    finally:
        print("系统已关闭")

if __name__ == "__main__":
    main()


http://www.kler.cn/a/614302.html

相关文章:

  • 【Hysteria】部署+测试
  • 虚拟机docker配置ES
  • Docker:ERROR [internal] load metadata for docker.io/library/java:8-alpine问题解决
  • UDS故障码(DTC)SAE格式和HEX相互转换公式
  • B3647 【模板】Floyd
  • ubuntu 安装mysql
  • 【计算机网络】网络原理
  • 智能路由系统-信息泄露漏洞挖掘
  • 第30周Java分布式入门 ThreadLocal
  • Tomcat深度解析:Java Web服务的核心引擎
  • Qwen-0.5b linux部署
  • sql注入语句学习
  • Scala总结(二)
  • 【leetcode】拆解与整合:分治并归的算法逻辑
  • 基于 IEC 61499 标准的开放自动化技术发展现状与展望
  • pycharm与python版本
  • Redis 实现分布式锁详解
  • 网络安全(一):常见的网络威胁及防范
  • 求解AX=XB 方法
  • STM32学习笔记之系统异常和中断(原理篇)