大语言模型(LLM)和嵌入模型的统一调用接口
ChatModelFactory、EmbeddingModelFactory
讲解代码:import os
from dotenv import load_dotenv, find_dotenv
_ = load_dotenv(find_dotenv())
from langchain_openai import ChatOpenAI, OpenAIEmbeddings, AzureChatOpenAI, AzureOpenAIEmbeddings
class ChatModelFactory:
model_params = {
"temperature": 0,
"seed": 42,
}
@classmethod
def get_model(cls, model_name: str, use_azure: bool = False):
if "gpt" in model_name:
if not use_azure:
return ChatOpenAI(model=model_name, **cls.model_params)
else:
return AzureChatOpenAI(
azure_deployment=model_name,
api_version="2024-05-01-preview",
**cls.model_params
)
elif model_name == "deepseek":
# 换成开源模型试试
# https://siliconflow.cn/
# 一个 Model-as-a-Service 平台
# 可以通过与 OpenAI API 兼容的方式调用各种开源语言模型。
return ChatOpenAI(
model="deepseek-ai/DeepSeek-V2-Chat", # 模型名称
openai_api_key=os.getenv("SILICONFLOW_API_KEY"), # 在平台注册账号后获取
openai_api_base="https://api.siliconflow.cn/v1", # 平台 API 地址
**cls.model_params,
)
@classmethod
def get_default_model(cls):
return cls.get_model("gpt-3.5-turbo")
class EmbeddingModelFactory:
@classmethod
def get_model(cls, model_name: str, use_azure: bool = False):
if model_name.startswith("text-embedding"):
if not use_azure:
return OpenAIEmbeddings(model=model_name)
else:
return AzureOpenAIEmbeddings(
azure_deployment=model_name,
openai_api_version="2024-05-01-preview",
)
else:
raise NotImplementedError(f"Model {model_name} not implemented.")
@classmethod
def get_default_model(cls):
return cls.get_model("text-embedding-ada-002")
以下是这段代码的逐层解析,该代码实现了大语言模型(LLM)和嵌入模型的统一调用接口,支持多种云服务和开源模型:
一、环境初始化
import os
from dotenv import load_dotenv, find_dotenv
# 加载.env文件中的环境变量
_ = load_dotenv(find_dotenv())
- 关键技术:
dotenv
库自动寻找项目根目录下的.env
文件(包括上级目录)- 将环境变量(如API Key)注入系统环境变量
- 典型
.env
内容:OPENAI_API_KEY = "sk-xxx" AZURE_OPENAI_KEY = "azure-xxx" SILICONFLOW_API_KEY = "sf-xxx"
二、聊天模型工厂(ChatModelFactory)
1. 基础配置
class ChatModelFactory:
model_params = {
"temperature": 0, # 控制生成随机性(0=确定性最高)
"seed": 42, # 固定随机种子保证可复现性
}
- 核心参数:
temperature
: 影响生成多样性(0为最保守,1为最随机)seed
: 确保相同输入得到相同输出(对测试和调试至关重要)
2. 模型选择逻辑
@classmethod
def get_model(cls, model_name: str, use_azure: bool = False):
# OpenAI官方服务
if "gpt" in model_name:
if not use_azure:
return ChatOpenAI(model=model_name, **cls.model_params)
else:
return AzureChatOpenAI(
azure_deployment=model_name, # Azure部署名称
api_version="2024-05-01-preview", # 最新API版本
**cls.model_params
)
# 第三方开源模型平台
elif model_name == "deepseek":
return ChatOpenAI(
model="deepseek-ai/DeepSeek-V2-Chat",
openai_api_key=os.getenv("SILICONFLOW_API_KEY"),
openai_api_base="https://api.siliconflow.cn/v1",
**cls.model_params,
)
-
多服务支持:
服务类型 适用场景 关键参数 OpenAI官方 直接使用OpenAI的API model="gpt-3.5-turbo"
Azure OpenAI 企业级Azure云服务 azure_deployment
SiliconFlow平台 调用DeepSeek等国产开源模型 自定义API Base URL -
设计亮点:
- 统一接口:不同服务使用相同调用方式(均继承自
BaseChatModel
) - 无缝切换:通过
use_azure
布尔参数切换云服务商 - 开箱即用:预置最新API版本(
2024-05-01-preview
)
- 统一接口:不同服务使用相同调用方式(均继承自
3. 默认模型配置
@classmethod
def get_default_model(cls):
return cls.get_model("gpt-3.5-turbo")
- 最佳实践:将
gpt-3.5-turbo
设为默认模型(性价比与性能平衡)
三、嵌入模型工厂(EmbeddingModelFactory)
1. 模型分发逻辑
class EmbeddingModelFactory:
@classmethod
def get_model(cls, model_name: str, use_azure: bool = False):
if model_name.startswith("text-embedding"):
if not use_azure:
return OpenAIEmbeddings(model=model_name)
else:
return AzureOpenAIEmbeddings(
azure_deployment=model_name,
openai_api_version="2024-05-01-preview",
)
else:
raise NotImplementedError(f"Model {model_name} not implemented.")
- 模型识别:通过
text-embedding
前缀判断为嵌入模型 - 服务兼容性:
- OpenAI:直接指定模型名称(如
text-embedding-ada-002
) - Azure:使用部署名称(需与Azure门户中的部署名一致)
- OpenAI:直接指定模型名称(如
2. 默认配置
@classmethod
def get_default_model(cls):
return cls.get_model("text-embedding-ada-002")
- 行业标准:
text-embedding-ada-002
是OpenAI最常用的嵌入模型
四、关键技术点解析
1. 环境变量管理
- 安全实践:通过
.env
文件隔离敏感信息 - 多平台支持:
即使使用非OpenAI服务,也复用# SiliconFlow示例 openai_api_key=os.getenv("SILICONFLOW_API_KEY") openai_api_base="https://api.siliconflow.cn/v1"
openai_
前缀的参数名,保持接口统一
2. 工厂模式优势
优势 | 具体体现 |
---|---|
解耦合 | 业务代码无需关心模型初始化细节 |
扩展性 | 新增模型只需修改工厂类 |
配置集中化 | 所有模型参数在单个类中管理 |
3. 多云服务兼容
# Azure的特殊参数
AzureChatOpenAI(
azure_deployment="gpt-35-turbo", # 部署名称(可能与模型名不同)
api_version="2024-05-01-preview" # 必须指定
)
- 注意差异:Azure的部署名称可能与OpenAI模型名称不同(如
gpt-35-turbo
对应OpenAI的gpt-3.5-turbo
)
五、使用示例
1. 基础调用
# 获取默认聊天模型
chat_model = ChatModelFactory.get_default_model()
# 获取Azure版嵌入模型
embed_model = EmbeddingModelFactory.get_model(
"text-embedding-ada-002",
use_azure=True
)
2. 完整工作流
from langchain.schema import HumanMessage
# 初始化模型
model = ChatModelFactory.get_model("gpt-4", use_azure=False)
# 执行对话
messages = [HumanMessage(content="你好!")]
response = model.invoke(messages)
print(response.content)
六、扩展建议
1. 添加新模型支持
# 在ChatModelFactory中添加
elif model_name == "moonshot":
return ChatOpenAI(
model="moonshot-v1",
openai_api_key=os.getenv("MOONSHOT_API_KEY"),
openai_api_base="https://api.moonshot.cn/v1",
**cls.model_params
)
2. 动态参数配置
# 允许运行时修改参数
@classmethod
def get_model(cls, model_name: str, **kwargs):
params = {**cls.model_params, **kwargs}
return ChatOpenAI(model=model_name, **params)
3. 错误处理增强
try:
return AzureChatOpenAI(...)
except AuthenticationError as e:
print("Azure认证失败,请检查API Key和部署名称")
七、架构图示
该代码为构建多模型AI应用提供了标准化基础设施,开发者可以通过简单的参数切换实现:
- 不同云服务商之间的迁移
- 开源模型与商业模型的混合使用
- 嵌入模型与对话模型的协同工作
@Classmethod
@classmethod
是 Python 中一个重要的装饰器,用于定义类方法。它的核心作用是为类提供一种不依赖实例却能操作类本身或类层级数据的途径。以下是逐层解析:
一、核心特性
1. 方法签名
class MyClass:
@classmethod
def my_method(cls, arg1, arg2): # 第一个参数必须是类本身(惯例命名为`cls`)
pass
cls
参数:自动接收当前类(而非实例)作为第一个参数- 调用方式:可通过
类名.方法名()
或实例.方法名()
调用
2. 与普通实例方法的对比
方法类型 | 第一个参数 | 可访问内容 | 典型用途 |
---|---|---|---|
实例方法 | self | 实例属性/方法 | 操作实例级数据 |
类方法 | cls | 类属性/其他类方法 | 工厂模式、类配置操作 |
二、核心应用场景
1. *工厂模式(核心代码中的用法)
class ChatModelFactory:
@classmethod
def get_model(cls, model_name: str): # 不创建实例即可获取模型
if model_name == "gpt":
return cls._create_openai_model()
else:
return cls._create_custom_model()
# 直接通过类调用
model = ChatModelFactory.get_model("gpt")
- 优势:无需实例化工厂类即可创建目标对象
- 扩展性:子类可重写方法改变创建逻辑
2. 替代构造函数
class Date:
def __init__(self, year, month, day):
self.year = year
@classmethod
def from_string(cls, date_str): # 接受不同格式的输入
year, month, day = map(int, date_str.split("-"))
return cls(year, month, day) # 调用主构造函数
# 使用示例
date = Date.from_string("2023-08-01")
3. 类状态管理
class Counter:
_count = 0
@classmethod
def increment(cls):
cls._count += 1 # 修改类属性
@classmethod
def get_count(cls):
return cls._count # 读取类属性
Counter.increment()
print(Counter.get_count()) # 输出: 1
三、继承场景下的行为
class Animal:
@classmethod
def identify(cls):
print(f"I am {cls.__name__}")
class Dog(Animal):
pass
Animal.identify() # 输出: I am Animal
Dog.identify() # 输出: I am Dog 🐛 cls自动指向实际调用类
- 多态性:子类调用父类方法时,
cls
指向子类 - 应用价值:实现继承链中的工厂方法
四、与 @staticmethod
的对比
特性 | @classmethod | @staticmethod |
---|---|---|
第一个参数 | 自动传入类 (cls ) | 无自动参数 |
访问类属性 | ✅ 可直接通过cls 访问 | ❌ 需通过类名硬编码访问 |
调用其他类方法 | ✅ 可通过cls.method() | ❌ 需硬编码类名 |
典型应用场景 | 需要类引用的操作 | 工具函数 |
代码示例对比
class Utilities:
@classmethod
def class_util(cls):
print(f"Called from {cls.__name__}")
@staticmethod
def static_util():
print("Generic utility")
class SubUtilities(Utilities):
pass
SubUtilities.class_util() # 输出: Called from SubUtilities
SubUtilities.static_util() # 输出: Generic utility
五、在示例代码中的作用
class ChatModelFactory:
model_params = {"temperature": 0} # 类属性
@classmethod
def get_model(cls, model_name: str):
return ChatOpenAI(**cls.model_params) # 通过cls访问类属性
- 统一配置管理:通过
cls.model_params
确保所有方法使用相同配置 - 无需实例化:直接
ChatModelFactory.get_model()
调用 - 扩展性:子类可重写
model_params
改变默认参数
六、最佳实践原则
- 优先用于工厂方法:当需要创建类实例但需封装复杂逻辑时
- 避免滥用:不需要访问类状态时使用
@staticmethod
- 命名规范:始终将第一个参数命名为
cls
(非强制但强烈建议) - 继承友好设计:通过
cls
实现多态,而非硬编码类名
掌握 @classmethod
的使用,能大幅提升面向对象代码的灵活性和可维护性,尤其在需要类层级操作而非实例层级操作的场景下表现出色。