RAG选择合适的向量数据库,完整案列,使用百川词嵌入模型与向量数据库lanceDB,智谱清言大模型整合
这里我们使用用LanceDB
import os
import lancedb
from langchain_community.document_loaders import TextLoader
from langchain_community.embeddings import BaichuanTextEmbeddings
from langchain_community.vectorstores import LanceDB
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnableParallel, RunnablePassthrough
from langchain_openai import ChatOpenAI
from langchain_text_splitters import RecursiveCharacterTextSplitter
# 百川的apikey
os.environ['BAICHUAN_API_KEY'] = 'sk-fb5ef65021a207c17bf9f772839fbd16'
loader = TextLoader('state_of_the_union.txt', encoding='utf8')
documents = loader.load()
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=100,
chunk_overlap=0,
length_function=len,
is_separator_regex=False,
separators=[
"\n\n",
"\n",
".",
"?",
"!",
"。",
"!",
"?",
",",
",",
" "
]
)
docs = text_splitter.split_documents(documents)
print('=======', len(docs))
# 使用百川的词嵌入模型
embeddings = BaichuanTextEmbeddings()
# 连接向量数据库,向量数据库用的是lancedb
# 在当前目录下面创建一个文件夹名为lanceDB,作为存储向量数据
connect = lancedb.connect(os.path.join(os.getcwd(), 'lanceDB')) # 本地目录存储向量
vectorStore = LanceDB.from_documents(docs, embeddings, connection=connect, table_name='my_vectors')
query = '今年长三角铁路春游运输共经历多少天?'
# 测试一下向量数据库
# docs_and_score = vectorStore.similarity_search_with_score(query)
# for doc, score in docs_and_score:
# print('-------------------------')
# print('Score: ', score)
# print("Content: ", doc.page_content)
# 和大语言模型整合
retriever = vectorStore.as_retriever()
template = """Answer the question based only on the following context:
{context}
Question: {question}
"""
prompt = ChatPromptTemplate.from_template(template)
# 创建模型,使用都是智谱清言的大模型
model = ChatOpenAI(
model='glm-4-0520',
api_key='112534ffd84245e6a1ffe7df3a790289.9rVJE1mN7zswWDjU',
base_url='https://open.bigmodel.cn/api/paas/v4/'
)
output_parser = StrOutputParser()
# 把检索器和用户输入的问题,结合得到检索结果
start_retriever = RunnableParallel({'context': retriever, 'question': RunnablePassthrough()})
# 创建长链
chain = start_retriever | prompt | model | output_parser
res = chain.invoke(query)
print(res)