wow-rag学习|搞定模型
本文为Datawhale 开源项目wow-rag的学习笔记与分享,仅供参考
开源项目:链接
可以使用Llmam-index,Langchain来做RAG,这里使用Llama-index来做。在干活前,需要准备一个llm模型和一个embedding模型 想要借助Llama-index构建llm和embedding模型。
大体上有四种思路:
- 第一个思路:使用Llama-index为各个厂家构建的服务,比如Llama-index为智谱和零一万物构建了专门的包,我们可以直接安装使用。
- 第二个思路:如果Llama-index没有为某个厂家构建服务,我们可以借助Llama-index为openai构建的库。只要我们国内的模型是openai兼容的,我们可以稍微修改一下源码就可以直接使用
- 第三个思路:我们可以利用Llama-index提供的自定义类来定义模型。
- 第四个思路:我们可以在本地安装Ollama,在本地安装好模型,然后再Llama-index中使用Ollama的服务。
思路一
用Llama-index为智谱构建的专门的包,直接安装最新版本即可
import os
from dotenv import load_dotenv
# 加载环境变量
load_dotenv()
# 从环境变量中读取api_key
api_key = os.getenv('ZHIPU_API_KEY')
base_url = "https://open.bigmodel.cn/api/paas/v4/"
chat_model = "glm-4-flash"
emb_model = "embedding-2"
配置对话模型
from llama_index.llms.zhipuai import ZhipuAI
llm = ZhipuAI(
api_key = api_key,
model = chat_model,
)
测试对话模型
配置嵌入模型
# 配置嵌入模型
from llama_index.embeddings.zhipuai import ZhipuAIEmbedding
embedding = ZhipuAIEmbedding(
api_key = api_key,
model = emb_model,
)
# 测试嵌入模型
emb = embedding.get_text_embedding("你好呀呀")
len(emb), type(emb)
#(1024, list)
思路二
我们借助llama_index的OpenAI接口使用其他厂家的模型,通过翻阅源码,发现llama_index把OpenAIEmbedding的模型名称写死在代码里面了,他会检查每个模型的输入上下文大小,如果模型没有在他的列表中,就会报错,所以我们可以重写一下llama_index中的OpenAI类,通过构建一个NewOpenAI类,并继承OpenAI类,我们直接把输入上下文写死,不让他检查。github源码地址
重写OpenAI类:
from llama_index.llms.openai import OpenAI
from llama_index.core.base.llms.types import LLMMetadata, MessageRole
# 复用原生OpenAI类的所有功能(如API调用、参数处理),仅修改元数据校验部分。
class NewOpenAI(OpenAI):# 继承原生OpenAI类
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) # 保留父类初始化逻辑
# 重写元数据属性
@property
def metadata(self) -> LLMMetadata:
# 创建一个新的LLMMetadata实例,只修改context_window
return LLMMetadata(
context_window=8192, # 强制指定上下文长度为8192
num_output=self.max_tokens or -1, # 继承父类的输出token数
is_chat_model=True, # 标记为聊天模型
is_function_calling_model=True, #支持函数调用
model_name=self.model,
system_role=MessageRole.USER, # 系统角色设为USER
)
- 绕过校验:原生代码会根据模型名称查找预设的context_window,此处直接写死该值,避免因模型名不在列表而报错。
- 参数选择:8192是常见大模型的上下文长度(如GPT-4),需根据实际模型调整。
重写完后,我们用NewOpenAI这个类来配置llm。
llm = NewOpenAI(
temperature=0.95,
api_key = api_key,
model = chat_model,
api_base = base_url # 注意这里单词不一样
)
response = llm.complete("你是谁?")
print(response)
思路三
自定义可以利用openai-like的包,来封装任何openai类似的大模型 这个思路的缺点很明显,只有对话模型,没有嵌入模型。
对话模型可以直接使用
from llama_index.llms.openai_like import OpenAILike
llm = OpenAILike(
model = chat_model,
api_base = base_url,
api_key = api_key,
is_chat_model=True
)
response=llm.complete("你是谁")
print(response)
自定义对话模型
# 导入必要的库和模块
from openai import OpenAI
from pydantic import Field # 导入Field,用于Pydantic模型中定义字段的元数据
from typing import Optional, List, Mapping, Any, Generator
import os
from llama_index.core import SimpleDirectoryReader, SummaryIndex
from llama_index.core.callbacks import CallbackManager
from llama_index.core.llms import (
CustomLLM,
CompletionResponse,
CompletionResponseGen,
LLMMetadata,
)
from llama_index.core.llms.callbacks import llm_completion_callback
from llama_index.core import Settings
# 定义OurLLM类,继承自CustomLLM基类
class OurLLM(CustomLLM):
api_key: str = Field(default=api_key)
base_url: str = Field(default=base_url)
model_name: str = Field(default=chat_model)
client: OpenAI = Field(default=None, exclude=True) # 显式声明 client 字段
def __init__(self, api_key: str, base_url: str, model_name: str = chat_model, **data: Any):
super().__init__(**data)
self.api_key = api_key
self.base_url = base_url
self.model_name = model_name
self.client = OpenAI(api_key=self.api_key, base_url=self.base_url) # 使用传入的api_key和base_url初始化 client 实例
@property
def metadata(self) -> LLMMetadata:
"""Get LLM metadata."""
return LLMMetadata(
model_name=self.model_name,
)
@llm_completion_callback()
def complete(self, prompt: str, **kwargs: Any) -> CompletionResponse:
response = self.client.chat.completions.create(model=self.model_name, messages=[{"role": "user", "content": prompt}])
if hasattr(response, 'choices') and len(response.choices) > 0:
response_text = response.choices[0].message.content
return CompletionResponse(text=response_text)
else:
raise Exception(f"Unexpected response format: {response}")
@llm_completion_callback()
def stream_complete(
self, prompt: str, **kwargs: Any
) -> Generator[CompletionResponse, None, None]:
response = self.client.chat.completions.create(
model=self.model_name,
messages=[{"role": "user", "content": prompt}],
stream=True
)
try:
for chunk in response:
chunk_message = chunk.choices[0].delta
if not chunk_message.content:
continue
content = chunk_message.content
yield CompletionResponse(text=content, delta=content)
except Exception as e:
raise Exception(f"Unexpected response format: {e}")
# 测试对话模型
llm = OurLLM(api_key=api_key, base_url=base_url, model_name=chat_model)
response = llm.complete("你是谁?")
print(response)
自定义嵌入模型
from openai import OpenAI
from typing import Any, List
from llama_index.core.embeddings import BaseEmbedding
from pydantic import Field
class OurEmbeddings(BaseEmbedding):
api_key: str = Field(default=api_key)
base_url: str = Field(default=base_url)
model_name: str = Field(default=emb_model)
client: OpenAI = Field(default=None, exclude=True) # 显式声明 client 字段
def __init__(
self,
api_key: str = api_key,
base_url: str = base_url,
model_name: str = emb_model,
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
self.api_key = api_key
self.base_url = base_url
self.model_name = model_name
self.client = OpenAI(api_key=self.api_key, base_url=self.base_url)
def invoke_embedding(self, query: str) -> List[float]:
response = self.client.embeddings.create(model=self.model_name, input=[query])
# 检查响应是否成功
if response.data and len(response.data) > 0:
return response.data[0].embedding
else:
raise ValueError("Failed to get embedding from ZhipuAI API")
def _get_query_embedding(self, query: str) -> List[float]:
return self.invoke_embedding(query)
def _get_text_embedding(self, text: str) -> List[float]:
return self.invoke_embedding(text)
def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]:
return [self._get_text_embedding(text) for text in texts]
async def _aget_query_embedding(self, query: str) -> List[float]:
return self._get_query_embedding(query)
async def _aget_text_embedding(self, text: str) -> List[float]:
return self._get_text_embedding(text)
async def _aget_text_embeddings(self, texts: List[str]) -> List[List[float]]:
return self._get_text_embeddings(texts)
embedding = OurEmbeddings(api_key=api_key, base_url=base_url, model_name=emb_model)
emb = embedding.get_text_embedding("你好呀")
len(emb), type(emb)
思路四
在本地安装Ollama