毕昇入门学习
schemas.py
概述
这段代码主要定义了一系列基于 Pydantic 的数据模型(BaseModel
),用于数据验证和序列化,通常用于构建 API(如使用 FastAPI)。这些模型涵盖了用户认证、聊天消息、知识库管理、模型配置等多个方面。此外,还定义了一些通用的响应模型和辅助函数,以标准化 API 的响应格式。
详细解析
导入部分
from datetime import datetime
from enum import Enum
from typing import Any, Dict, Generic, List, Optional, TypeVar, Union
from uuid import UUID
from langchain.docstore.document import Document
from orjson import orjson
from pydantic import BaseModel, Field, validator, root_validator
from bisheng.database.models.assistant import AssistantBase
from bisheng.database.models.finetune import TrainMethod
from bisheng.database.models.flow import FlowCreate, FlowRead
from bisheng.database.models.gpts_tools import GptsToolsRead, AuthMethod, AuthType
from bisheng.database.models.knowledge import KnowledgeRead
from bisheng.database.models.llm_server import LLMServerBase, LLMModelBase, LLMServerType, LLMModelType
from bisheng.database.models.message import ChatMessageRead
from bisheng.database.models.tag import Tag
- 标准库导入:
datetime
:处理日期和时间。Enum
:定义枚举类型。typing
:用于类型注解,增强代码可读性和可维护性。UUID
:处理唯一标识符。
- 第三方库导入:
langchain.docstore.document.Document
:可能用于文档存储和检索。orjson
:高性能的 JSON 库,用于快速序列化和反序列化 JSON 数据。pydantic
:用于数据验证和设置管理,广泛用于 FastAPI。
- 项目内部模块导入:
- 从
bisheng.database.models.*
导入多个 ORM 模型,代表数据库中的不同实体(如助手、微调方法、流程、工具、知识库、LLM 服务器和模型、聊天消息、标签等)。
- 从
数据模型
以下是代码中定义的主要数据模型和相关函数的详细解释。
基础输入模型
-
验证码输入
class CaptchaInput(BaseModel): captcha_key: str captcha: str
- 用途:用于用户提交验证码时的输入验证。
- 字段:
captcha_key
:验证码的唯一标识符。captcha
:用户输入的验证码内容。
-
分块输入
class ChunkInput(BaseModel): knowledge_id: int documents: List[Document]
- 用途:用于将文档分块处理的输入。
- 字段:
knowledge_id
:知识库的唯一标识符。documents
:文档列表,每个文档为langchain.docstore.document.Document
类型。
枚举类型
-
构建状态
class BuildStatus(Enum): """Status of the build.""" SUCCESS = 'success' FAILURE = 'failure' STARTED = 'started' IN_PROGRESS = 'in_progress'
- 用途:表示构建过程的状态。
- 枚举值:
SUCCESS
:构建成功。FAILURE
:构建失败。STARTED
:构建已启动。IN_PROGRESS
:构建进行中。
图数据模型
-
图数据
class GraphData(BaseModel): """Data inside the exported flow.""" nodes: List[Dict[str, Any]] edges: List[Dict[str, Any]]
- 用途:表示导出流程中的图数据。
- 字段:
nodes
:节点列表,每个节点为字典类型,包含节点的详细信息。edges
:边列表,每条边为字典类型,描述节点之间的连接关系。
-
导出流程
class ExportedFlow(BaseModel): """Exported flow from bisheng.""" description: str name: str id: str data: GraphData
- 用途:表示从
bisheng
导出的流程数据。 - 字段:
description
:流程描述。name
:流程名称。id
:流程的唯一标识符。data
:流程的图数据,类型为GraphData
。
- 用途:表示从
请求模型
-
输入请求
class InputRequest(BaseModel): input: str = Field(description='question or command asked LLM to do')
- 用途:表示用户向 LLM(大型语言模型)提出的问题或命令。
- 字段:
input
:用户的输入内容(问题或命令)。
-
调整请求
class TweaksRequest(BaseModel): tweaks: Optional[Dict[str, Dict[str, str]]] = Field(default_factory=dict)
- 用途:表示对某些设置或参数的调整请求。
- 字段:
tweaks
:一个可选的字典,嵌套字典用于描述具体的调整内容。
-
更新模板请求
class UpdateTemplateRequest(BaseModel): template: dict
- 用途:用于更新模板的请求。
- 字段:
template
:模板内容,类型为字典。
通用响应模型
-
统一响应模型
DataT = TypeVar('DataT') class UnifiedResponseModel(Generic[DataT], BaseModel): """统一响应模型""" status_code: int status_message: str data: DataT = None
- 用途:提供一个通用的 API 响应结构,适用于各种类型的数据。
- 泛型:
DataT
:数据的具体类型,可以是任意类型。
- 字段:
status_code
:响应状态码(如 200、500 等)。status_message
:响应状态信息(如 “SUCCESS”、“BAD REQUEST” 等)。data
:响应的数据,类型为泛型DataT
。
-
成功响应函数
def resp_200(data: Union[list, dict, str, Any] = None, message: str = 'SUCCESS') -> UnifiedResponseModel: """成功的代码""" return UnifiedResponseModel(status_code=200, status_message=message, data=data)
- 用途:生成一个状态码为 200 的成功响应。
- 参数:
data
:响应的数据,可以是列表、字典、字符串或任意类型。message
:响应的状态信息,默认值为 “SUCCESS”。
-
错误响应函数
def resp_500(code: int = 500, data: Union[list, dict, str, Any] = None, message: str = 'BAD REQUEST') -> UnifiedResponseModel: """错误的逻辑回复""" return UnifiedResponseModel(status_code=code, status_message=message, data=data)
- 用途:生成一个错误响应,默认状态码为 500。
- 参数:
code
:错误状态码,默认值为 500。data
:响应的数据,可以是列表、字典、字符串或任意类型。message
:响应的状态信息,默认值为 “BAD REQUEST”。
处理响应模型
-
处理响应
class ProcessResponse(BaseModel): """Process response schema.""" result: Any # task: Optional[TaskResponse] = None session_id: Optional[str] = None backend: Optional[str] = None
- 用途:表示处理请求后的响应结构。
- 字段:
result
:处理结果,类型为任意类型。session_id
:会话标识符,可选。backend
:后端标识,可选。
聊天相关模型
-
聊天输入
class ChatInput(BaseModel): message_id: int comment: str = None liked: int = 0
- 用途:表示用户对某条聊天消息的输入(如评论、点赞)。
- 字段:
message_id
:消息的唯一标识符。comment
:用户的评论,可选。liked
:点赞数,默认值为 0。
-
添加聊天消息
class AddChatMessages(BaseModel): """Add a pair of chat messages.""" flow_id: UUID # 技能或助手ID chat_id: str # 会话ID human_message: str = None # 用户问题 answer_message: str = None # 执行结果
-
用途:用于向聊天会话中添加一对消息(用户消息和助手回复)。
-
字段
:
flow_id
:技能或助手的唯一标识符,类型为UUID
。chat_id
:会话的唯一标识符。human_message
:用户的问题,可选。answer_message
:助手的回复,可选。
-
-
聊天列表
class ChatList(BaseModel): """Chat message list.""" flow_name: str = None flow_description: str = None flow_id: UUID = None chat_id: str = None create_time: datetime = None update_time: datetime = None flow_type: str = None # flow: 技能 assistant:gpts助手 latest_message: ChatMessageRead = None logo: Optional[str] = None
- 用途:表示一个聊天会话的简要信息,用于展示聊天列表。
- 字段:
flow_name
:技能或助手的名称。flow_description
:技能或助手的描述。flow_id
:技能或助手的唯一标识符,类型为UUID
。chat_id
:会话的唯一标识符。create_time
:会话的创建时间。update_time
:会话的更新时间。flow_type
:技能类型,flow
表示技能,assistant
表示 GPTs 助手。latest_message
:最新的聊天消息,类型为ChatMessageRead
。logo
:技能或助手的 logo 地址,可选。
-
在线流程列表
class FlowGptsOnlineList(BaseModel): id: str = Field('唯一ID') name: str = None desc: str = None logo: str = None create_time: datetime = None update_time: datetime = None flow_type: str = None # flow: 技能 assistant:gpts助手 count: int = 0
- 用途:表示在线流程(技能或助手)的列表信息。
- 字段:
id
:流程的唯一标识符。name
:流程名称。desc
:流程描述。logo
:流程的 logo 地址。create_time
:流程的创建时间。update_time
:流程的更新时间。flow_type
:流程类型,flow
表示技能,assistant
表示 GPTs 助手。count
:关联的某些数量(具体含义需结合业务逻辑理解)。
-
聊天消息
class ChatMessage(BaseModel): """Chat message schema.""" is_bot: bool = False message: Union[str, None, dict] = '' type: str = 'human' category: str = 'processing' # system processing answer tool intermediate_steps: str = None files: list = [] user_id: int = None message_id: int = None source: int = 0 sender: str = None receiver: dict = None liked: int = 0 extra: str = '{}' flow_id: str = None chat_id: str = None
- 用途:表示一条聊天消息的详细信息。
- 字段:
is_bot
:是否为机器人消息,默认值为False
。message
:消息内容,可以是字符串、字典或None
。type
:消息类型,默认值为'human'
(表示用户消息)。category
:消息分类,如'system'
、'processing'
、'answer'
、'tool'
。intermediate_steps
:中间步骤信息,可选。files
:关联的文件列表,默认空列表。user_id
:用户的唯一标识符,可选。message_id
:消息的唯一标识符,可选。source
:消息来源,默认值为0
。sender
:发送者信息,可选。receiver
:接收者信息,字典类型,可选。liked
:点赞数,默认值为0
。extra
:额外信息,默认值为'{}'
。flow_id
:关联的流程(技能或助手)的唯一标识符,可选。chat_id
:关联的聊天会话的唯一标识符,可选。
-
聊天响应
class ChatResponse(ChatMessage): """Chat response schema.""" intermediate_steps: str = '' type: str is_bot: bool = True files: list = [] category: str = 'processing' @validator('type') def validate_message_type(cls, v): """ end_cover: 结束并覆盖上一条message """ if v not in ['start', 'stream', 'end', 'error', 'info', 'file', 'begin', 'close', 'end_cover']: raise ValueError('type must be start, stream, end, error, info, or file') return v
- 用途:表示机器人(助手)的聊天响应消息。
- 继承自:
ChatMessage
,并对部分字段进行了重写或扩展。 - 字段:
intermediate_steps
:中间步骤信息,默认值为空字符串。type
:消息类型,必须在指定的枚举值中。is_bot
:是否为机器人消息,固定为True
。files
:关联的文件列表,默认空列表。category
:消息分类,固定为'processing'
。
- 验证器:
validate_message_type
:确保type
字段的值在指定的枚举值范围内,否则抛出错误。
-
文件响应
class FileResponse(ChatMessage): """File response schema.""" data: Any data_type: str type: str = 'file' is_bot: bool = True @validator('data_type') def validate_data_type(cls, v): if v not in ['image', 'csv']: raise ValueError('data_type must be image or csv') return v
- 用途:表示机器人发送的文件类型消息。
- 继承自:
ChatMessage
,并对部分字段进行了重写或扩展。 - 字段:
data
:文件数据,类型为任意类型。data_type
:文件类型,必须是'image'
或'csv'
。type
:消息类型,固定为'file'
。is_bot
:是否为机器人消息,固定为True
。
- 验证器:
validate_data_type
:确保data_type
字段的值为'image'
或'csv'
,否则抛出错误。
流程相关模型
-
流程列表创建
class FlowListCreate(BaseModel): flows: List[FlowCreate]
- 用途:用于批量创建流程(技能或助手)。
- 字段:
flows
:流程创建请求的列表,类型为List[FlowCreate]
。
-
流程列表读取
class FlowListRead(BaseModel): flows: List[FlowRead]
- 用途:用于批量读取流程的信息。
- 字段:
flows
:流程读取响应的列表,类型为List[FlowRead]
。
-
初始化响应
class InitResponse(BaseModel): flowId: str
- 用途:表示初始化操作后的响应,通常返回一个流程的唯一标识符。
- 字段:
flowId
:流程的唯一标识符。
-
构建响应
class BuiltResponse(BaseModel): built: bool
- 用途:表示构建操作的结果。
- 字段:
built
:布尔值,表示是否构建成功。
-
上传文件响应
class UploadFileResponse(BaseModel): """Upload file response schema.""" flowId: Optional[str] file_path: str relative_path: Optional[str] # minio的相对路径,即object_name
- 用途:表示文件上传后的响应。
- 字段:
flowId
:关联的流程(技能或助手)的唯一标识符,可选。file_path
:文件的完整路径。relative_path
:文件在存储系统(如 MinIO)中的相对路径,即对象名称,可选。
-
流数据
class StreamData(BaseModel): event: str data: dict | str def __str__(self) -> str: if isinstance(self.data, dict): return f'event: {self.event}\ndata: {orjson.dumps(self.data).decode()}\n\n' return f'event: {self.event}\ndata: {self.data}\n\n'
- 用途:表示流式数据传输的结构,常用于服务器发送事件(SSE)。
- 字段:
event
:事件类型。data
:事件数据,可以是字典或字符串。
- 方法:
__str__
:将StreamData
对象转换为字符串格式,适用于流式传输。
微调相关模型
-
微调创建请求
class FinetuneCreateReq(BaseModel): server: int = Field(description='关联的RT服务ID') base_model: int = Field(description='基础模型ID') model_name: str = Field(max_length=50, description='模型名称') method: TrainMethod = Field(description='训练方法') extra_params: Dict = Field(default_factory=dict, description='训练任务所需额外参数') train_data: Optional[List[Dict]] = Field(default=None, description='个人训练数据') preset_data: Optional[List[Dict]] = Field(default=None, description='预设训练数据')
- 用途:表示创建微调任务的请求。
- 字段:
server
:关联的 RT 服务的唯一标识符。base_model
:基础模型的唯一标识符。model_name
:模型名称,最大长度为 50。method
:训练方法,类型为TrainMethod
(枚举类型)。extra_params
:训练任务所需的额外参数,默认为空字典。train_data
:个人训练数据,可选。preset_data
:预设训练数据,可选。
组件相关模型
-
创建组件请求
class CreateComponentReq(BaseModel): name: str = Field(max_length=50, description='组件名称') data: Any = Field(default='', description='组件数据') description: Optional[str] = Field(default='', description='组件描述')
- 用途:表示创建组件的请求。
- 字段:
name
:组件名称,最大长度为 50。data
:组件的数据,类型为任意类型,默认值为空字符串。description
:组件描述,可选,默认值为空字符串。
-
自定义组件代码
class CustomComponentCode(BaseModel): code: str field: Optional[str] = None frontend_node: Optional[dict] = None
- 用途:表示自定义组件的代码。
- 字段:
code
:组件的代码。field
:可选字段,可能用于描述组件的特定字段。frontend_node
:前端节点的信息,类型为字典,可选。
助手相关模型
-
创建助手请求
class AssistantCreateReq(BaseModel): name: str = Field(max_length=50, description='助手名称') prompt: str = Field(min_length=20, max_length=1000, description='助手提示词') logo: str = Field(description='logo文件的相对地址')
- 用途:表示创建助手的请求。
- 字段:
name
:助手名称,最大长度为 50。prompt
:助手的提示词,最小长度为 20,最大长度为 1000。logo
:logo 文件的相对地址。
-
更新助手请求
class AssistantUpdateReq(BaseModel): id: UUID = Field(description='助手ID') name: Optional[str] = Field('', description='助手名称, 为空则不更新') desc: Optional[str] = Field('', description='助手描述, 为空则不更新') logo: Optional[str] = Field('', description='logo文件的相对地址,为空则不更新') prompt: Optional[str] = Field('', description='用户可见prompt, 为空则不更新') guide_word: Optional[str] = Field('', description='开场白, 为空则不更新') guide_question: Optional[List] = Field([], description='引导问题列表, 为空则不更新') model_name: Optional[str] = Field('', description='选择的模型名, 为空则不更新') temperature: Optional[float] = Field(None, description='模型温度, 不传则不更新') tool_list: List[int] | None = Field(default=None, description='助手的工具ID列表,空列表则清空绑定的工具,为None则不更新') flow_list: List[str] | None = Field(default=None, description='助手的技能ID列表,为None则不更新') knowledge_list: List[int] | None = Field(default=None, description='知识库ID列表,为None则不更新')
- 用途:表示更新助手的请求。
- 字段:
id
:助手的唯一标识符,类型为UUID
。name
:助手名称,可选,默认为空字符串,空则不更新。desc
:助手描述,可选,默认为空字符串,空则不更新。logo
:logo 文件的相对地址,可选,默认为空字符串,空则不更新。prompt
:用户可见的提示词,可选,默认为空字符串,空则不更新。guide_word
:开场白,可选,默认为空字符串,空则不更新。guide_question
:引导问题列表,可选,默认为空列表,空则不更新。model_name
:选择的模型名称,可选,默认为空字符串,空则不更新。temperature
:模型温度,可选,默认为None
,不传则不更新。tool_list
:助手的工具 ID 列表,可选,None
表示不更新,空列表表示清空绑定的工具。flow_list
:助手的技能 ID 列表,可选,None
表示不更新。knowledge_list
:知识库 ID 列表,可选,None
表示不更新。
-
助手简单信息
class AssistantSimpleInfo(BaseModel): id: UUID name: str desc: str logo: str user_id: int user_name: str status: int write: Optional[bool] = Field(default=False) group_ids: Optional[List[int]] tags: Optional[List[Tag]] create_time: datetime update_time: datetime
- 用途:表示助手的基本信息,用于展示列表或概要信息。
- 字段:
id
:助手的唯一标识符,类型为UUID
。name
:助手名称。desc
:助手描述。logo
:助手的 logo 地址。user_id
:创建者的用户 ID。user_name
:创建者的用户名。status
:助手的状态,类型为整数(具体含义需结合业务逻辑理解)。write
:是否可写,默认值为False
。group_ids
:关联的用户组 ID 列表,可选。tags
:关联的标签列表,类型为List[Tag]
,可选。create_time
:创建时间。update_time
:更新时间。
-
助手信息
class AssistantInfo(AssistantBase): tool_list: List[GptsToolsRead] = Field(default=[], description='助手的工具ID列表') flow_list: List[FlowRead] = Field(default=[], description='助手的技能ID列表') knowledge_list: List[KnowledgeRead] = Field(default=[], description='知识库ID列表')
- 用途:表示助手的详细信息,继承自
AssistantBase
(数据库模型)。 - 字段:
tool_list
:助手的工具列表,类型为List[GptsToolsRead]
。flow_list
:助手的技能列表,类型为List[FlowRead]
。knowledge_list
:知识库列表,类型为List[KnowledgeRead]
。
- 用途:表示助手的详细信息,继承自
流程版本与对比模型
-
流程版本创建
class FlowVersionCreate(BaseModel): name: Optional[str] = Field(default=..., description="版本的名字") description: Optional[str] = Field(default=None, description="版本的描述") data: Optional[Dict] = Field(default=None, description='技能版本的节点数据数据') original_version_id: Optional[int] = Field(default=None, description="版本的来源版本ID")
- 用途:表示创建流程版本的请求。
- 字段:
name
:版本名称。description
:版本描述。data
:技能版本的节点数据,类型为字典,可选。original_version_id
:来源版本的 ID,可选。
-
流程对比请求
class FlowCompareReq(BaseModel): inputs: Any = Field(default=None, description='技能运行所需要的输入') question_list: List[str] = Field(default=[], description='测试case列表') version_list: List[int] = Field(default=[], description='对比版本ID列表') node_id: str = Field(default=None, description='需要对比的节点唯一ID') thread_num: Optional[int] = Field(default=1, description='对比线程数')
- 用途:表示对比不同版本流程的请求。
- 字段:
inputs
:技能运行所需的输入,类型为任意类型。question_list
:测试用例列表,类型为List[str]
。version_list
:对比的版本 ID 列表,类型为List[int]
。node_id
:需要对比的节点的唯一标识符,类型为字符串,可选。thread_num
:对比的线程数,默认值为1
。
工具类型与测试模型
-
删除工具类型请求
class DeleteToolTypeReq(BaseModel): tool_type_id: int = Field(description="要删除的工具类别ID")
- 用途:表示删除工具类别的请求。
- 字段:
tool_type_id
:要删除的工具类别的唯一标识符。
-
测试工具请求
class TestToolReq(BaseModel): server_host: str = Field(default='', description="服务的根地址") extra: str = Field(default='', description="Api 对象解析后的extra字段") auth_method: int = Field(default=AuthMethod.NO.value, description="认证类型") auth_type: Optional[str] = Field(default=AuthType.BASIC.value, description="Auth Type") api_key: Optional[str] = Field(default='', description="api key") request_params: Dict = Field(default=None, description="用户填写的请求参数")
- 用途:表示测试工具接口的请求。
- 字段:
server_host
:服务的根地址,默认值为空字符串。extra
:API 对象解析后的额外字段,默认值为空字符串。auth_method
:认证方法,默认值为AuthMethod.NO.value
。auth_type
:认证类型,默认值为AuthType.BASIC.value
,可选。api_key
:API 密钥,默认值为空字符串,可选。request_params
:用户填写的请求参数,类型为字典,可选。
用户与角色模型
-
用户组与角色
class GroupAndRoles(BaseModel): group_id: int role_ids: List[int]
- 用途:表示用户所属的组和角色。
- 字段:
group_id
:用户组的唯一标识符。role_ids
:角色的唯一标识符列表。
-
创建用户请求
class CreateUserReq(BaseModel): user_name: str = Field(max_length=30, description='用户名') password: str = Field(description='密码') group_roles: List[GroupAndRoles] = Field(description='要加入的用户组和角色列表')
- 用途:表示创建新用户的请求。
- 字段:
user_name
:用户名,最大长度为 30。password
:用户密码。group_roles
:用户所属的用户组和角色列表,类型为List[GroupAndRoles]
。
OpenAI 聊天模型相关
-
OpenAI 聊天完成请求
class OpenAIChatCompletionReq(BaseModel): messages: List[dict] = Field(..., description="聊天消息列表,只支持user、assistant。system用数据库内的数据") model: str = Field(..., description="助手的唯一ID") n: int = Field(default=1, description="返回的答案个数, 助手侧默认为1,暂不支持多个回答") stream: bool = Field(default=False, description="是否开启流式回复") temperature: float = Field(default=0.0, description="模型温度, 传入0或者不传表示不覆盖") tools: List[dict] = Field(default=[], description="工具列表, 助手暂不支持,使用助手的配置")
- 用途:表示向 OpenAI 的聊天完成 API 发送请求的结构。
- 字段:
messages
:聊天消息列表,类型为List[dict]
,只支持user
和assistant
,system
消息使用数据库内的数据。model
:助手的唯一标识符。n
:返回的答案个数,默认值为1
,当前不支持多个回答。stream
:是否开启流式回复,默认值为False
。temperature
:模型温度,默认值为0.0
,传入0
或不传表示不覆盖。tools
:工具列表,默认值为空列表,目前助手暂不支持,使用助手的配置。
-
OpenAI 选择
class OpenAIChoice(BaseModel): index: int = Field(..., description="选项的索引") message: dict = Field(default=None, description="对应的消息内容,和输入的格式一致") finish_reason: str = Field(default='stop', description="结束原因, 助手只有stop") delta: dict = Field(default=None, description="对应的openai流式返回消息内容")
- 用途:表示 OpenAI 聊天完成 API 返回的单个选择项。
- 字段:
index
:选项的索引。message
:对应的消息内容,格式与输入一致。finish_reason
:结束原因,助手只有'stop'
。delta
:对应的 OpenAI 流式返回消息内容。
-
OpenAI 聊天完成响应
class OpenAIChatCompletionResp(BaseModel): id: str = Field(..., description="请求的唯一ID") object: str = Field(default='chat.completion', description="返回的类型") created: int = Field(default=..., description="返回的创建时间戳") model: str = Field(..., description="返回的模型,对应助手的id") choices: List[OpenAIChoice] = Field(..., description="返回的答案列表") usage: dict = Field(default=None, description="返回的token用量, 助手此值为空") system_fingerprint: Optional[str] = Field(default=None, description="系统指纹")
- 用途:表示 OpenAI 聊天完成 API 的响应结构。
- 字段:
id
:请求的唯一标识符。object
:返回的类型,默认值为'chat.completion'
。created
:返回的创建时间戳。model
:返回的模型名称,对应助手的 ID。choices
:返回的答案列表,类型为List[OpenAIChoice]
。usage
:返回的 token 用量,助手此值为空。system_fingerprint
:系统指纹,可选。
LLM 模型与服务器配置
-
LLM 模型创建请求
class LLMModelCreateReq(BaseModel): id: Optional[int] = Field(default=None, description="模型唯一ID, 更新时需要传") name: str = Field(..., description="模型展示名称") description: Optional[str] = Field(default='', description="模型描述") model_name: str = Field(..., description="模型名称") model_type: str = Field(..., description="模型类型") online: bool = Field(default=True, description='是否在线') config: Optional[dict] = Field(default=None, description="模型配置")
- 用途:表示创建或更新 LLM 模型的请求。
- 字段:
id
:模型的唯一标识符,更新时需要传。name
:模型的展示名称。description
:模型描述,可选,默认值为空字符串。model_name
:模型名称。model_type
:模型类型。online
:是否在线,默认值为True
。config
:模型的配置,类型为字典,可选。
-
LLM 服务器创建请求
class LLMServerCreateReq(BaseModel): id: Optional[int] = Field(default=None, description="服务提供方ID, 更新时需要传") name: str = Field(..., description="服务提供方名称") description: Optional[str] = Field(default='', description="服务提供方描述") type: str = Field(..., description="服务提供方类型") limit_flag: Optional[bool] = Field(default=False, description="是否开启每日调用次数限制") limit: Optional[int] = Field(default=0, description="每日调用次数限制") config: Optional[dict] = Field(default=None, description="服务提供方配置") models: Optional[List[LLMModelCreateReq]] = Field(default=[], description="服务提供方下的模型列表")
- 用途:表示创建或更新 LLM 服务器的请求。
- 字段:
id
:服务提供方的唯一标识符,更新时需要传。name
:服务提供方名称。description
:服务提供方描述,可选,默认值为空字符串。type
:服务提供方类型。limit_flag
:是否开启每日调用次数限制,默认值为False
。limit
:每日调用次数限制,默认值为0
。config
:服务提供方的配置,类型为字典,可选。models
:服务提供方下的模型列表,类型为List[LLMModelCreateReq]
,默认值为空列表。
-
LLM 模型信息
class LLMModelInfo(LLMModelBase): id: Optional[int]
- 用途:表示 LLM 模型的详细信息,继承自
LLMModelBase
(数据库模型)。 - 字段:
id
:模型的唯一标识符,可选。
- 用途:表示 LLM 模型的详细信息,继承自
-
LLM 服务器信息
class LLMServerInfo(LLMServerBase): id: Optional[int] models: List[LLMModelInfo] = Field(default=[], description="模型列表")
- 用途:表示 LLM 服务器的详细信息,继承自
LLMServerBase
(数据库模型)。 - 字段:
id
:服务器的唯一标识符,可选。models
:服务器下的模型列表,类型为List[LLMModelInfo]
,默认值为空列表。
- 用途:表示 LLM 服务器的详细信息,继承自
知识库配置模型
-
知识库 LLM 配置
class KnowledgeLLMConfig(BaseModel): embedding_model_id: Optional[int] = Field(description="知识库默认embedding模型的ID") source_model_id: Optional[int] = Field(description="知识库溯源模型的ID") extract_title_model_id: Optional[int] = Field(description="文档知识库提取标题模型的ID") qa_similar_model_id: Optional[int] = Field(description="QA知识库相似问模型的ID")
- 用途:表示知识库的 LLM 配置。
- 字段:
embedding_model_id
:默认的 embedding 模型 ID,可选。source_model_id
:溯源模型 ID,可选。extract_title_model_id
:文档知识库提取标题模型的 ID,可选。qa_similar_model_id
:QA 知识库相似问模型的 ID,可选。
-
助手 LLM 项目
class AssistantLLMItem(BaseModel): model_id: Optional[int] = Field(description="模型的ID") agent_executor_type: Optional[str] = Field(default="ReAct", description="执行模式。function call 或者 ReAct") knowledge_max_content: Optional[int] = Field(default=15000, description="知识库检索最大字符串数") knowledge_sort_index: Optional[bool] = Field(default=False, description="知识库检索后是否重排") streaming: Optional[bool] = Field(default=True, description="是否开启流式") default: Optional[bool] = Field(default=False, description="是否为默认模型")
- 用途:表示助手的单个 LLM 配置项。
- 字段:
model_id
:模型的唯一标识符,可选。agent_executor_type
:执行模式,默认值为"ReAct"
,可选,其他值如"function call"
。knowledge_max_content
:知识库检索的最大字符串数,默认值为15000
。knowledge_sort_index
:知识库检索后是否重排,默认值为False
。streaming
:是否开启流式回复,默认值为True
。default
:是否为默认模型,默认值为False
。
-
助手 LLM 配置
class AssistantLLMConfig(BaseModel): llm_list: Optional[List[AssistantLLMItem]] = Field(default=[], description="助手可选的LLM列表") auto_llm: Optional[AssistantLLMItem] = Field(description="助手画像自动优化模型的配置")
- 用途:表示助手的 LLM 配置。
- 字段:
llm_list
:助手可选的 LLM 列表,类型为List[AssistantLLMItem]
,默认值为空列表。auto_llm
:助手画像自动优化模型的配置,类型为AssistantLLMItem
,可选。
-
评测 LLM 配置
class EvaluationLLMConfig(BaseModel): model_id: Optional[int] = Field(description="评测功能默认模型的ID")
- 用途:表示评测功能的 LLM 配置。
- 字段:
model_id
:评测功能默认模型的 ID,可选。
文件处理模型
-
文件处理基础请求
class FileProcessBase(BaseModel): knowledge_id: int = Field(..., description="知识库ID") separator: Optional[List[str]] = Field(default=['\n\n'], description="切分文本规则, 不传则为默认") separator_rule: Optional[List[str]] = Field(default=['after'], description="切分规则前还是后进行切分;before/after") chunk_size: Optional[int] = Field(default=1000, description="切分文本长度,不传则为默认") chunk_overlap: Optional[int] = Field(default=100, description="切分文本重叠长度,不传则为默认") @root_validator def check_separator_rule(cls, values): if values['separator'] is None: values['separator'] = ['\n\n'] if values['separator_rule'] is None: values['separator_rule'] = ['after' for _ in values["separator"]] if values['chunk_size'] is None: values['chunk_size'] = 1000 if values['chunk_overlap'] is None: values['chunk_overlap'] = 100 return values
- 用途:表示文件处理的基础请求参数,用于文件切分。
- 字段:
knowledge_id
:知识库的唯一标识符。separator
:切分文本的规则,默认值为['\n\n']
,可选。separator_rule
:切分规则前还是后进行切分,默认值为['after']
,可选。chunk_size
:切分文本的长度,默认值为1000
,可选。chunk_overlap
:切分文本的重叠长度,默认值为100
,可选。
- 验证器:
check_separator_rule
:确保所有可选字段在未提供时设置为默认值。
-
文件分块元数据
class FileChunkMetadata(BaseModel): source: str = Field(default='', description="源文件名") title: str = Field(default='', description="源文件内容总结的标题") chunk_index: int = Field(default=0, description="文本块索引") bbox: str = Field(default='', description="文本块bbox信息") page: int = Field(default=0, description="文本块所在页码") extra: str = Field(default='', description="文本块额外信息") file_id: int = Field(default=0, description="文本块所属文件ID")
- 用途:表示文件分块的元数据信息。
- 字段:
source
:源文件名,默认值为空字符串。title
:源文件内容总结的标题,默认值为空字符串。chunk_index
:文本块的索引,默认值为0
。bbox
:文本块的 bounding box 信息,默认值为空字符串。page
:文本块所在的页码,默认值为0
。extra
:文本块的额外信息,默认值为空字符串。file_id
:文本块所属文件的唯一标识符,默认值为0
。
-
文件分块
class FileChunk(BaseModel): text: str = Field(..., description="文本块内容") parse_type: Optional[str] = Field(default=None, description="文本所属的文件解析类型") metadata: FileChunkMetadata = Field(..., description="文本块元数据")
- 用途:表示文件的一个分块。
- 字段:
text
:文本块的内容。parse_type
:文本所属的文件解析类型,可选。metadata
:文本块的元数据,类型为FileChunkMetadata
。
-
预览文件分块请求
class PreviewFileChunk(FileProcessBase): file_path: str = Field(..., description="文件路径") cache: bool = Field(default=True, description="是否从缓存获取")
- 用途:表示预览文件分块内容的请求。
- 字段:
file_path
:文件的路径。cache
:是否从缓存获取,默认值为True
。
-
更新预览文件分块
class UpdatePreviewFileChunk(BaseModel): knowledge_id: int = Field(..., description="知识库ID") file_path: str = Field(..., description="文件路径") text: str = Field(..., description="文本块内容") chunk_index: int = Field(..., description="文本块索引, 在metadata里") bbox: Optional[str] = Field(default='', description="文本块bbox信息")
- 用途:表示更新预览文件分块内容的请求。
- 字段:
knowledge_id
:知识库的唯一标识符。file_path
:文件的路径。text
:文本块的内容。chunk_index
:文本块的索引。bbox
:文本块的 bounding box 信息,可选,默认值为空字符串。
-
知识库文件单项
class KnowledgeFileOne(BaseModel): file_path: str = Field(..., description="文件路径")
- 用途:表示知识库中单个文件的路径。
- 字段:
file_path
:文件的路径。
-
知识库文件处理
class KnowledgeFileProcess(FileProcessBase): file_list: List[KnowledgeFileOne] = Field(..., description="文件列表") callback_url: Optional[str] = Field(default=None, description="异步任务回调地址") extra: Optional[str] = Field(default=None, description="附加信息")
- 用途:表示知识库文件处理的请求。
- 字段:
file_list
:文件列表,类型为List[KnowledgeFileOne]
。callback_url
:异步任务的回调地址,可选。extra
:附加信息,可选。
总结
这段代码定义了一系列用于数据验证和序列化的模型,涵盖了用户认证、聊天管理、知识库处理、模型和服务器配置等多个方面。以下是关键点总结:
- 数据验证和序列化:通过 Pydantic 的
BaseModel
,确保 API 接收和返回的数据结构符合预期,提升代码的可靠性和可维护性。 - 通用响应结构:使用泛型的
UnifiedResponseModel
和辅助函数resp_200
、resp_500
,统一 API 的响应格式,便于前端处理和错误管理。 - 业务逻辑覆盖:
- 用户认证:如
CaptchaInput
、CreateUserReq
。 - 聊天管理:如
ChatMessage
、ChatResponse
、AddChatMessages
。 - 知识库处理:如
KnowledgeFileProcess
、FileChunk
。 - 模型和服务器配置:如
LLMModelCreateReq
、LLMServerCreateReq
。 - 流程管理:如
FlowListCreate
、FlowCompareReq
。
- 用户认证:如
- 验证器:通过
@validator
和@root_validator
,进一步确保数据的合法性和完整性。 - 类型注解:广泛使用类型注解(如
Optional
、List
、Dict
等),提高代码的可读性和静态分析工具的效果。
base.py
import uuid
from contextlib import contextmanager
from bisheng.database.service import DatabaseService
from bisheng.settings import settings
from bisheng.utils.logger import logger
from sqlmodel import Session
db_service: 'DatabaseService' = DatabaseService(settings.database_url)
@contextmanager
def session_getter() -> Session:
"""轻量级session context"""
try:
session = Session(db_service.engine)
yield session
except Exception as e:
logger.info('Session rollback because of exception:{}', e)
session.rollback()
raise
finally:
session.close()
def generate_uuid() -> str:
"""
生成uuid的字符串
"""
return uuid.uuid4().hex
1. 导入部分
标准库导入
import uuid
from contextlib import contextmanager
uuid
:- 用途:用于生成唯一标识符(UUID)。
- 常用功能:
uuid.uuid4()
:生成一个随机的 UUID。uuid.uuid4().hex
:获取 UUID 的 32 位十六进制字符串表示。
contextmanager
:- 用途:用于创建上下文管理器,通常与
with
语句一起使用。 - 功能:简化上下文管理器的创建,避免手动编写类。
- 用途:用于创建上下文管理器,通常与
项目内部模块导入
from bisheng.database.service import DatabaseService
from bisheng.settings import settings
from bisheng.utils.logger import logger
bisheng.database.service.DatabaseService
:- 用途:数据库服务类,通常封装了与数据库的连接和操作逻辑。
- 假设功能:
- 初始化数据库引擎。
- 提供数据库会话管理。
- 可能包含其他数据库相关的服务方法。
bisheng.settings.settings
:- 用途:项目的配置对象,包含了各种配置参数,如数据库 URL、API 密钥等。
- 功能:通过
settings.database_url
访问数据库连接字符串。
bisheng.utils.logger.logger
:- 用途:日志记录器,用于记录日志信息。
- 功能:
- 记录信息、警告、错误等不同级别的日志。
- 便于调试和监控应用程序运行状态。
第三方库导入
from sqlmodel import Session
sqlmodel.Session
:- 用途:用于与数据库交互的会话对象。
- 功能:
- 提供 CRUD(创建、读取、更新、删除)操作。
- 管理数据库事务。
- 处理数据库连接的打开和关闭。
2. 数据库服务实例化
db_service: 'DatabaseService' = DatabaseService(settings.database_url)
db_service
:- 类型注解:
'DatabaseService'
(使用字符串注解以避免循环导入问题)。 - 用途:创建一个
DatabaseService
的实例,用于管理与数据库的连接。 - 初始化:
- 参数:
settings.database_url
,即数据库的连接字符串。 - 功能:内部会根据
database_url
配置数据库引擎(如 SQLAlchemy 引擎)。
- 参数:
- 类型注解:
示例
假设 settings.database_url
是 "postgresql://user:password@localhost/dbname"
,那么 DatabaseService
会使用这个 URL 初始化与 PostgreSQL 数据库的连接。
3. 上下文管理器:session_getter
@contextmanager
def session_getter() -> Session:
"""轻量级session context"""
try:
session = Session(db_service.engine)
yield session
except Exception as e:
logger.info('Session rollback because of exception:{}', e)
session.rollback()
raise
finally:
session.close()
功能和用途
- 目的:提供一个安全且便捷的方式来管理数据库会话,确保会话在使用后正确关闭,并在发生异常时回滚事务。
- 使用场景:
- 在执行数据库操作时使用
with session_getter() as session:
,确保会话的正确管理。 - 处理事务,确保数据一致性和完整性。
- 在执行数据库操作时使用
工作原理
-
创建会话:
session = Session(db_service.engine)
db_service.engine
:获取数据库引擎,用于创建会话。
-
使用
yield
传递会话对象:yield session
- 允许
with
语句块内的代码使用该会话对象。
- 允许
-
异常处理:
except Exception as e: logger.info('Session rollback because of exception:{}', e) session.rollback() raise
- 记录日志:使用
logger
记录异常信息。 - 事务回滚:确保在发生异常时,事务不会部分提交,保持数据一致性。
- 重新抛出异常:让上层调用者知道发生了错误。
- 记录日志:使用
-
关闭会话:
finally: session.close()
- 确保资源释放:无论是否发生异常,都会关闭会话,释放数据库连接资源。
示例用法
from bisheng.database.utils import session_getter
def get_user(user_id: int):
with session_getter() as session:
user = session.query(User).filter(User.id == user_id).first()
return user
在这个示例中:
with session_getter() as session
:获取一个数据库会话。- 执行查询:通过会话对象执行数据库查询。
- 自动管理:退出
with
块时,会自动关闭会话,确保资源释放。
4. UUID 生成函数:generate_uuid
def generate_uuid() -> str:
"""
生成uuid的字符串
"""
return uuid.uuid4().hex
功能和用途
- 目的:生成一个唯一的 UUID 字符串,用于标识对象或记录。
- 用途:
- 为数据库记录生成唯一标识符。
- 生成会话 ID、交易 ID 等。
- 确保标识符的唯一性,避免冲突。
工作原理
-
生成 UUID:
uuid.uuid4()
uuid4
:生成一个随机的 UUID(基于随机数)。
-
获取十六进制表示:
.hex
hex
属性:将 UUID 对象转换为 32 位十六进制字符串,没有连字符(-
)。
示例
unique_id = generate_uuid()
print(unique_id) # 输出示例:e4eaaaf2dcd24d5f8e4e8c1b0b3f1f6a
-
优势
:
- 唯一性:UUID 生成的概率几乎为零,确保每个 ID 的唯一性。
- 不可预测性:基于随机数生成,难以预测下一个 UUID 的值。
注意事项
- 长度:UUID 的
hex
表示长度为 32 个字符。 - 存储:在数据库中存储时,考虑使用合适的数据类型(如
VARCHAR(32)
或CHAR(32)
)。 - 性能:UUID 的随机性可能导致数据库索引性能下降,特别是在高频率写入的情况下。可以考虑使用有序 UUID(如
UUID1
)以优化索引性能。
5. 总结
这段代码主要涵盖了以下几个关键方面:
- 数据库服务初始化:
- 使用
DatabaseService
类初始化数据库连接,基于项目配置中的database_url
。
- 使用
- 会话管理:
- 定义了一个上下文管理器
session_getter
,用于安全地管理数据库会话,确保会话在使用后正确关闭,并在异常发生时回滚事务。
- 定义了一个上下文管理器
- UUID 生成:
- 提供了一个简单的函数
generate_uuid
,用于生成唯一的 UUID 字符串,确保标识符的唯一性和不可预测性。
- 提供了一个简单的函数
如何在项目中使用
-
数据库操作:
- 使用
session_getter
上下文管理器来执行数据库操作,确保会话的正确管理和资源的释放。
from bisheng.database.utils import session_getter, generate_uuid def create_user(name: str, email: str): user_id = generate_uuid() with session_getter() as session: new_user = User(id=user_id, name=name, email=email) session.add(new_user) session.commit() return new_user
- 使用
-
生成唯一标识符:
- 使用
generate_uuid
函数为新记录生成唯一的 ID。
unique_id = generate_uuid() print(f"Generated UUID: {unique_id}")
- 使用
进一步阅读
DatabaseService
类:- 查看
bisheng.database.service
模块中的DatabaseService
类实现,了解其具体功能和方法。
- 查看
- 项目配置:
- 查看
bisheng.settings
模块,了解项目的配置管理,尤其是database_url
的设置方式。
- 查看
- 日志记录器:
- 查看
bisheng.utils.logger
模块,了解日志记录器的配置和使用方法,以便更好地调试和监控应用程序。
- 查看
sqlmodel
库:sqlmodel
是一个用于与 SQL 数据库交互的库,结合了 SQLAlchemy 和 Pydantic 的特性。建议阅读 SQLModel 官方文档 以深入理解其用法和最佳实践。
常见问题
- 为什么使用
contextmanager
装饰器创建session_getter
?- 使用
contextmanager
可以简化上下文管理器的创建,使代码更加简洁和易读。它允许使用with
语句块自动管理资源,如数据库会话的打开和关闭。
- 使用
- 为什么在异常处理中回滚事务?
- 当数据库操作过程中发生异常时,事务可能处于不一致状态。回滚事务可以确保数据库不会部分提交数据,保持数据的一致性和完整性。
- UUID 为什么使用
.hex
属性而不是直接使用uuid4()
生成的对象?- 使用
.hex
属性可以获得一个标准的 32 位十六进制字符串表示,便于存储和传输。相比直接使用UUID
对象,字符串表示更容易与其他系统集成和调试。
- 使用
UserService.py
这段代码主要涉及用户身份验证、权限管理、用户创建以及与数据库交互的服务逻辑。代码使用了 FastAPI 框架,并结合了 Pydantic、JWT、Redis 等技术。以下是对每一部分的详细解析:
import functools
import json
from base64 import b64decode
from typing import List
import rsa
from bisheng.api.errcode.base import UnAuthorizedError
from bisheng.api.errcode.user import (UserLoginOfflineError, UserNameAlreadyExistError,
UserNeedGroupAndRoleError)
from bisheng.api.JWT import ACCESS_TOKEN_EXPIRE_TIME
from bisheng.api.utils import md5_hash
from bisheng.api.v1.schemas import CreateUserReq
from bisheng.cache.redis import redis_client
from bisheng.database.models.assistant import Assistant, AssistantDao
from bisheng.database.models.flow import Flow, FlowDao, FlowRead
from bisheng.database.models.knowledge import Knowledge, KnowledgeDao, KnowledgeRead
from bisheng.database.models.role import AdminRole
from bisheng.database.models.role_access import AccessType, RoleAccessDao
from bisheng.database.models.user import User, UserDao
from bisheng.database.models.user_group import UserGroupDao
from bisheng.database.models.user_role import UserRoleDao
from bisheng.settings import settings
from bisheng.utils.constants import RSA_KEY, USER_CURRENT_SESSION
from fastapi import Depends, HTTPException, Request
from fastapi_jwt_auth import AuthJWT
1. 导入部分
标准库导入
import functools
import json
from base64 import b64decode
from typing import List
-
functools
:-
用途:用于高级函数操作,如装饰器。
-
常用功能
:
functools.wraps
:用于装饰器,保持原函数的元数据。
-
-
json
:-
用途:用于 JSON 数据的序列化和反序列化。
-
常用功能
:
json.loads
:将 JSON 字符串转换为 Python 对象。json.dumps
:将 Python 对象转换为 JSON 字符串。
-
-
base64.b64decode
:- 用途:用于解码 Base64 编码的数据。
- 应用场景:常用于处理加密数据或传输二进制数据。
-
typing.List
:- 用途:用于类型注解,表示列表类型。
- 示例:
List[int]
表示整数列表。
第三方库导入
import rsa
from fastapi import Depends, HTTPException, Request
from fastapi_jwt_auth import AuthJWT
-
rsa
:-
用途:用于 RSA 加密和解密。
-
常用功能
:
rsa.decrypt
:使用私钥解密数据。rsa.encrypt
:使用公钥加密数据。
-
-
fastapi
:-
用途:用于构建 Web API。
-
常用功能
:
Depends
:用于依赖注入。HTTPException
:用于抛出 HTTP 异常。Request
:表示请求对象。
-
-
fastapi_jwt_auth.AuthJWT
:-
用途:用于处理 JWT(JSON Web Token)认证。
-
常用功能
:
create_access_token
:创建访问令牌。create_refresh_token
:创建刷新令牌。jwt_required
:装饰器,用于保护路由,需要 JWT 认证。
-
项目内部模块导入
from bisheng.api.errcode.base import UnAuthorizedError
from bisheng.api.errcode.user import (UserLoginOfflineError, UserNameAlreadyExistError,
UserNeedGroupAndRoleError)
from bisheng.api.JWT import ACCESS_TOKEN_EXPIRE_TIME
from bisheng.api.utils import md5_hash
from bisheng.api.v1.schemas import CreateUserReq
from bisheng.cache.redis import redis_client
from bisheng.database.models.assistant import Assistant, AssistantDao
from bisheng.database.models.flow import Flow, FlowDao, FlowRead
from bisheng.database.models.knowledge import Knowledge, KnowledgeDao, KnowledgeRead
from bisheng.database.models.role import AdminRole
from bisheng.database.models.role_access import AccessType, RoleAccessDao
from bisheng.database.models.user import User, UserDao
from bisheng.database.models.user_group import UserGroupDao
from bisheng.database.models.user_role import UserRoleDao
from bisheng.settings import settings
from bisheng.utils.constants import RSA_KEY, USER_CURRENT_SESSION
-
bisheng.api.errcode.\*
:-
用途:定义自定义的 API 错误码和异常类,用于统一错误处理。
-
常用异常
:
UnAuthorizedError
:未授权错误。UserLoginOfflineError
:用户登录离线错误。UserNameAlreadyExistError
:用户名已存在错误。UserNeedGroupAndRoleError
:用户需要组和角色错误。
-
-
bisheng.api.JWT
:-
用途:存储 JWT 相关的配置常量,如访问令牌过期时间。
-
字段
:
ACCESS_TOKEN_EXPIRE_TIME
:访问令牌的过期时间。
-
-
bisheng.api.utils
:-
用途:存储各种实用函数。
-
常用函数
:
md5_hash
:生成字符串的 MD5 哈希值。
-
-
bisheng.api.v1.schemas.CreateUserReq
:- 用途:定义创建用户请求的数据模型,通常使用 Pydantic
BaseModel
。
- 用途:定义创建用户请求的数据模型,通常使用 Pydantic
-
bisheng.cache.redis.redis_client
:- 用途:Redis 客户端实例,用于与 Redis 缓存交互。
-
bisheng.database.models.\*
:-
用途:导入数据库模型和数据访问对象(DAO)。
-
常用模型和 DAO
:
Assistant
,AssistantDao
:助手模型和 DAO。Flow
,FlowDao
,FlowRead
:流程模型、DAO 和读取模型。Knowledge
,KnowledgeDao
,KnowledgeRead
:知识库模型、DAO 和读取模型。AdminRole
:管理员角色模型。AccessType
,RoleAccessDao
:访问类型和角色访问 DAO。User
,UserDao
:用户模型和 DAO。UserGroupDao
:用户组 DAO。UserRoleDao
:用户角色 DAO。
-
-
bisheng.settings.settings
:- 用途:项目的配置对象,包含了各种配置参数,如数据库 URL、RSA 密钥等。
-
bisheng.utils.constants
:-
用途:存储项目中使用的常量。
-
字段
:
RSA_KEY
:RSA 密钥的常量键。USER_CURRENT_SESSION
:用户当前会话的 Redis 键模板。
-
2. UserPayload
类
class UserPayload:
def __init__(self, **kwargs):
self.user_id = kwargs.get('user_id')
self.user_role = kwargs.get('role')
if self.user_role != 'admin': # 非管理员用户,需要获取他的角色列表
roles = UserRoleDao.get_user_roles(self.user_id)
self.user_role = [one.role_id for one in roles]
self.user_name = kwargs.get('user_name')
def is_admin(self):
if self.user_role == 'admin':
return True
if isinstance(self.user_role, list):
for one in self.user_role:
if one == AdminRole:
return True
return False
@staticmethod
def wrapper_access_check(func):
"""
权限检查的装饰器
如果是admin用户则不执行后续具体的检查逻辑
"""
@functools.wraps(func)
def wrapper(*args, **kwargs):
if args[0].is_admin():
return True
return func(*args, **kwargs)
return wrapper
@wrapper_access_check
def access_check(self, owner_user_id: int, target_id: str, access_type: AccessType) -> bool:
"""
检查用户是否有某个资源的权限
"""
# 判断是否属于本人资源
if self.user_id == owner_user_id:
return True
# 判断授权
if RoleAccessDao.judge_role_access(self.user_role, target_id, access_type):
return True
return False
@wrapper_access_check
def check_group_admin(self, group_id: int) -> bool:
"""
检查用户是否是某个组的管理员
"""
# 判断是否是用户组的管理员
user_group = UserGroupDao.get_user_admin_group(self.user_id)
if not user_group:
return False
for one in user_group:
if one.group_id == group_id:
return True
return False
@wrapper_access_check
def check_groups_admin(self, group_ids: List[int]) -> bool:
"""
检查用户是否是用户组列表中的管理员,有一个就是true
"""
user_groups = UserGroupDao.get_user_admin_group(self.user_id)
for one in user_groups:
if one.is_group_admin and one.group_id in group_ids:
return True
return False
功能和用途
UserPayload
类用于封装当前登录用户的相关信息,并提供权限检查的方法。这在构建基于角色的访问控制(RBAC)系统中非常常见。
详细解析
构造函数 __init__
def __init__(self, **kwargs):
self.user_id = kwargs.get('user_id')
self.user_role = kwargs.get('role')
if self.user_role != 'admin': # 非管理员用户,需要获取他的角色列表
roles = UserRoleDao.get_user_roles(self.user_id)
self.user_role = [one.role_id for one in roles]
self.user_name = kwargs.get('user_name')
-
功能:初始化用户的基本信息,包括
user_id
、user_role
和user_name
。 -
逻辑
:
- 从
kwargs
获取user_id
、role
和user_name
。 - 如果用户不是管理员 (
self.user_role != 'admin'
),则从数据库获取该用户的所有角色,并将其存储为角色 ID 的列表。
- 从
方法 is_admin
def is_admin(self):
if self.user_role == 'admin':
return True
if isinstance(self.user_role, list):
for one in self.user_role:
if one == AdminRole:
return True
return False
-
功能:判断用户是否具有管理员权限。
-
逻辑
:
- 如果
user_role
是'admin'
,返回True
。 - 如果
user_role
是一个列表,遍历列表,检查是否包含AdminRole
,如果包含,返回True
。 - 否则,返回
False
。
- 如果
装饰器 wrapper_access_check
@staticmethod
def wrapper_access_check(func):
"""
权限检查的装饰器
如果是admin用户则不执行后续具体的检查逻辑
"""
@functools.wraps(func)
def wrapper(*args, **kwargs):
if args[0].is_admin():
return True
return func(*args, **kwargs)
return wrapper
-
功能:定义一个装饰器,用于权限检查。
-
逻辑
:
- 如果用户是管理员,直接返回
True
,跳过后续的具体权限检查逻辑。 - 否则,执行被装饰的函数。
- 如果用户是管理员,直接返回
方法 access_check
@wrapper_access_check
def access_check(self, owner_user_id: int, target_id: str, access_type: AccessType) -> bool:
"""
检查用户是否有某个资源的权限
"""
# 判断是否属于本人资源
if self.user_id == owner_user_id:
return True
# 判断授权
if RoleAccessDao.judge_role_access(self.user_role, target_id, access_type):
return True
return False
-
功能:检查用户是否有访问某个资源的权限。
-
逻辑
:
- 如果资源的拥有者是当前用户,返回
True
。 - 否则,调用
RoleAccessDao.judge_role_access
方法检查用户角色是否有访问该资源的权限。 - 返回相应的布尔值。
- 如果资源的拥有者是当前用户,返回
方法 check_group_admin
@wrapper_access_check
def check_group_admin(self, group_id: int) -> bool:
"""
检查用户是否是某个组的管理员
"""
# 判断是否是用户组的管理员
user_group = UserGroupDao.get_user_admin_group(self.user_id)
if not user_group:
return False
for one in user_group:
if one.group_id == group_id:
return True
return False
-
功能:检查用户是否是某个用户组的管理员。
-
逻辑
:
- 调用
UserGroupDao.get_user_admin_group
获取用户所属的管理员组。 - 如果用户不属于任何管理员组,返回
False
。 - 遍历管理员组,检查是否包含指定的
group_id
,如果包含,返回True
。 - 否则,返回
False
。
- 调用
方法 check_groups_admin
@wrapper_access_check
def check_groups_admin(self, group_ids: List[int]) -> bool:
"""
检查用户是否是用户组列表中的管理员,有一个就是true
"""
user_groups = UserGroupDao.get_user_admin_group(self.user_id)
for one in user_groups:
if one.is_group_admin and one.group_id in group_ids:
return True
return False
-
功能:检查用户是否是给定用户组列表中的任意一个用户组的管理员。
-
逻辑
:
- 调用
UserGroupDao.get_user_admin_group
获取用户所属的管理员组。 - 遍历管理员组,检查是否有任何一个用户组的
group_id
在给定的group_ids
列表中,并且用户是该组的管理员 (one.is_group_admin
)。 - 如果找到符合条件的用户组,返回
True
。 - 否则,返回
False
。
- 调用
3. UserService
类
class UserService:
@classmethod
def decrypt_md5_password(cls, password: str):
if value := redis_client.get(RSA_KEY):
private_key = value[1]
password = md5_hash(rsa.decrypt(b64decode(password), private_key).decode('utf-8'))
else:
password = md5_hash(password)
return password
@classmethod
def create_user(cls, request: Request, login_user: UserPayload, req_data: CreateUserReq):
"""
创建用户
"""
exists_user = UserDao.get_user_by_username(req_data.user_name)
if exists_user:
# 抛出异常
raise UserNameAlreadyExistError.http_exception()
user = User(
user_name=req_data.user_name,
password=cls.decrypt_md5_password(req_data.password),
)
group_ids = []
role_ids = []
for one in req_data.group_roles:
group_ids.append(one.group_id)
role_ids.extend(one.role_ids)
if not group_ids or not role_ids:
raise UserNeedGroupAndRoleError.http_exception()
user = UserDao.add_user_with_groups_and_roles(user, group_ids, role_ids)
return user
功能和用途
UserService
类提供了与用户相关的业务逻辑,如密码解密和用户创建。这种类通常包含静态方法或类方法,用于处理特定的业务操作。
详细解析
方法 decrypt_md5_password
@classmethod
def decrypt_md5_password(cls, password: str):
if value := redis_client.get(RSA_KEY):
private_key = value[1]
password = md5_hash(rsa.decrypt(b64decode(password), private_key).decode('utf-8'))
else:
password = md5_hash(password)
return password
- 功能:解密密码并生成 MD5 哈希值。
- 逻辑:
- 尝试从 Redis 获取 RSA 密钥 (
RSA_KEY
)。 - 如果存在 RSA 密钥:
- 使用 RSA 私钥解密 Base64 编码的密码。
- 将解密后的密码转换为 UTF-8 字符串。
- 对解密后的密码进行 MD5 哈希。
- 如果不存在 RSA 密钥:
- 直接对原始密码进行 MD5 哈希。
- 返回哈希后的密码。
- 尝试从 Redis 获取 RSA 密钥 (
- 应用场景:
- 用户注册或密码更新时,将用户输入的密码加密存储到数据库。
- 增强密码的安全性,通过加密和哈希减少明文密码的存储风险。
方法 create_user
@classmethod
def create_user(cls, request: Request, login_user: UserPayload, req_data: CreateUserReq):
"""
创建用户
"""
exists_user = UserDao.get_user_by_username(req_data.user_name)
if exists_user:
# 抛出异常
raise UserNameAlreadyExistError.http_exception()
user = User(
user_name=req_data.user_name,
password=cls.decrypt_md5_password(req_data.password),
)
group_ids = []
role_ids = []
for one in req_data.group_roles:
group_ids.append(one.group_id)
role_ids.extend(one.role_ids)
if not group_ids or not role_ids:
raise UserNeedGroupAndRoleError.http_exception()
user = UserDao.add_user_with_groups_and_roles(user, group_ids, role_ids)
return user
- 功能:创建新用户,并将其分配到指定的用户组和角色中。
- 参数:
request
:请求对象,包含请求相关的信息。login_user
:当前登录用户的信息,类型为UserPayload
。req_data
:创建用户的请求数据,类型为CreateUserReq
。
- 逻辑:
- 检查用户名是否存在:
- 调用
UserDao.get_user_by_username
检查请求中的用户名是否已经存在。 - 如果用户已存在,抛出
UserNameAlreadyExistError
异常。
- 调用
- 创建用户实例:
- 使用
CreateUserReq
中的user_name
和解密后的密码创建一个新的User
实例。
- 使用
- 收集用户组和角色 ID:
- 遍历
req_data.group_roles
,收集所有的group_id
和role_id
。
- 遍历
- 检查用户组和角色:
- 如果用户没有指定任何用户组或角色,抛出
UserNeedGroupAndRoleError
异常。
- 如果用户没有指定任何用户组或角色,抛出
- 添加用户并关联组和角色:
- 调用
UserDao.add_user_with_groups_and_roles
方法,将用户添加到数据库,并关联指定的用户组和角色。
- 调用
- 返回创建的用户:
- 返回新创建的
User
实例。
- 返回新创建的
- 检查用户名是否存在:
- 应用场景:
- 用户注册功能,允许管理员或系统自动创建新用户,并分配相应的权限和组。
4. 辅助函数
函数 sso_login
def sso_login():
pass
- 功能:此函数目前为空,实现细节未提供。
- 用途:通常用于单点登录(SSO)功能,允许用户通过单一认证源登录多个系统。
函数 gen_user_role
def gen_user_role(db_user: User):
# 查询用户的角色列表
db_user_role = UserRoleDao.get_user_roles(db_user.user_id)
role = ''
role_ids = []
for user_role in db_user_role:
if user_role.role_id == 1:
# 是管理员,忽略其他的角色
role = 'admin'
else:
role_ids.append(user_role.role_id)
if role != 'admin':
# 判断是否是用户组管理员
db_user_groups = UserGroupDao.get_user_admin_group(db_user.user_id)
if len(db_user_groups) > 0:
role = 'group_admin'
else:
role = role_ids
# 获取用户的菜单栏权限列表
web_menu = RoleAccessDao.get_role_access(role_ids, AccessType.WEB_MENU)
web_menu = list(set([one.third_id for one in web_menu]))
return role, web_menu
- 功能:生成用户的角色和相应的菜单权限。
- 参数:
db_user
:数据库中的User
实例。
- 逻辑:
- 查询用户角色:
- 调用
UserRoleDao.get_user_roles
获取用户的角色列表。
- 调用
- 确定角色:
- 遍历用户角色列表:
- 如果用户具有角色 ID 为
1
,将角色设置为'admin'
,并忽略其他角色。 - 否则,将角色 ID 添加到
role_ids
列表中。
- 如果用户具有角色 ID 为
- 遍历用户角色列表:
- 检查用户组管理员身份:
- 如果用户不是管理员,调用
UserGroupDao.get_user_admin_group
检查用户是否是某个用户组的管理员。 - 如果是用户组管理员,设置角色为
'group_admin'
。 - 否则,角色保持为角色 ID 的列表。
- 如果用户不是管理员,调用
- 获取菜单权限:
- 调用
RoleAccessDao.get_role_access
获取用户角色对应的菜单权限。 - 提取并去重
third_id
,形成菜单权限列表web_menu
。
- 调用
- 返回角色和菜单权限:
- 返回用户的角色(字符串或列表)和菜单权限列表。
- 查询用户角色:
- 应用场景:
- 在用户登录后,生成用户的权限信息,用于前端展示和后续的权限校验。
函数 gen_user_jwt
def gen_user_jwt(db_user: User):
if 1 == db_user.delete:
raise HTTPException(status_code=500, detail='该账号已被禁用,请联系管理员')
# 查询角色
role, web_menu = gen_user_role(db_user)
# 生成JWT令牌
payload = {'user_name': db_user.user_name, 'user_id': db_user.user_id, 'role': role}
# Create the tokens and passing to set_access_cookies or set_refresh_cookies
access_token = AuthJWT().create_access_token(subject=json.dumps(payload),
expires_time=ACCESS_TOKEN_EXPIRE_TIME)
refresh_token = AuthJWT().create_refresh_token(subject=db_user.user_name)
# Set the JWT cookies in the response
return access_token, refresh_token, role, web_menu
- 功能:为用户生成 JWT 访问令牌和刷新令牌。
- 参数:
db_user
:数据库中的User
实例。
- 逻辑:
- 检查用户状态:
- 如果
db_user.delete
等于1
,表示账号被禁用,抛出HTTPException
,状态码为500
,提示用户联系管理员。
- 如果
- 获取用户角色和菜单权限:
- 调用
gen_user_role
函数,获取用户的角色和菜单权限。
- 调用
- 生成 JWT 载荷:
- 创建一个包含
user_name
、user_id
和role
的字典payload
。
- 创建一个包含
- 生成访问令牌和刷新令牌:
- 调用
AuthJWT().create_access_token
生成访问令牌,主题为payload
的 JSON 字符串,过期时间为ACCESS_TOKEN_EXPIRE_TIME
。 - 调用
AuthJWT().create_refresh_token
生成刷新令牌,主题为用户的用户名。
- 调用
- 返回令牌和权限信息:
- 返回访问令牌、刷新令牌、角色和菜单权限。
- 检查用户状态:
- 应用场景:
- 在用户成功登录后,生成并返回 JWT 令牌,用于后续的认证和授权。
函数 get_knowledge_list_by_access
def get_knowledge_list_by_access(role_id: int, name: str, page_num: int, page_size: int):
count_filter = []
if name:
count_filter.append(Knowledge.name.like('%{}%'.format(name)))
db_role_access = KnowledgeDao.get_knowledge_by_access(role_id, page_num, page_size)
total_count = KnowledgeDao.get_count_by_filter(count_filter)
# 补充用户名
user_ids = [access[0].user_id for access in db_role_access]
db_users = UserDao.get_user_by_ids(user_ids)
user_dict = {user.user_id: user.user_name for user in db_users}
return {
'data': [
KnowledgeRead.validate({
'name': access[0].name,
'user_name': user_dict.get(access[0].user_id),
'user_id': access[0].user_id,
'update_time': access[0].update_time,
'id': access[0].id
}) for access in db_role_access
],
'total':
total_count
}
- 功能:根据用户角色获取有访问权限的知识库列表,并支持分页和名称过滤。
- 参数:
role_id
:用户的角色 ID。name
:知识库名称的搜索关键字。page_num
:分页的页码。page_size
:每页显示的记录数。
- 逻辑:
- 构建过滤条件:
- 如果提供了
name
,则添加一个 SQL LIKE 过滤条件,用于模糊匹配知识库名称。
- 如果提供了
- 查询有访问权限的知识库:
- 调用
KnowledgeDao.get_knowledge_by_access
,传入role_id
、page_num
和page_size
,获取用户有权限访问的知识库列表。
- 调用
- 查询总数:
- 调用
KnowledgeDao.get_count_by_filter
,传入过滤条件,获取满足条件的知识库总数。
- 调用
- 补充用户名:
- 从知识库访问列表中提取所有的
user_id
。 - 调用
UserDao.get_user_by_ids
获取对应的用户信息。 - 构建一个字典
user_dict
,将user_id
映射到user_name
。
- 从知识库访问列表中提取所有的
- 构建响应数据:
- 遍历
db_role_access
,为每个知识库构建一个包含name
、user_name
、user_id
、update_time
和id
的字典,并通过KnowledgeRead.validate
进行数据验证和序列化。 - 返回包含
data
和total
的字典。
- 遍历
- 构建过滤条件:
- 应用场景:
- 用于前端页面展示用户有权限访问的知识库列表,并支持搜索和分页功能。
函数 get_flow_list_by_access
def get_flow_list_by_access(role_id: int, name: str, page_num: int, page_size: int):
count_filter = []
if name:
count_filter.append(Flow.name.like('%{}%'.format(name)))
db_role_access = FlowDao.get_flow_by_access(role_id, name, page_num, page_size)
total_count = FlowDao.get_count_by_filters(count_filter)
# 补充用户名
user_ids = [access[0].user_id for access in db_role_access]
db_users = UserDao.get_user_by_ids(user_ids)
user_dict = {user.user_id: user.user_name for user in db_users}
return {
'data': [
FlowRead.validate({
'name': access[0].name,
'user_name': user_dict.get(access[0].user_id),
'user_id': access[0].user_id,
'update_time': access[0].update_time,
'id': access[0].id
}) for access in db_role_access
],
'total':
total_count
}
- 功能:根据用户角色获取有访问权限的流程列表,并支持分页和名称过滤。
- 参数:
role_id
:用户的角色 ID。name
:流程名称的搜索关键字。page_num
:分页的页码。page_size
:每页显示的记录数。
- 逻辑:
- 构建过滤条件:
- 如果提供了
name
,则添加一个 SQL LIKE 过滤条件,用于模糊匹配流程名称。
- 如果提供了
- 查询有访问权限的流程:
- 调用
FlowDao.get_flow_by_access
,传入role_id
、name
、page_num
和page_size
,获取用户有权限访问的流程列表。
- 调用
- 查询总数:
- 调用
FlowDao.get_count_by_filters
,传入过滤条件,获取满足条件的流程总数。
- 调用
- 补充用户名:
- 从流程访问列表中提取所有的
user_id
。 - 调用
UserDao.get_user_by_ids
获取对应的用户信息。 - 构建一个字典
user_dict
,将user_id
映射到user_name
。
- 从流程访问列表中提取所有的
- 构建响应数据:
- 遍历
db_role_access
,为每个流程构建一个包含name
、user_name
、user_id
、update_time
和id
的字典,并通过FlowRead.validate
进行数据验证和序列化。 - 返回包含
data
和total
的字典。
- 遍历
- 构建过滤条件:
- 应用场景:
- 用于前端页面展示用户有权限访问的流程列表,并支持搜索和分页功能。
函数 get_assistant_list_by_access
def get_assistant_list_by_access(role_id: int, name: str, page_num: int, page_size: int):
count_filter = []
if name:
count_filter.append(Assistant.name.like('%{}%'.format(name)))
db_role_access = AssistantDao.get_assistants_by_access(role_id, name, page_size, page_num)
total_count = AssistantDao.get_count_by_filters(count_filter)
# 补充用户名
user_ids = [access[0].user_id for access in db_role_access]
db_users = UserDao.get_user_by_ids(user_ids)
user_dict = {user.user_id: user.user_name for user in db_users}
return {
'data': [{
'name': access[0].name,
'user_name': user_dict.get(access[0].user_id),
'user_id': access[0].user_id,
'update_time': access[0].update_time,
'id': access[0].id
} for access in db_role_access],
'total':
total_count
}
- 功能:根据用户角色获取有访问权限的助手列表,并支持分页和名称过滤。
- 参数:
role_id
:用户的角色 ID。name
:助手名称的搜索关键字。page_num
:分页的页码。page_size
:每页显示的记录数。
- 逻辑:
- 构建过滤条件:
- 如果提供了
name
,则添加一个 SQL LIKE 过滤条件,用于模糊匹配助手名称。
- 如果提供了
- 查询有访问权限的助手:
- 调用
AssistantDao.get_assistants_by_access
,传入role_id
、name
、page_size
和page_num
,获取用户有权限访问的助手列表。
- 调用
- 查询总数:
- 调用
AssistantDao.get_count_by_filters
,传入过滤条件,获取满足条件的助手总数。
- 调用
- 补充用户名:
- 从助手访问列表中提取所有的
user_id
。 - 调用
UserDao.get_user_by_ids
获取对应的用户信息。 - 构建一个字典
user_dict
,将user_id
映射到user_name
。
- 从助手访问列表中提取所有的
- 构建响应数据:
- 遍历
db_role_access
,为每个助手构建一个包含name
、user_name
、user_id
、update_time
和id
的字典。 - 返回包含
data
和total
的字典。
- 遍历
- 构建过滤条件:
- 应用场景:
- 用于前端页面展示用户有权限访问的助手列表,并支持搜索和分页功能。
5. 获取当前登录用户的依赖项
函数 get_login_user
async def get_login_user(authorize: AuthJWT = Depends()) -> UserPayload:
"""
获取当前登录的用户
"""
# 校验是否过期,过期则直接返回http 状态码的 401
authorize.jwt_required()
current_user = json.loads(authorize.get_jwt_subject())
user = UserPayload(**current_user)
# 判断是否允许多点登录
if not settings.get_system_login_method().allow_multi_login:
# 获取access_token
current_token = redis_client.get(USER_CURRENT_SESSION.format(user.user_id))
# 登录被挤下线了,http状态码是200, status_code是特殊code
if current_token != authorize._token:
raise UserLoginOfflineError.http_exception()
return user
- 功能:获取当前登录的用户信息,并进行相关的登录状态检查。
- 参数:
authorize
:AuthJWT
实例,通过 FastAPI 的Depends
注入。
- 逻辑:
- JWT 验证:
- 调用
authorize.jwt_required()
,确保请求包含有效的 JWT 令牌。 - 如果 JWT 过期或无效,自动抛出
401 Unauthorized
异常。
- 调用
- 获取 JWT 载荷:
- 调用
authorize.get_jwt_subject()
获取 JWT 的主题部分(通常是用户信息的 JSON 字符串)。 - 使用
json.loads
将 JSON 字符串转换为 Python 字典。 - 使用
UserPayload
类初始化用户信息对象。
- 调用
- 检查多点登录:
- 调用
settings.get_system_login_method().allow_multi_login
判断系统是否允许多点登录。 - 如果不允许多点登录:
- 从 Redis 获取当前用户的
access_token
,键格式为USER_CURRENT_SESSION.format(user.user_id)
。 - 比较 Redis 中存储的
current_token
与当前请求中的_token
。 - 如果两者不匹配,说明用户在其他地方登录,抛出
UserLoginOfflineError
异常。
- 从 Redis 获取当前用户的
- 调用
- 返回用户信息:
- 返回
UserPayload
实例,包含用户的 ID、角色和用户名。
- 返回
- JWT 验证:
- 应用场景:
- 用于保护需要登录才能访问的 API 路由,确保只有认证用户可以访问。
- 结合 FastAPI 的依赖注入机制,简化路由函数中的用户信息获取。
函数 get_admin_user
async def get_admin_user(authorize: AuthJWT = Depends()) -> UserPayload:
"""
获取超级管理账号,非超级管理员用户,抛出异常
"""
login_user = await get_login_user(authorize)
if not login_user.is_admin():
raise UnAuthorizedError.http_exception()
return login_user
- 功能:获取当前登录的超级管理员用户,如果用户不是管理员,则抛出未授权异常。
- 参数:
authorize
:AuthJWT
实例,通过 FastAPI 的Depends
注入。
- 逻辑:
- 获取当前登录用户:
- 调用
get_login_user
函数,获取当前登录用户的UserPayload
实例。
- 调用
- 检查管理员权限:
- 调用
login_user.is_admin()
判断用户是否是管理员。 - 如果用户不是管理员,抛出
UnAuthorizedError
异常。
- 调用
- 返回管理员用户信息:
- 如果用户是管理员,返回
UserPayload
实例。
- 如果用户是管理员,返回
- 获取当前登录用户:
- 应用场景:
- 用于保护需要管理员权限才能访问的 API 路由,确保只有管理员用户可以访问。
- 结合 FastAPI 的依赖注入机制,简化路由函数中的管理员用户信息获取和权限校验。
6. 辅助函数 gen_user_role
、gen_user_jwt
和 get_*_list_by_access
函数 gen_user_role
def gen_user_role(db_user: User):
# 查询用户的角色列表
db_user_role = UserRoleDao.get_user_roles(db_user.user_id)
role = ''
role_ids = []
for user_role in db_user_role:
if user_role.role_id == 1:
# 是管理员,忽略其他的角色
role = 'admin'
else:
role_ids.append(user_role.role_id)
if role != 'admin':
# 判断是否是用户组管理员
db_user_groups = UserGroupDao.get_user_admin_group(db_user.user_id)
if len(db_user_groups) > 0:
role = 'group_admin'
else:
role = role_ids
# 获取用户的菜单栏权限列表
web_menu = RoleAccessDao.get_role_access(role_ids, AccessType.WEB_MENU)
web_menu = list(set([one.third_id for one in web_menu]))
return role, web_menu
- 功能:生成用户的角色和菜单权限。
- 参数:
db_user
:数据库中的User
实例。
- 逻辑:
- 查询用户角色:
- 调用
UserRoleDao.get_user_roles
获取用户的角色列表。
- 调用
- 确定用户角色:
- 遍历用户的角色列表:
- 如果用户拥有角色 ID 为
1
,则将角色设置为'admin'
,并忽略其他角色。 - 否则,将角色 ID 添加到
role_ids
列表中。
- 如果用户拥有角色 ID 为
- 遍历用户的角色列表:
- 检查用户组管理员身份:
- 如果用户不是管理员,调用
UserGroupDao.get_user_admin_group
检查用户是否是某个用户组的管理员。 - 如果是用户组管理员,设置角色为
'group_admin'
。 - 否则,角色保持为角色 ID 的列表。
- 如果用户不是管理员,调用
- 获取菜单权限:
- 调用
RoleAccessDao.get_role_access
,传入role_ids
和AccessType.WEB_MENU
,获取用户的菜单权限。 - 提取并去重
third_id
,形成web_menu
列表。
- 调用
- 返回角色和菜单权限:
- 返回用户的角色(字符串或列表)和菜单权限列表。
- 查询用户角色:
- 应用场景:
- 在用户登录后,生成用户的角色和权限信息,用于前端展示和权限校验。
函数 gen_user_jwt
def gen_user_jwt(db_user: User):
if 1 == db_user.delete:
raise HTTPException(status_code=500, detail='该账号已被禁用,请联系管理员')
# 查询角色
role, web_menu = gen_user_role(db_user)
# 生成JWT令牌
payload = {'user_name': db_user.user_name, 'user_id': db_user.user_id, 'role': role}
# Create the tokens and passing to set_access_cookies or set_refresh_cookies
access_token = AuthJWT().create_access_token(subject=json.dumps(payload),
expires_time=ACCESS_TOKEN_EXPIRE_TIME)
refresh_token = AuthJWT().create_refresh_token(subject=db_user.user_name)
# Set the JWT cookies in the response
return access_token, refresh_token, role, web_menu
- 功能:为用户生成 JWT 访问令牌和刷新令牌,并返回角色和菜单权限信息。
- 参数:
db_user
:数据库中的User
实例。
- 逻辑:
- 检查用户状态:
- 如果
db_user.delete
等于1
,表示账号被禁用,抛出HTTPException
,状态码为500
,提示用户联系管理员。
- 如果
- 获取用户角色和菜单权限:
- 调用
gen_user_role
函数,获取用户的角色和菜单权限。
- 调用
- 生成 JWT 载荷:
- 创建一个包含
user_name
、user_id
和role
的字典payload
。
- 创建一个包含
- 生成访问令牌和刷新令牌:
- 调用
AuthJWT().create_access_token
生成访问令牌,主题为payload
的 JSON 字符串,过期时间为ACCESS_TOKEN_EXPIRE_TIME
。 - 调用
AuthJWT().create_refresh_token
生成刷新令牌,主题为用户的用户名。
- 调用
- 返回令牌和权限信息:
- 返回访问令牌、刷新令牌、角色和菜单权限。
- 检查用户状态:
- 应用场景:
- 在用户成功登录后,生成并返回 JWT 令牌,用于后续的认证和授权。
函数 get_knowledge_list_by_access
def get_knowledge_list_by_access(role_id: int, name: str, page_num: int, page_size: int):
count_filter = []
if name:
count_filter.append(Knowledge.name.like('%{}%'.format(name)))
db_role_access = KnowledgeDao.get_knowledge_by_access(role_id, page_num, page_size)
total_count = KnowledgeDao.get_count_by_filter(count_filter)
# 补充用户名
user_ids = [access[0].user_id for access in db_role_access]
db_users = UserDao.get_user_by_ids(user_ids)
user_dict = {user.user_id: user.user_name for user in db_users}
return {
'data': [
KnowledgeRead.validate({
'name': access[0].name,
'user_name': user_dict.get(access[0].user_id),
'user_id': access[0].user_id,
'update_time': access[0].update_time,
'id': access[0].id
}) for access in db_role_access
],
'total':
total_count
}
- 功能:根据用户角色获取有访问权限的知识库列表,并支持分页和名称过滤。
- 参数:
role_id
:用户的角色 ID。name
:知识库名称的搜索关键字。page_num
:分页的页码。page_size
:每页显示的记录数。
- 逻辑:
- 构建过滤条件:
- 如果提供了
name
,则添加一个 SQL LIKE 过滤条件,用于模糊匹配知识库名称。
- 如果提供了
- 查询有访问权限的知识库:
- 调用
KnowledgeDao.get_knowledge_by_access
,传入role_id
、page_num
和page_size
,获取用户有权限访问的知识库列表。
- 调用
- 查询总数:
- 调用
KnowledgeDao.get_count_by_filter
,传入过滤条件,获取满足条件的知识库总数。
- 调用
- 补充用户名:
- 从知识库访问列表中提取所有的
user_id
。 - 调用
UserDao.get_user_by_ids
获取对应的用户信息。 - 构建一个字典
user_dict
,将user_id
映射到user_name
。
- 从知识库访问列表中提取所有的
- 构建响应数据:
- 遍历
db_role_access
,为每个知识库构建一个包含name
、user_name
、user_id
、update_time
和id
的字典,并通过KnowledgeRead.validate
进行数据验证和序列化。 - 返回包含
data
和total
的字典。
- 遍历
- 构建过滤条件:
- 应用场景:
- 用于前端页面展示用户有权限访问的知识库列表,并支持搜索和分页功能。
函数 get_flow_list_by_access
def get_flow_list_by_access(role_id: int, name: str, page_num: int, page_size: int):
count_filter = []
if name:
count_filter.append(Flow.name.like('%{}%'.format(name)))
db_role_access = FlowDao.get_flow_by_access(role_id, name, page_num, page_size)
total_count = FlowDao.get_count_by_filters(count_filter)
# 补充用户名
user_ids = [access[0].user_id for access in db_role_access]
db_users = UserDao.get_user_by_ids(user_ids)
user_dict = {user.user_id: user.user_name for user in db_users}
return {
'data': [
FlowRead.validate({
'name': access[0].name,
'user_name': user_dict.get(access[0].user_id),
'user_id': access[0].user_id,
'update_time': access[0].update_time,
'id': access[0].id
}) for access in db_role_access
],
'total':
total_count
}
- 功能:根据用户角色获取有访问权限的流程列表,并支持分页和名称过滤。
- 参数:
role_id
:用户的角色 ID。name
:流程名称的搜索关键字。page_num
:分页的页码。page_size
:每页显示的记录数。
- 逻辑:
- 构建过滤条件:
- 如果提供了
name
,则添加一个 SQL LIKE 过滤条件,用于模糊匹配流程名称。
- 如果提供了
- 查询有访问权限的流程:
- 调用
FlowDao.get_flow_by_access
,传入role_id
、name
、page_num
和page_size
,获取用户有权限访问的流程列表。
- 调用
- 查询总数:
- 调用
FlowDao.get_count_by_filters
,传入过滤条件,获取满足条件的流程总数。
- 调用
- 补充用户名:
- 从流程访问列表中提取所有的
user_id
。 - 调用
UserDao.get_user_by_ids
获取对应的用户信息。 - 构建一个字典
user_dict
,将user_id
映射到user_name
。
- 从流程访问列表中提取所有的
- 构建响应数据:
- 遍历
db_role_access
,为每个流程构建一个包含name
、user_name
、user_id
、update_time
和id
的字典,并通过FlowRead.validate
进行数据验证和序列化。 - 返回包含
data
和total
的字典。
- 遍历
- 构建过滤条件:
- 应用场景:
- 用于前端页面展示用户有权限访问的流程列表,并支持搜索和分页功能。
函数 get_assistant_list_by_access
def get_assistant_list_by_access(role_id: int, name: str, page_num: int, page_size: int):
count_filter = []
if name:
count_filter.append(Assistant.name.like('%{}%'.format(name)))
db_role_access = AssistantDao.get_assistants_by_access(role_id, name, page_size, page_num)
total_count = AssistantDao.get_count_by_filters(count_filter)
# 补充用户名
user_ids = [access[0].user_id for access in db_role_access]
db_users = UserDao.get_user_by_ids(user_ids)
user_dict = {user.user_id: user.user_name for user in db_users}
return {
'data': [{
'name': access[0].name,
'user_name': user_dict.get(access[0].user_id),
'user_id': access[0].user_id,
'update_time': access[0].update_time,
'id': access[0].id
} for access in db_role_access],
'total':
total_count
}
- 功能:根据用户角色获取有访问权限的助手列表,并支持分页和名称过滤。
- 参数:
role_id
:用户的角色 ID。name
:助手名称的搜索关键字。page_num
:分页的页码。page_size
:每页显示的记录数。
- 逻辑:
- 构建过滤条件:
- 如果提供了
name
,则添加一个 SQL LIKE 过滤条件,用于模糊匹配助手名称。
- 如果提供了
- 查询有访问权限的助手:
- 调用
AssistantDao.get_assistants_by_access
,传入role_id
、name
、page_size
和page_num
,获取用户有权限访问的助手列表。
- 调用
- 查询总数:
- 调用
AssistantDao.get_count_by_filters
,传入过滤条件,获取满足条件的助手总数。
- 调用
- 补充用户名:
- 从助手访问列表中提取所有的
user_id
。 - 调用
UserDao.get_user_by_ids
获取对应的用户信息。 - 构建一个字典
user_dict
,将user_id
映射到user_name
。
- 从助手访问列表中提取所有的
- 构建响应数据:
- 遍历
db_role_access
,为每个助手构建一个包含name
、user_name
、user_id
、update_time
和id
的字典。 - 返回包含
data
和total
的字典。
- 遍历
- 构建过滤条件:
- 应用场景:
- 用于前端页面展示用户有权限访问的助手列表,并支持搜索和分页功能。
7. 认证和权限获取
函数 get_login_user
async def get_login_user(authorize: AuthJWT = Depends()) -> UserPayload:
"""
获取当前登录的用户
"""
# 校验是否过期,过期则直接返回http 状态码的 401
authorize.jwt_required()
current_user = json.loads(authorize.get_jwt_subject())
user = UserPayload(**current_user)
# 判断是否允许多点登录
if not settings.get_system_login_method().allow_multi_login:
# 获取access_token
current_token = redis_client.get(USER_CURRENT_SESSION.format(user.user_id))
# 登录被挤下线了,http状态码是200, status_code是特殊code
if current_token != authorize._token:
raise UserLoginOfflineError.http_exception()
return user
- 功能:获取当前登录的用户信息,并进行多点登录状态检查。
- 参数:
authorize
:AuthJWT
实例,通过 FastAPI 的Depends
注入。
- 逻辑:
- JWT 验证:
- 调用
authorize.jwt_required()
,确保请求包含有效的 JWT 令牌。 - 如果 JWT 过期或无效,自动抛出
401 Unauthorized
异常。
- 调用
- 获取 JWT 载荷:
- 调用
authorize.get_jwt_subject()
获取 JWT 的主题部分(通常是用户信息的 JSON 字符串)。 - 使用
json.loads
将 JSON 字符串转换为 Python 字典。 - 使用
UserPayload
类初始化用户信息对象。
- 调用
- 检查多点登录:
- 调用
settings.get_system_login_method().allow_multi_login
判断系统是否允许多点登录。 - 如果不允许多点登录:
- 从 Redis 获取当前用户的
access_token
,键格式为USER_CURRENT_SESSION.format(user.user_id)
。 - 比较 Redis 中存储的
current_token
与当前请求中的_token
。 - 如果两者不匹配,说明用户在其他地方登录,抛出
UserLoginOfflineError
异常。
- 从 Redis 获取当前用户的
- 调用
- 返回用户信息:
- 返回
UserPayload
实例,包含用户的 ID、角色和用户名。
- 返回
- JWT 验证:
- 应用场景:
- 用于保护需要登录才能访问的 API 路由,确保只有认证用户可以访问。
- 结合 FastAPI 的依赖注入机制,简化路由函数中的用户信息获取。
函数 get_admin_user
async def get_admin_user(authorize: AuthJWT = Depends()) -> UserPayload:
"""
获取超级管理账号,非超级管理员用户,抛出异常
"""
login_user = await get_login_user(authorize)
if not login_user.is_admin():
raise UnAuthorizedError.http_exception()
return login_user
- 功能:获取当前登录的超级管理员用户,如果用户不是管理员,则抛出未授权异常。
- 参数:
authorize
:AuthJWT
实例,通过 FastAPI 的Depends
注入。
- 逻辑:
- 获取当前登录用户:
- 调用
get_login_user
函数,获取当前登录用户的UserPayload
实例。
- 调用
- 检查管理员权限:
- 调用
login_user.is_admin()
判断用户是否是管理员。 - 如果用户不是管理员,抛出
UnAuthorizedError
异常。
- 调用
- 返回管理员用户信息:
- 如果用户是管理员,返回
UserPayload
实例。
- 如果用户是管理员,返回
- 获取当前登录用户:
- 应用场景:
- 用于保护需要管理员权限才能访问的 API 路由,确保只有管理员用户可以访问。
- 结合 FastAPI 的依赖注入机制,简化路由函数中的管理员用户信息获取和权限校验。
8. 其他辅助函数
函数 gen_user_role
此函数已在前面详细讲解,此处不再重复。
函数 gen_user_jwt
此函数已在前面详细讲解,此处不再重复。
函数 get_knowledge_list_by_access
此函数已在前面详细讲解,此处不再重复。
函数 get_flow_list_by_access
此函数已在前面详细讲解,此处不再重复。
函数 get_assistant_list_by_access
此函数已在前面详细讲解,此处不再重复。
9. 总结
这段代码主要涵盖了以下几个关键方面:
- 用户身份验证:
- 使用 JWT 进行用户认证,确保每个请求都携带有效的令牌。
- 提供了
get_login_user
和get_admin_user
函数,用于获取当前登录用户的信息,并进行权限校验。
- 权限管理:
UserPayload
类封装了用户的角色和权限信息,提供了权限检查的方法。- 通过装饰器
wrapper_access_check
简化权限检查逻辑,提高代码复用性。
- 用户管理:
UserService
类提供了用户创建和密码解密的功能。- 处理用户注册时的逻辑,如检查用户名是否存在、分配用户组和角色等。
- 与数据库交互:
- 使用 DAO(数据访问对象)模式,通过
UserDao
、FlowDao
、KnowledgeDao
等类与数据库进行交互。 - 通过
UserRoleDao
、RoleAccessDao
等类管理用户角色和权限。
- 使用 DAO(数据访问对象)模式,通过
- 缓存和配置管理:
- 使用 Redis 作为缓存,存储和获取用户会话信息和 RSA 密钥。
- 通过
settings
对象管理项目的配置参数,如数据库连接字符串、是否允许多点登录等。
- 异常处理:
- 定义了自定义的异常类,用于统一错误响应和处理。
- 在关键业务逻辑中抛出自定义异常,确保错误信息的一致性和可读性。
- 实用功能:
- 提供了
generate_uuid
函数,用于生成唯一的标识符。 - 提供了
md5_hash
和 RSA 解密功能,增强密码的安全性。
- 提供了
如何在项目中使用
-
保护路由:
- 使用
Depends(get_login_user)
或Depends(get_admin_user)
作为路由的依赖项,确保只有经过认证的用户或管理员可以访问特定的路由。
from fastapi import APIRouter, Depends router = APIRouter() @router.get("/protected-route") async def protected_route(user: UserPayload = Depends(get_login_user)): return {"message": f"Hello, {user.user_name}!"} @router.get("/admin-route") async def admin_route(user: UserPayload = Depends(get_admin_user)): return {"message": f"Welcome, admin {user.user_name}!"}
- 使用
-
创建用户:
- 使用
UserService.create_user
方法处理用户创建逻辑。
@router.post("/create-user") async def create_user(request: Request, req_data: CreateUserReq, user: UserPayload = Depends(get_admin_user)): new_user = UserService.create_user(request, user, req_data) return {"user_id": new_user.user_id, "user_name": new_user.user_name}
- 使用
-
生成和返回 JWT:
- 在用户登录成功后,调用
gen_user_jwt
函数生成 JWT 令牌,并将其返回给前端。
@router.post("/login") async def login(credentials: LoginSchema): db_user = UserDao.authenticate(credentials.username, credentials.password) if not db_user: raise HTTPException(status_code=401, detail="Invalid credentials") access_token, refresh_token, role, web_menu = gen_user_jwt(db_user) return {"access_token": access_token, "refresh_token": refresh_token, "role": role, "web_menu": web_menu}
- 在用户登录成功后,调用
进一步阅读
- FastAPI 文档:
- FastAPI 官方文档 提供了详细的框架使用指南和最佳实践。
- Pydantic 文档:
- Pydantic 官方文档 详细介绍了数据模型和验证功能。
- JWT 认证:
- 学习 JWT(JSON Web Token) 的基本概念和使用方法,了解如何在 API 中实现安全的认证机制。
- Redis 缓存:
- Redis 官方文档 了解如何使用 Redis 进行缓存和会话管理。
- RSA 加密:
- 学习 RSA 加密算法 的基本原理和应用场景,了解如何在 Python 中使用
rsa
库进行加密和解密操作。
- 学习 RSA 加密算法 的基本原理和应用场景,了解如何在 Python 中使用
常见问题
- 为什么使用装饰器进行权限检查?
- 使用装饰器可以将权限检查逻辑与业务逻辑分离,提高代码的可读性和可维护性。通过装饰器,可以复用权限检查逻辑,避免在每个方法中重复编写相同的代码。
- 如何确保 JWT 的安全性?
- 确保 JWT 使用强密码进行签名,避免暴露私钥。
- 设置合理的过期时间,防止令牌被滥用。
- 使用 HTTPS 保护令牌在传输过程中的安全性。
- 为什么需要检查多点登录?
- 检查多点登录可以防止用户账号被多人同时使用,提升账号的安全性。
- 如果系统不允许多点登录,当用户在其他设备登录时,可以自动将之前的会话踢下线。
- 如何处理 RSA 密钥管理?
- 确保 RSA 密钥安全存储,不要将私钥暴露在代码库中。
- 使用环境变量或安全的密钥管理服务(如 AWS KMS)来存储和管理 RSA 密钥。
- 如何优化数据库查询性能?
- 确保数据库表中有适当的索引,尤其是在常用的查询字段上。
- 使用分页查询减少一次查询返回的数据量。
- 避免 N+1 查询问题,通过使用 JOIN 或预加载相关数据提高查询效率。