当前位置: 首页 > article >正文

LlamaIndex中使用本地LLM和Embedding

LlamaIndex默认会调用OpenAI的text-davinci-002模型对应的API,用于获得大模型输出,这种方式在很多情况下对国内用户不太方便,如果本地有大模型可以部署,可以按照以下方式在LlamaIndex中使用本地的LLM和Embedding(这里LLM使用chatglm2-6b,Embedding使用m3e-base):

import torch
from transformers import AutoModel, AutoTokenizer
from llama_index.llms import HuggingFaceLLM
from llama_index import VectorStoreIndex, ServiceContext
from llama_index import LangchainEmbedding, ServiceContext
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
from llama_index import VectorStoreIndex, SimpleDirectoryReader
from llama_index import Prompt, PromptHelper
from llama_index.node_parser import SimpleNodeParser
from llama_index.langchain_helpers.text_splitter import TextSplitter, TokenTextSplitter
from llama_index import set_global_service_context

# 需要使用GPU才能运行
device = 'cuda'

# 自定义输入大模型的prompt
TEMPLATE_STR = """我们在下面提供了上下文信息

{context_str}
根据此信息,请回答问题:{query_str}
"""
QA_TEMPLATE = Prompt(TEMPLATE_STR)

# 加载本地LLM,需提供本地LLM模型文件的路径
llm_tokenizer = AutoTokenizer.from_pretrained('/models/chatglm2-6b/', trust_remote_code=True, device=device)
llm_model = AutoModel.from_pretrained('/models/chatglm2-6b/', trust_remote_code=True, device=device)
chatglm2 = HuggingFaceLLM(model=llm_model, tokenizer=llm_tokenizer)

# 加载本地Embedding,需提供本地Embedding模型文件的路径
embed_tokenizer = AutoTokenizer.from_pretrained('/models/moka-ai/m3e-base/', trust_remote_code=True, device=device)
embed_model = LangchainEmbedding(
                HuggingFaceEmbeddings(model_name='/models/moka-ai/m3e-base/'), 
                tokenizer=embed_tokenizer)

node_parser = SimpleNodeParser(text_splitter=TokenTextSplitter(tokenizer=embed_tokenizer))
prompt_helper = PromptHelper(tokenizer=llm_tokenizer)
service_context = ServiceContext.from_defaults(
                llm=chatglm2, 
                prompt_helper=prompt_helper, 
                embed_model=embed_model, 
                node_parser=node_parser)
set_global_service_context(service_context)

documents = SimpleDirectoryReader('/path/to/your/files').load_data()
index = VectorStoreIndex.from_documents(documents, service_context=service_context)

 # 查询引擎
query_engine = index.as_query_engine(text_qa_template=QA_TEMPLATE)
# 聊天引擎
chat_eigine = index.as_chat_engine()

response = query_engine.query("your question")
print(response)

http://www.kler.cn/a/557725.html

相关文章:

  • 图表控件Aspose.Diagram入门教程:使用 Python 将 VSDX 转换为 PDF
  • QEMU源码全解析 —— 内存虚拟化(17)
  • LeetCode 热题 100 283. 移动零
  • 【JT/T 808协议】808 协议开发笔记 ② ( 终端注册 | 终端注册应答 | 字符编码转换网站 )
  • 软件集成测试的技术要求
  • AF3 _parse_template_hit_files类解读
  • python使用httpx_sse调用sse流式接口对响应格式为application/json的错误信息的处理
  • 零基础学QT、C++(六)制作桌面摄像头软件
  • 计算机考研复试上机07
  • EVM系区块链开发网节点搭建及测试详细文档
  • unordered_map和 unordered_set
  • 20250221 NLP
  • 基于C++ Qt的图形绘制与XML序列化系统
  • HW面试经验分享 | 北京蓝中研判岗
  • 【Java】File 类
  • 水果生鲜农产品推荐系统 协同过滤余弦函数推荐水果生鲜农产品 Springboot Vue Element-UI前后端分离 代码+开发文档+视频教程
  • WPF实现打印机控制及打印
  • ACWing蓝桥杯集训·每日一题2025-6122. 农夫约翰的奶酪块-Java
  • malloc如何分配内存
  • 区块链相关方法-SWOT分析