加载MiniLM-L12-v2模型及知识库,调用Deepseek进行问答
这段 Python 代码实现了一个基于知识库的问答系统 KnowledgeBaseSystem
。该系统主要具备以下功能:
- 加载本地模型对文本进行编码。
- 从指定路径加载知识库文件。
- 对知识库中的文本进行向量化处理。
- 根据用户的问题,找出知识库中最相关的知识。
- 调用 DeepSeek API 结合相关知识生成回答。
代码详细说明
导入必要的库
python
import os
import numpy as np
import requests
from typing import Dict, List, Tuple
import time
from sentence_transformers import SentenceTransformer
os
:用于处理文件和目录操作。numpy
:用于数值计算,如向量运算。requests
:用于发送 HTTP 请求,调用 DeepSeek API。typing
:用于类型提示,增强代码的可读性和可维护性。time
:用于记录查询耗时。SentenceTransformer
:用于加载和使用预训练的句子嵌入模型。
KnowledgeBaseSystem
类
python
class KnowledgeBaseSystem:
def __init__(self, knowledge_base_path: str, deepseek_api_key: str):
self.knowledge_base_path = knowledge_base_path
self.deepseek_api_key = deepseek_api_key
self.model = self._load_local_model() # 加载本地模型
self.knowledge_base: Dict[str, str] = {}
self.vectorized_knowledge: Dict[str, np.ndarray] = {}
self._load_knowledge_base()
self._vectorize_knowledge()
__init__
方法:类的构造函数,初始化知识库系统。knowledge_base_path
:知识库文件所在的路径。deepseek_api_key
:DeepSeek API 的密钥。self.model
:通过_load_local_model
方法加载的本地模型。self.knowledge_base
:存储知识库文件内容的字典,键为文件名,值为文件内容。self.vectorized_knowledge
:存储知识库文件向量化结果的字典,键为文件名,值为向量表示。
_load_local_model
方法
python
def _load_local_model(self):
"""从本地文件加载模型"""
model_paths = [
XXXX, # 本地模型路径
os.path.join("models", "paraphrase-multilingual-MiniLM-L12-v2"),
os.path.join("models", "all-MiniLM-L6-v2"),
"all-MiniLM-L6-v2" # 最后尝试从缓存加载
]
for path in model_paths:
try:
if os.path.exists(path):
print(f"尝试从本地加载模型: {path}")
return SentenceTransformer(path)
else:
print(f"尝试加载模型: {path} (未找到本地文件)")
return SentenceTransformer(path.split('/')[-1]) # 尝试从名称加载
except Exception as e:
print(f"加载模型 {path} 失败: {e}")
continue
raise RuntimeError("""
无法加载任何模型,请按以下步骤操作:
1. 手动下载模型文件:
- 访问 https://huggingface.co/sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2
- 点击"↓"按钮下载整个仓库
- 解压到项目目录下的 models/ 文件夹中
2. 或者运行以下命令自动下载(需要有网络连接):
python -c "from sentence_transformers import SentenceTransformer; SentenceTransformer('all-MiniLM-L6-v2')"
""")
- 功能:尝试从多个路径加载本地模型。
- 步骤:
- 定义一个包含多个模型路径的列表
model_paths
。 - 遍历
model_paths
,尝试加载模型。 - 如果路径存在,使用
SentenceTransformer
从本地路径加载模型。 - 如果路径不存在,尝试从模型名称加载模型。
- 如果所有路径都无法加载模型,抛出
RuntimeError
并给出解决建议。
- 定义一个包含多个模型路径的列表
_load_knowledge_base
方法
python
def _load_knowledge_base(self):
"""加载知识库"""
print("正在加载知识库...")
for root, _, files in os.walk(self.knowledge_base_path):
for file in files:
if file.endswith(".txt"):
file_path = os.path.join(root, file)
try:
with open(file_path, 'r', encoding='utf-8') as f:
self.knowledge_base[file] = f.read()
except Exception as e:
print(f"加载文件 {file_path} 失败: {e}")
print(f"已加载 {len(self.knowledge_base)} 个知识文件")
- 功能:从指定路径加载所有
.txt
文件作为知识库。 - 步骤:
- 使用
os.walk
遍历指定路径下的所有文件。 - 筛选出
.txt
文件,读取文件内容并存储到self.knowledge_base
字典中。 - 打印加载的文件数量。
- 使用
_vectorize_knowledge
方法
python
def _vectorize_knowledge(self):
"""向量化知识库"""
print("正在向量化知识库...")
for key, content in self.knowledge_base.items():
self.vectorized_knowledge[key] = self.model.encode(content)
print("知识库向量化完成")
- 功能:使用加载的模型对知识库中的文本进行向量化处理。
- 步骤:
- 遍历
self.knowledge_base
字典,对每个文件的内容进行编码。 - 将编码结果存储到
self.vectorized_knowledge
字典中。 - 打印向量化完成的信息。
- 遍历
_get_most_relevant_knowledge
方法
python
def _get_most_relevant_knowledge(self, query: str, top_k: int = 3) -> List[Tuple[str, float]]:
"""获取最相关知识"""
query_vector = self.model.encode(query)
similarities = []
for key, vector in self.vectorized_knowledge.items():
similarity = np.dot(query_vector, vector) / (np.linalg.norm(query_vector) * np.linalg.norm(vector))
similarities.append((key, similarity))
return sorted(similarities, key=lambda x: x[1], reverse=True)[:top_k]
- 功能:根据用户的问题,找出知识库中最相关的
top_k
个知识。 - 步骤:
- 对用户的问题进行编码,得到查询向量
query_vector
。 - 计算查询向量与知识库中每个文件向量的余弦相似度。
- 将文件名和相似度组成元组,存储到
similarities
列表中。 - 对
similarities
列表按相似度降序排序,取前top_k
个结果。
- 对用户的问题进行编码,得到查询向量
_call_deepseek_api
方法
python
def _call_deepseek_api(self, context: str, query: str) -> str:
"""调用DeepSeek API"""
headers = {
"Authorization": f"Bearer {self.deepseek_api_key}",
"Content-Type": "application/json"
}
prompt = f"""基于以下上下文回答问题:
【上下文】
{context}
【问题】
{query}
请给出专业、准确的回答:"""
try:
response = requests.post(
"https://api.deepseek.com/v1/chat/completions",
json={
"model": "deepseek-chat",
"messages": [{"role": "user", "content": prompt}],
"temperature": 0.7,
"max_tokens": 1000
},
headers=headers,
timeout=30
)
response.raise_for_status()
return response.json()["choices"][0]["message"]["content"]
except Exception as e:
print(f"API调用失败: {e}")
return "无法获取回答,请检查网络连接和API密钥。"
- 功能:调用 DeepSeek API,结合相关知识和用户问题生成回答。
- 步骤:
- 设置请求头,包含 API 密钥和内容类型。
- 构造请求的提示信息,包含上下文和用户问题。
- 使用
requests.post
方法发送请求到 DeepSeek API。 - 检查响应状态码,如果正常,返回 API 的回答;否则,打印错误信息并返回错误提示。
query
方法
python
def query(self, question: str) -> str:
"""查询知识库"""
try:
start_time = time.time()
relevant = self._get_most_relevant_knowledge(question)
if not relevant:
return "未找到相关信息。"
context = "\n\n".join(f"【{k}】\n{self.knowledge_base[k]}" for k, _ in relevant)
answer = self._call_deepseek_api(context, question)
print(f"查询耗时: {time.time()-start_time:.2f}秒")
return answer
except Exception as e:
return f"查询失败: {str(e)}"
- 功能:处理用户的查询请求,返回回答。
- 步骤:
- 记录查询开始时间。
- 调用
_get_most_relevant_knowledge
方法找出最相关的知识。 - 如果没有找到相关知识,返回提示信息。
- 构造上下文信息,调用
_call_deepseek_api
方法生成回答。 - 记录查询结束时间,打印查询耗时。
- 返回回答,如果出现异常,返回错误信息。
main
函数
python
def main():
# 配置参数
KNOWLEDGE_BASE_PATH = r"D:\06_Python\20250328_Graph_knowledge\laws"
DEEPSEEK_API_KEY = "XXXX" # 替换为你的API密钥
try:
print("初始化知识库系统...")
kb = KnowledgeBaseSystem(KNOWLEDGE_BASE_PATH, DEEPSEEK_API_KEY)
print("系统已就绪,输入问题开始查询('退出'结束)")
while True:
try:
q = input("\n问题: ").strip()
if q.lower() in ['退出', 'exit', 'quit']:
break
if q:
print("\n回答:", kb.query(q))
except KeyboardInterrupt:
print("\n输入'退出'结束程序")
continue
except Exception as e:
print(f"系统初始化失败: {str(e)}")
finally:
print("系统已关闭")
- 功能:程序的入口函数,初始化知识库系统并处理用户的查询请求。
- 步骤:
- 设置知识库路径和 DeepSeek API 密钥。
- 初始化
KnowledgeBaseSystem
类的实例。 - 进入循环,等待用户输入问题。
- 如果用户输入
退出
、exit
或quit
,退出循环。 - 如果用户输入有效问题,调用
query
方法获取回答并打印。 - 处理异常,确保系统关闭时打印关闭信息。
程序入口
python
if __name__ == "__main__":
main()
- 确保代码作为脚本直接运行时,调用
main
函数。
使用说明
- 确保已经安装了所需的库:
numpy
、requests
和sentence-transformers
。 - 将
XXXX
替换为实际的本地模型路径和 DeepSeek API 密钥。 - 将知识库文件(
.txt
格式)放在指定的路径下。 - 运行脚本,按照提示输入问题进行查询。输入
退出
、exit
或quit
结束程序。
完整代码(需添加本地模型路径及deep seek的API)
import os
import numpy as np
import requests
from typing import Dict, List, Tuple
import time
from sentence_transformers import SentenceTransformer
class KnowledgeBaseSystem:
def __init__(self, knowledge_base_path: str, deepseek_api_key: str):
self.knowledge_base_path = knowledge_base_path
self.deepseek_api_key = deepseek_api_key
self.model = self._load_local_model() # 加载本地模型
self.knowledge_base: Dict[str, str] = {}
self.vectorized_knowledge: Dict[str, np.ndarray] = {}
self._load_knowledge_base()
self._vectorize_knowledge()
def _load_local_model(self):
"""从本地文件加载模型"""
model_paths = [
XXXX, # 本地模型路径
os.path.join("models", "paraphrase-multilingual-MiniLM-L12-v2"),
os.path.join("models", "all-MiniLM-L6-v2"),
"all-MiniLM-L6-v2" # 最后尝试从缓存加载
]
for path in model_paths:
try:
if os.path.exists(path):
print(f"尝试从本地加载模型: {path}")
return SentenceTransformer(path)
else:
print(f"尝试加载模型: {path} (未找到本地文件)")
return SentenceTransformer(path.split('/')[-1]) # 尝试从名称加载
except Exception as e:
print(f"加载模型 {path} 失败: {e}")
continue
raise RuntimeError("""
无法加载任何模型,请按以下步骤操作:
1. 手动下载模型文件:
- 访问 https://huggingface.co/sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2
- 点击"↓"按钮下载整个仓库
- 解压到项目目录下的 models/ 文件夹中
2. 或者运行以下命令自动下载(需要有网络连接):
python -c "from sentence_transformers import SentenceTransformer; SentenceTransformer('all-MiniLM-L6-v2')"
""")
def _load_knowledge_base(self):
"""加载知识库"""
print("正在加载知识库...")
for root, _, files in os.walk(self.knowledge_base_path):
for file in files:
if file.endswith(".txt"):
file_path = os.path.join(root, file)
try:
with open(file_path, 'r', encoding='utf-8') as f:
self.knowledge_base[file] = f.read()
except Exception as e:
print(f"加载文件 {file_path} 失败: {e}")
print(f"已加载 {len(self.knowledge_base)} 个知识文件")
def _vectorize_knowledge(self):
"""向量化知识库"""
print("正在向量化知识库...")
for key, content in self.knowledge_base.items():
self.vectorized_knowledge[key] = self.model.encode(content)
print("知识库向量化完成")
def _get_most_relevant_knowledge(self, query: str, top_k: int = 3) -> List[Tuple[str, float]]:
"""获取最相关知识"""
query_vector = self.model.encode(query)
similarities = []
for key, vector in self.vectorized_knowledge.items():
similarity = np.dot(query_vector, vector) / (np.linalg.norm(query_vector) * np.linalg.norm(vector))
similarities.append((key, similarity))
return sorted(similarities, key=lambda x: x[1], reverse=True)[:top_k]
def _call_deepseek_api(self, context: str, query: str) -> str:
"""调用DeepSeek API"""
headers = {
"Authorization": f"Bearer {self.deepseek_api_key}",
"Content-Type": "application/json"
}
prompt = f"""基于以下上下文回答问题:
【上下文】
{context}
【问题】
{query}
请给出专业、准确的回答:"""
try:
response = requests.post(
"https://api.deepseek.com/v1/chat/completions",
json={
"model": "deepseek-chat",
"messages": [{"role": "user", "content": prompt}],
"temperature": 0.7,
"max_tokens": 1000
},
headers=headers,
timeout=30
)
response.raise_for_status()
return response.json()["choices"][0]["message"]["content"]
except Exception as e:
print(f"API调用失败: {e}")
return "无法获取回答,请检查网络连接和API密钥。"
def query(self, question: str) -> str:
"""查询知识库"""
try:
start_time = time.time()
relevant = self._get_most_relevant_knowledge(question)
if not relevant:
return "未找到相关信息。"
context = "\n\n".join(f"【{k}】\n{self.knowledge_base[k]}" for k, _ in relevant)
answer = self._call_deepseek_api(context, question)
print(f"查询耗时: {time.time()-start_time:.2f}秒")
return answer
except Exception as e:
return f"查询失败: {str(e)}"
def main():
# 配置参数
KNOWLEDGE_BASE_PATH = r"D:\06_Python\20250328_Graph_knowledge\laws"
DEEPSEEK_API_KEY = "XXXX" # 替换为你的API密钥
try:
print("初始化知识库系统...")
kb = KnowledgeBaseSystem(KNOWLEDGE_BASE_PATH, DEEPSEEK_API_KEY)
print("系统已就绪,输入问题开始查询('退出'结束)")
while True:
try:
q = input("\n问题: ").strip()
if q.lower() in ['退出', 'exit', 'quit']:
break
if q:
print("\n回答:", kb.query(q))
except KeyboardInterrupt:
print("\n输入'退出'结束程序")
continue
except Exception as e:
print(f"系统初始化失败: {str(e)}")
finally:
print("系统已关闭")
if __name__ == "__main__":
main()