当前位置: 首页 > article >正文

毕昇入门学习

schemas.py

概述

这段代码主要定义了一系列基于 Pydantic 的数据模型(BaseModel),用于数据验证和序列化,通常用于构建 API(如使用 FastAPI)。这些模型涵盖了用户认证、聊天消息、知识库管理、模型配置等多个方面。此外,还定义了一些通用的响应模型和辅助函数,以标准化 API 的响应格式。

详细解析

导入部分

from datetime import datetime
from enum import Enum
from typing import Any, Dict, Generic, List, Optional, TypeVar, Union
from uuid import UUID

from langchain.docstore.document import Document
from orjson import orjson
from pydantic import BaseModel, Field, validator, root_validator

from bisheng.database.models.assistant import AssistantBase
from bisheng.database.models.finetune import TrainMethod
from bisheng.database.models.flow import FlowCreate, FlowRead
from bisheng.database.models.gpts_tools import GptsToolsRead, AuthMethod, AuthType
from bisheng.database.models.knowledge import KnowledgeRead
from bisheng.database.models.llm_server import LLMServerBase, LLMModelBase, LLMServerType, LLMModelType
from bisheng.database.models.message import ChatMessageRead
from bisheng.database.models.tag import Tag
  • 标准库导入
    • datetime:处理日期和时间。
    • Enum:定义枚举类型。
    • typing:用于类型注解,增强代码可读性和可维护性。
    • UUID:处理唯一标识符。
  • 第三方库导入
    • langchain.docstore.document.Document:可能用于文档存储和检索。
    • orjson:高性能的 JSON 库,用于快速序列化和反序列化 JSON 数据。
    • pydantic:用于数据验证和设置管理,广泛用于 FastAPI。
  • 项目内部模块导入
    • bisheng.database.models.* 导入多个 ORM 模型,代表数据库中的不同实体(如助手、微调方法、流程、工具、知识库、LLM 服务器和模型、聊天消息、标签等)。

数据模型

以下是代码中定义的主要数据模型和相关函数的详细解释。

基础输入模型
  1. 验证码输入

    class CaptchaInput(BaseModel):
        captcha_key: str
        captcha: str
    
    • 用途:用于用户提交验证码时的输入验证。
    • 字段:
      • captcha_key:验证码的唯一标识符。
      • captcha:用户输入的验证码内容。
  2. 分块输入

    class ChunkInput(BaseModel):
        knowledge_id: int
        documents: List[Document]
    
    • 用途:用于将文档分块处理的输入。
    • 字段:
      • knowledge_id:知识库的唯一标识符。
      • documents:文档列表,每个文档为 langchain.docstore.document.Document 类型。
枚举类型
  1. 构建状态

    class BuildStatus(Enum):
        """Status of the build."""
    
        SUCCESS = 'success'
        FAILURE = 'failure'
        STARTED = 'started'
        IN_PROGRESS = 'in_progress'
    
    • 用途:表示构建过程的状态。
    • 枚举值:
      • SUCCESS:构建成功。
      • FAILURE:构建失败。
      • STARTED:构建已启动。
      • IN_PROGRESS:构建进行中。
图数据模型
  1. 图数据

    class GraphData(BaseModel):
        """Data inside the exported flow."""
    
        nodes: List[Dict[str, Any]]
        edges: List[Dict[str, Any]]
    
    • 用途:表示导出流程中的图数据。
    • 字段:
      • nodes:节点列表,每个节点为字典类型,包含节点的详细信息。
      • edges:边列表,每条边为字典类型,描述节点之间的连接关系。
  2. 导出流程

    class ExportedFlow(BaseModel):
        """Exported flow from bisheng."""
    
        description: str
        name: str
        id: str
        data: GraphData
    
    • 用途:表示从 bisheng 导出的流程数据。
    • 字段:
      • description:流程描述。
      • name:流程名称。
      • id:流程的唯一标识符。
      • data:流程的图数据,类型为 GraphData
请求模型
  1. 输入请求

    class InputRequest(BaseModel):
        input: str = Field(description='question or command asked LLM to do')
    
    • 用途:表示用户向 LLM(大型语言模型)提出的问题或命令。
    • 字段:
      • input:用户的输入内容(问题或命令)。
  2. 调整请求

    class TweaksRequest(BaseModel):
        tweaks: Optional[Dict[str, Dict[str, str]]] = Field(default_factory=dict)
    
    • 用途:表示对某些设置或参数的调整请求。
    • 字段:
      • tweaks:一个可选的字典,嵌套字典用于描述具体的调整内容。
  3. 更新模板请求

    class UpdateTemplateRequest(BaseModel):
        template: dict
    
    • 用途:用于更新模板的请求。
    • 字段:
      • template:模板内容,类型为字典。
通用响应模型
  1. 统一响应模型

    DataT = TypeVar('DataT')
    
    class UnifiedResponseModel(Generic[DataT], BaseModel):
        """统一响应模型"""
        status_code: int
        status_message: str
        data: DataT = None
    
    • 用途:提供一个通用的 API 响应结构,适用于各种类型的数据。
    • 泛型:
      • DataT:数据的具体类型,可以是任意类型。
    • 字段:
      • status_code:响应状态码(如 200、500 等)。
      • status_message:响应状态信息(如 “SUCCESS”、“BAD REQUEST” 等)。
      • data:响应的数据,类型为泛型 DataT
  2. 成功响应函数

    def resp_200(data: Union[list, dict, str, Any] = None,
                message: str = 'SUCCESS') -> UnifiedResponseModel:
        """成功的代码"""
        return UnifiedResponseModel(status_code=200, status_message=message, data=data)
    
    • 用途:生成一个状态码为 200 的成功响应。
    • 参数:
      • data:响应的数据,可以是列表、字典、字符串或任意类型。
      • message:响应的状态信息,默认值为 “SUCCESS”。
  3. 错误响应函数

    def resp_500(code: int = 500,
                data: Union[list, dict, str, Any] = None,
                message: str = 'BAD REQUEST') -> UnifiedResponseModel:
        """错误的逻辑回复"""
        return UnifiedResponseModel(status_code=code, status_message=message, data=data)
    
    • 用途:生成一个错误响应,默认状态码为 500。
    • 参数:
      • code:错误状态码,默认值为 500。
      • data:响应的数据,可以是列表、字典、字符串或任意类型。
      • message:响应的状态信息,默认值为 “BAD REQUEST”。
处理响应模型
  1. 处理响应

    class ProcessResponse(BaseModel):
        """Process response schema."""
    
        result: Any
        # task: Optional[TaskResponse] = None
        session_id: Optional[str] = None
        backend: Optional[str] = None
    
    • 用途:表示处理请求后的响应结构。
    • 字段:
      • result:处理结果,类型为任意类型。
      • session_id:会话标识符,可选。
      • backend:后端标识,可选。
聊天相关模型
  1. 聊天输入

    class ChatInput(BaseModel):
        message_id: int
        comment: str = None
        liked: int = 0
    
    • 用途:表示用户对某条聊天消息的输入(如评论、点赞)。
    • 字段:
      • message_id:消息的唯一标识符。
      • comment:用户的评论,可选。
      • liked:点赞数,默认值为 0。
  2. 添加聊天消息

    class AddChatMessages(BaseModel):
        """Add a pair of chat messages."""
    
        flow_id: UUID  # 技能或助手ID
        chat_id: str  # 会话ID
        human_message: str = None  # 用户问题
        answer_message: str = None  # 执行结果
    
    • 用途:用于向聊天会话中添加一对消息(用户消息和助手回复)。

    • 字段

      • flow_id:技能或助手的唯一标识符,类型为 UUID
      • chat_id:会话的唯一标识符。
      • human_message:用户的问题,可选。
      • answer_message:助手的回复,可选。
  3. 聊天列表

    class ChatList(BaseModel):
        """Chat message list."""
    
        flow_name: str = None
        flow_description: str = None
        flow_id: UUID = None
        chat_id: str = None
        create_time: datetime = None
        update_time: datetime = None
        flow_type: str = None  # flow: 技能 assistant:gpts助手
        latest_message: ChatMessageRead = None
        logo: Optional[str] = None
    
    • 用途:表示一个聊天会话的简要信息,用于展示聊天列表。
    • 字段:
      • flow_name:技能或助手的名称。
      • flow_description:技能或助手的描述。
      • flow_id:技能或助手的唯一标识符,类型为 UUID
      • chat_id:会话的唯一标识符。
      • create_time:会话的创建时间。
      • update_time:会话的更新时间。
      • flow_type:技能类型,flow 表示技能,assistant 表示 GPTs 助手。
      • latest_message:最新的聊天消息,类型为 ChatMessageRead
      • logo:技能或助手的 logo 地址,可选。
  4. 在线流程列表

    class FlowGptsOnlineList(BaseModel):
        id: str = Field('唯一ID')
        name: str = None
        desc: str = None
        logo: str = None
        create_time: datetime = None
        update_time: datetime = None
        flow_type: str = None  # flow: 技能 assistant:gpts助手
        count: int = 0
    
    • 用途:表示在线流程(技能或助手)的列表信息。
    • 字段:
      • id:流程的唯一标识符。
      • name:流程名称。
      • desc:流程描述。
      • logo:流程的 logo 地址。
      • create_time:流程的创建时间。
      • update_time:流程的更新时间。
      • flow_type:流程类型,flow 表示技能,assistant 表示 GPTs 助手。
      • count:关联的某些数量(具体含义需结合业务逻辑理解)。
  5. 聊天消息

    class ChatMessage(BaseModel):
        """Chat message schema."""
    
        is_bot: bool = False
        message: Union[str, None, dict] = ''
        type: str = 'human'
        category: str = 'processing'  # system processing answer tool
        intermediate_steps: str = None
        files: list = []
        user_id: int = None
        message_id: int = None
        source: int = 0
        sender: str = None
        receiver: dict = None
        liked: int = 0
        extra: str = '{}'
        flow_id: str = None
        chat_id: str = None
    
    • 用途:表示一条聊天消息的详细信息。
    • 字段:
      • is_bot:是否为机器人消息,默认值为 False
      • message:消息内容,可以是字符串、字典或 None
      • type:消息类型,默认值为 'human'(表示用户消息)。
      • category:消息分类,如 'system''processing''answer''tool'
      • intermediate_steps:中间步骤信息,可选。
      • files:关联的文件列表,默认空列表。
      • user_id:用户的唯一标识符,可选。
      • message_id:消息的唯一标识符,可选。
      • source:消息来源,默认值为 0
      • sender:发送者信息,可选。
      • receiver:接收者信息,字典类型,可选。
      • liked:点赞数,默认值为 0
      • extra:额外信息,默认值为 '{}'
      • flow_id:关联的流程(技能或助手)的唯一标识符,可选。
      • chat_id:关联的聊天会话的唯一标识符,可选。
  6. 聊天响应

    class ChatResponse(ChatMessage):
        """Chat response schema."""
    
        intermediate_steps: str = ''
        type: str
        is_bot: bool = True
        files: list = []
        category: str = 'processing'
    
        @validator('type')
        def validate_message_type(cls, v):
            """
            end_cover: 结束并覆盖上一条message
            """
            if v not in ['start', 'stream', 'end', 'error', 'info', 'file', 'begin', 'close', 'end_cover']:
                raise ValueError('type must be start, stream, end, error, info, or file')
            return v
    
    • 用途:表示机器人(助手)的聊天响应消息。
    • 继承自ChatMessage,并对部分字段进行了重写或扩展。
    • 字段
      • intermediate_steps:中间步骤信息,默认值为空字符串。
      • type:消息类型,必须在指定的枚举值中。
      • is_bot:是否为机器人消息,固定为 True
      • files:关联的文件列表,默认空列表。
      • category:消息分类,固定为 'processing'
    • 验证器
      • validate_message_type:确保 type 字段的值在指定的枚举值范围内,否则抛出错误。
  7. 文件响应

    class FileResponse(ChatMessage):
        """File response schema."""
    
        data: Any
        data_type: str
        type: str = 'file'
        is_bot: bool = True
    
        @validator('data_type')
        def validate_data_type(cls, v):
            if v not in ['image', 'csv']:
                raise ValueError('data_type must be image or csv')
            return v
    
    • 用途:表示机器人发送的文件类型消息。
    • 继承自ChatMessage,并对部分字段进行了重写或扩展。
    • 字段
      • data:文件数据,类型为任意类型。
      • data_type:文件类型,必须是 'image''csv'
      • type:消息类型,固定为 'file'
      • is_bot:是否为机器人消息,固定为 True
    • 验证器
      • validate_data_type:确保 data_type 字段的值为 'image''csv',否则抛出错误。
流程相关模型
  1. 流程列表创建

    class FlowListCreate(BaseModel):
        flows: List[FlowCreate]
    
    • 用途:用于批量创建流程(技能或助手)。
    • 字段:
      • flows:流程创建请求的列表,类型为 List[FlowCreate]
  2. 流程列表读取

    class FlowListRead(BaseModel):
        flows: List[FlowRead]
    
    • 用途:用于批量读取流程的信息。
    • 字段:
      • flows:流程读取响应的列表,类型为 List[FlowRead]
  3. 初始化响应

    class InitResponse(BaseModel):
        flowId: str
    
    • 用途:表示初始化操作后的响应,通常返回一个流程的唯一标识符。
    • 字段:
      • flowId:流程的唯一标识符。
  4. 构建响应

    class BuiltResponse(BaseModel):
        built: bool
    
    • 用途:表示构建操作的结果。
    • 字段:
      • built:布尔值,表示是否构建成功。
  5. 上传文件响应

    class UploadFileResponse(BaseModel):
        """Upload file response schema."""
    
        flowId: Optional[str]
        file_path: str
        relative_path: Optional[str]  # minio的相对路径,即object_name
    
    • 用途:表示文件上传后的响应。
    • 字段:
      • flowId:关联的流程(技能或助手)的唯一标识符,可选。
      • file_path:文件的完整路径。
      • relative_path:文件在存储系统(如 MinIO)中的相对路径,即对象名称,可选。
  6. 流数据

    class StreamData(BaseModel):
        event: str
        data: dict | str
    
        def __str__(self) -> str:
            if isinstance(self.data, dict):
                return f'event: {self.event}\ndata: {orjson.dumps(self.data).decode()}\n\n'
            return f'event: {self.event}\ndata: {self.data}\n\n'
    
    • 用途:表示流式数据传输的结构,常用于服务器发送事件(SSE)。
    • 字段
      • event:事件类型。
      • data:事件数据,可以是字典或字符串。
    • 方法
      • __str__:将 StreamData 对象转换为字符串格式,适用于流式传输。
微调相关模型
  1. 微调创建请求

    class FinetuneCreateReq(BaseModel):
        server: int = Field(description='关联的RT服务ID')
        base_model: int = Field(description='基础模型ID')
        model_name: str = Field(max_length=50, description='模型名称')
        method: TrainMethod = Field(description='训练方法')
        extra_params: Dict = Field(default_factory=dict, description='训练任务所需额外参数')
        train_data: Optional[List[Dict]] = Field(default=None, description='个人训练数据')
        preset_data: Optional[List[Dict]] = Field(default=None, description='预设训练数据')
    
    • 用途:表示创建微调任务的请求。
    • 字段:
      • server:关联的 RT 服务的唯一标识符。
      • base_model:基础模型的唯一标识符。
      • model_name:模型名称,最大长度为 50。
      • method:训练方法,类型为 TrainMethod(枚举类型)。
      • extra_params:训练任务所需的额外参数,默认为空字典。
      • train_data:个人训练数据,可选。
      • preset_data:预设训练数据,可选。
组件相关模型
  1. 创建组件请求

    class CreateComponentReq(BaseModel):
        name: str = Field(max_length=50, description='组件名称')
        data: Any = Field(default='', description='组件数据')
        description: Optional[str] = Field(default='', description='组件描述')
    
    • 用途:表示创建组件的请求。
    • 字段:
      • name:组件名称,最大长度为 50。
      • data:组件的数据,类型为任意类型,默认值为空字符串。
      • description:组件描述,可选,默认值为空字符串。
  2. 自定义组件代码

    class CustomComponentCode(BaseModel):
        code: str
        field: Optional[str] = None
        frontend_node: Optional[dict] = None
    
    • 用途:表示自定义组件的代码。
    • 字段:
      • code:组件的代码。
      • field:可选字段,可能用于描述组件的特定字段。
      • frontend_node:前端节点的信息,类型为字典,可选。
助手相关模型
  1. 创建助手请求

    class AssistantCreateReq(BaseModel):
        name: str = Field(max_length=50, description='助手名称')
        prompt: str = Field(min_length=20, max_length=1000, description='助手提示词')
        logo: str = Field(description='logo文件的相对地址')
    
    • 用途:表示创建助手的请求。
    • 字段:
      • name:助手名称,最大长度为 50。
      • prompt:助手的提示词,最小长度为 20,最大长度为 1000。
      • logo:logo 文件的相对地址。
  2. 更新助手请求

    class AssistantUpdateReq(BaseModel):
        id: UUID = Field(description='助手ID')
        name: Optional[str] = Field('', description='助手名称, 为空则不更新')
        desc: Optional[str] = Field('', description='助手描述, 为空则不更新')
        logo: Optional[str] = Field('', description='logo文件的相对地址,为空则不更新')
        prompt: Optional[str] = Field('', description='用户可见prompt, 为空则不更新')
        guide_word: Optional[str] = Field('', description='开场白, 为空则不更新')
        guide_question: Optional[List] = Field([], description='引导问题列表, 为空则不更新')
        model_name: Optional[str] = Field('', description='选择的模型名, 为空则不更新')
        temperature: Optional[float] = Field(None, description='模型温度, 不传则不更新')
    
        tool_list: List[int] | None = Field(default=None,
                                            description='助手的工具ID列表,空列表则清空绑定的工具,为None则不更新')
        flow_list: List[str] | None = Field(default=None, description='助手的技能ID列表,为None则不更新')
        knowledge_list: List[int] | None = Field(default=None, description='知识库ID列表,为None则不更新')
    
    • 用途:表示更新助手的请求。
    • 字段:
      • id:助手的唯一标识符,类型为 UUID
      • name:助手名称,可选,默认为空字符串,空则不更新。
      • desc:助手描述,可选,默认为空字符串,空则不更新。
      • logo:logo 文件的相对地址,可选,默认为空字符串,空则不更新。
      • prompt:用户可见的提示词,可选,默认为空字符串,空则不更新。
      • guide_word:开场白,可选,默认为空字符串,空则不更新。
      • guide_question:引导问题列表,可选,默认为空列表,空则不更新。
      • model_name:选择的模型名称,可选,默认为空字符串,空则不更新。
      • temperature:模型温度,可选,默认为 None,不传则不更新。
      • tool_list:助手的工具 ID 列表,可选,None 表示不更新,空列表表示清空绑定的工具。
      • flow_list:助手的技能 ID 列表,可选,None 表示不更新。
      • knowledge_list:知识库 ID 列表,可选,None 表示不更新。
  3. 助手简单信息

    class AssistantSimpleInfo(BaseModel):
        id: UUID
        name: str
        desc: str
        logo: str
        user_id: int
        user_name: str
        status: int
        write: Optional[bool] = Field(default=False)
        group_ids: Optional[List[int]]
        tags: Optional[List[Tag]]
        create_time: datetime
        update_time: datetime
    
    • 用途:表示助手的基本信息,用于展示列表或概要信息。
    • 字段:
      • id:助手的唯一标识符,类型为 UUID
      • name:助手名称。
      • desc:助手描述。
      • logo:助手的 logo 地址。
      • user_id:创建者的用户 ID。
      • user_name:创建者的用户名。
      • status:助手的状态,类型为整数(具体含义需结合业务逻辑理解)。
      • write:是否可写,默认值为 False
      • group_ids:关联的用户组 ID 列表,可选。
      • tags:关联的标签列表,类型为 List[Tag],可选。
      • create_time:创建时间。
      • update_time:更新时间。
  4. 助手信息

    class AssistantInfo(AssistantBase):
        tool_list: List[GptsToolsRead] = Field(default=[], description='助手的工具ID列表')
        flow_list: List[FlowRead] = Field(default=[], description='助手的技能ID列表')
        knowledge_list: List[KnowledgeRead] = Field(default=[], description='知识库ID列表')
    
    • 用途:表示助手的详细信息,继承自 AssistantBase(数据库模型)。
    • 字段:
      • tool_list:助手的工具列表,类型为 List[GptsToolsRead]
      • flow_list:助手的技能列表,类型为 List[FlowRead]
      • knowledge_list:知识库列表,类型为 List[KnowledgeRead]
流程版本与对比模型
  1. 流程版本创建

    class FlowVersionCreate(BaseModel):
        name: Optional[str] = Field(default=..., description="版本的名字")
        description: Optional[str] = Field(default=None, description="版本的描述")
        data: Optional[Dict] = Field(default=None, description='技能版本的节点数据数据')
        original_version_id: Optional[int] = Field(default=None, description="版本的来源版本ID")
    
    • 用途:表示创建流程版本的请求。
    • 字段:
      • name:版本名称。
      • description:版本描述。
      • data:技能版本的节点数据,类型为字典,可选。
      • original_version_id:来源版本的 ID,可选。
  2. 流程对比请求

    class FlowCompareReq(BaseModel):
        inputs: Any = Field(default=None, description='技能运行所需要的输入')
        question_list: List[str] = Field(default=[], description='测试case列表')
        version_list: List[int] = Field(default=[], description='对比版本ID列表')
        node_id: str = Field(default=None, description='需要对比的节点唯一ID')
        thread_num: Optional[int] = Field(default=1, description='对比线程数')
    
    • 用途:表示对比不同版本流程的请求。
    • 字段:
      • inputs:技能运行所需的输入,类型为任意类型。
      • question_list:测试用例列表,类型为 List[str]
      • version_list:对比的版本 ID 列表,类型为 List[int]
      • node_id:需要对比的节点的唯一标识符,类型为字符串,可选。
      • thread_num:对比的线程数,默认值为 1
工具类型与测试模型
  1. 删除工具类型请求

    class DeleteToolTypeReq(BaseModel):
        tool_type_id: int = Field(description="要删除的工具类别ID")
    
    • 用途:表示删除工具类别的请求。
    • 字段:
      • tool_type_id:要删除的工具类别的唯一标识符。
  2. 测试工具请求

    class TestToolReq(BaseModel):
        server_host: str = Field(default='', description="服务的根地址")
        extra: str = Field(default='', description="Api 对象解析后的extra字段")
        auth_method: int = Field(default=AuthMethod.NO.value, description="认证类型")
        auth_type: Optional[str] = Field(default=AuthType.BASIC.value, description="Auth Type")
        api_key: Optional[str] = Field(default='', description="api key")
    
        request_params: Dict = Field(default=None, description="用户填写的请求参数")
    
    • 用途:表示测试工具接口的请求。
    • 字段:
      • server_host:服务的根地址,默认值为空字符串。
      • extra:API 对象解析后的额外字段,默认值为空字符串。
      • auth_method:认证方法,默认值为 AuthMethod.NO.value
      • auth_type:认证类型,默认值为 AuthType.BASIC.value,可选。
      • api_key:API 密钥,默认值为空字符串,可选。
      • request_params:用户填写的请求参数,类型为字典,可选。
用户与角色模型
  1. 用户组与角色

    class GroupAndRoles(BaseModel):
        group_id: int
        role_ids: List[int]
    
    • 用途:表示用户所属的组和角色。
    • 字段:
      • group_id:用户组的唯一标识符。
      • role_ids:角色的唯一标识符列表。
  2. 创建用户请求

    class CreateUserReq(BaseModel):
        user_name: str = Field(max_length=30, description='用户名')
        password: str = Field(description='密码')
        group_roles: List[GroupAndRoles] = Field(description='要加入的用户组和角色列表')
    
    • 用途:表示创建新用户的请求。
    • 字段:
      • user_name:用户名,最大长度为 30。
      • password:用户密码。
      • group_roles:用户所属的用户组和角色列表,类型为 List[GroupAndRoles]
OpenAI 聊天模型相关
  1. OpenAI 聊天完成请求

    class OpenAIChatCompletionReq(BaseModel):
        messages: List[dict] = Field(..., description="聊天消息列表,只支持user、assistant。system用数据库内的数据")
        model: str = Field(..., description="助手的唯一ID")
        n: int = Field(default=1, description="返回的答案个数, 助手侧默认为1,暂不支持多个回答")
        stream: bool = Field(default=False, description="是否开启流式回复")
        temperature: float = Field(default=0.0, description="模型温度, 传入0或者不传表示不覆盖")
        tools: List[dict] = Field(default=[], description="工具列表, 助手暂不支持,使用助手的配置")
    
    • 用途:表示向 OpenAI 的聊天完成 API 发送请求的结构。
    • 字段:
      • messages:聊天消息列表,类型为 List[dict],只支持 userassistantsystem 消息使用数据库内的数据。
      • model:助手的唯一标识符。
      • n:返回的答案个数,默认值为 1,当前不支持多个回答。
      • stream:是否开启流式回复,默认值为 False
      • temperature:模型温度,默认值为 0.0,传入 0 或不传表示不覆盖。
      • tools:工具列表,默认值为空列表,目前助手暂不支持,使用助手的配置。
  2. OpenAI 选择

    class OpenAIChoice(BaseModel):
        index: int = Field(..., description="选项的索引")
        message: dict = Field(default=None, description="对应的消息内容,和输入的格式一致")
        finish_reason: str = Field(default='stop', description="结束原因, 助手只有stop")
        delta: dict = Field(default=None, description="对应的openai流式返回消息内容")
    
    • 用途:表示 OpenAI 聊天完成 API 返回的单个选择项。
    • 字段:
      • index:选项的索引。
      • message:对应的消息内容,格式与输入一致。
      • finish_reason:结束原因,助手只有 'stop'
      • delta:对应的 OpenAI 流式返回消息内容。
  3. OpenAI 聊天完成响应

    class OpenAIChatCompletionResp(BaseModel):
        id: str = Field(..., description="请求的唯一ID")
        object: str = Field(default='chat.completion', description="返回的类型")
        created: int = Field(default=..., description="返回的创建时间戳")
        model: str = Field(..., description="返回的模型,对应助手的id")
        choices: List[OpenAIChoice] = Field(..., description="返回的答案列表")
        usage: dict = Field(default=None, description="返回的token用量, 助手此值为空")
        system_fingerprint: Optional[str] = Field(default=None, description="系统指纹")
    
    • 用途:表示 OpenAI 聊天完成 API 的响应结构。
    • 字段:
      • id:请求的唯一标识符。
      • object:返回的类型,默认值为 'chat.completion'
      • created:返回的创建时间戳。
      • model:返回的模型名称,对应助手的 ID。
      • choices:返回的答案列表,类型为 List[OpenAIChoice]
      • usage:返回的 token 用量,助手此值为空。
      • system_fingerprint:系统指纹,可选。
LLM 模型与服务器配置
  1. LLM 模型创建请求

    class LLMModelCreateReq(BaseModel):
        id: Optional[int] = Field(default=None, description="模型唯一ID, 更新时需要传")
        name: str = Field(..., description="模型展示名称")
        description: Optional[str] = Field(default='', description="模型描述")
        model_name: str = Field(..., description="模型名称")
        model_type: str = Field(..., description="模型类型")
        online: bool = Field(default=True, description='是否在线')
        config: Optional[dict] = Field(default=None, description="模型配置")
    
    • 用途:表示创建或更新 LLM 模型的请求。
    • 字段:
      • id:模型的唯一标识符,更新时需要传。
      • name:模型的展示名称。
      • description:模型描述,可选,默认值为空字符串。
      • model_name:模型名称。
      • model_type:模型类型。
      • online:是否在线,默认值为 True
      • config:模型的配置,类型为字典,可选。
  2. LLM 服务器创建请求

    class LLMServerCreateReq(BaseModel):
        id: Optional[int] = Field(default=None, description="服务提供方ID, 更新时需要传")
        name: str = Field(..., description="服务提供方名称")
        description: Optional[str] = Field(default='', description="服务提供方描述")
        type: str = Field(..., description="服务提供方类型")
        limit_flag: Optional[bool] = Field(default=False, description="是否开启每日调用次数限制")
        limit: Optional[int] = Field(default=0, description="每日调用次数限制")
        config: Optional[dict] = Field(default=None, description="服务提供方配置")
        models: Optional[List[LLMModelCreateReq]] = Field(default=[], description="服务提供方下的模型列表")
    
    • 用途:表示创建或更新 LLM 服务器的请求。
    • 字段:
      • id:服务提供方的唯一标识符,更新时需要传。
      • name:服务提供方名称。
      • description:服务提供方描述,可选,默认值为空字符串。
      • type:服务提供方类型。
      • limit_flag:是否开启每日调用次数限制,默认值为 False
      • limit:每日调用次数限制,默认值为 0
      • config:服务提供方的配置,类型为字典,可选。
      • models:服务提供方下的模型列表,类型为 List[LLMModelCreateReq],默认值为空列表。
  3. LLM 模型信息

    class LLMModelInfo(LLMModelBase):
        id: Optional[int]
    
    • 用途:表示 LLM 模型的详细信息,继承自 LLMModelBase(数据库模型)。
    • 字段:
      • id:模型的唯一标识符,可选。
  4. LLM 服务器信息

    class LLMServerInfo(LLMServerBase):
        id: Optional[int]
        models: List[LLMModelInfo] = Field(default=[], description="模型列表")
    
    • 用途:表示 LLM 服务器的详细信息,继承自 LLMServerBase(数据库模型)。
    • 字段:
      • id:服务器的唯一标识符,可选。
      • models:服务器下的模型列表,类型为 List[LLMModelInfo],默认值为空列表。
知识库配置模型
  1. 知识库 LLM 配置

    class KnowledgeLLMConfig(BaseModel):
        embedding_model_id: Optional[int] = Field(description="知识库默认embedding模型的ID")
        source_model_id: Optional[int] = Field(description="知识库溯源模型的ID")
        extract_title_model_id: Optional[int] = Field(description="文档知识库提取标题模型的ID")
        qa_similar_model_id: Optional[int] = Field(description="QA知识库相似问模型的ID")
    
    • 用途:表示知识库的 LLM 配置。
    • 字段:
      • embedding_model_id:默认的 embedding 模型 ID,可选。
      • source_model_id:溯源模型 ID,可选。
      • extract_title_model_id:文档知识库提取标题模型的 ID,可选。
      • qa_similar_model_id:QA 知识库相似问模型的 ID,可选。
  2. 助手 LLM 项目

    class AssistantLLMItem(BaseModel):
        model_id: Optional[int] = Field(description="模型的ID")
        agent_executor_type: Optional[str] = Field(default="ReAct", description="执行模式。function call 或者 ReAct")
        knowledge_max_content: Optional[int] = Field(default=15000, description="知识库检索最大字符串数")
        knowledge_sort_index: Optional[bool] = Field(default=False, description="知识库检索后是否重排")
        streaming: Optional[bool] = Field(default=True, description="是否开启流式")
        default: Optional[bool] = Field(default=False, description="是否为默认模型")
    
    • 用途:表示助手的单个 LLM 配置项。
    • 字段:
      • model_id:模型的唯一标识符,可选。
      • agent_executor_type:执行模式,默认值为 "ReAct",可选,其他值如 "function call"
      • knowledge_max_content:知识库检索的最大字符串数,默认值为 15000
      • knowledge_sort_index:知识库检索后是否重排,默认值为 False
      • streaming:是否开启流式回复,默认值为 True
      • default:是否为默认模型,默认值为 False
  3. 助手 LLM 配置

    class AssistantLLMConfig(BaseModel):
        llm_list: Optional[List[AssistantLLMItem]] = Field(default=[], description="助手可选的LLM列表")
        auto_llm: Optional[AssistantLLMItem] = Field(description="助手画像自动优化模型的配置")
    
    • 用途:表示助手的 LLM 配置。
    • 字段:
      • llm_list:助手可选的 LLM 列表,类型为 List[AssistantLLMItem],默认值为空列表。
      • auto_llm:助手画像自动优化模型的配置,类型为 AssistantLLMItem,可选。
  4. 评测 LLM 配置

    class EvaluationLLMConfig(BaseModel):
        model_id: Optional[int] = Field(description="评测功能默认模型的ID")
    
    • 用途:表示评测功能的 LLM 配置。
    • 字段:
      • model_id:评测功能默认模型的 ID,可选。
文件处理模型
  1. 文件处理基础请求

    class FileProcessBase(BaseModel):
        knowledge_id: int = Field(..., description="知识库ID")
        separator: Optional[List[str]] = Field(default=['\n\n'], description="切分文本规则, 不传则为默认")
        separator_rule: Optional[List[str]] = Field(default=['after'],
                                                    description="切分规则前还是后进行切分;before/after")
        chunk_size: Optional[int] = Field(default=1000, description="切分文本长度,不传则为默认")
        chunk_overlap: Optional[int] = Field(default=100, description="切分文本重叠长度,不传则为默认")
    
        @root_validator
        def check_separator_rule(cls, values):
            if values['separator'] is None:
                values['separator'] = ['\n\n']
            if values['separator_rule'] is None:
                values['separator_rule'] = ['after' for _ in values["separator"]]
            if values['chunk_size'] is None:
                values['chunk_size'] = 1000
            if values['chunk_overlap'] is None:
                values['chunk_overlap'] = 100
            return values
    
    • 用途:表示文件处理的基础请求参数,用于文件切分。
    • 字段
      • knowledge_id:知识库的唯一标识符。
      • separator:切分文本的规则,默认值为 ['\n\n'],可选。
      • separator_rule:切分规则前还是后进行切分,默认值为 ['after'],可选。
      • chunk_size:切分文本的长度,默认值为 1000,可选。
      • chunk_overlap:切分文本的重叠长度,默认值为 100,可选。
    • 验证器
      • check_separator_rule:确保所有可选字段在未提供时设置为默认值。
  2. 文件分块元数据

    class FileChunkMetadata(BaseModel):
        source: str = Field(default='', description="源文件名")
        title: str = Field(default='', description="源文件内容总结的标题")
        chunk_index: int = Field(default=0, description="文本块索引")
        bbox: str = Field(default='', description="文本块bbox信息")
        page: int = Field(default=0, description="文本块所在页码")
        extra: str = Field(default='', description="文本块额外信息")
        file_id: int = Field(default=0, description="文本块所属文件ID")
    
    • 用途:表示文件分块的元数据信息。
    • 字段:
      • source:源文件名,默认值为空字符串。
      • title:源文件内容总结的标题,默认值为空字符串。
      • chunk_index:文本块的索引,默认值为 0
      • bbox:文本块的 bounding box 信息,默认值为空字符串。
      • page:文本块所在的页码,默认值为 0
      • extra:文本块的额外信息,默认值为空字符串。
      • file_id:文本块所属文件的唯一标识符,默认值为 0
  3. 文件分块

    class FileChunk(BaseModel):
        text: str = Field(..., description="文本块内容")
        parse_type: Optional[str] = Field(default=None, description="文本所属的文件解析类型")
        metadata: FileChunkMetadata = Field(..., description="文本块元数据")
    
    • 用途:表示文件的一个分块。
    • 字段:
      • text:文本块的内容。
      • parse_type:文本所属的文件解析类型,可选。
      • metadata:文本块的元数据,类型为 FileChunkMetadata
  4. 预览文件分块请求

    class PreviewFileChunk(FileProcessBase):
        file_path: str = Field(..., description="文件路径")
        cache: bool = Field(default=True, description="是否从缓存获取")
    
    • 用途:表示预览文件分块内容的请求。
    • 字段:
      • file_path:文件的路径。
      • cache:是否从缓存获取,默认值为 True
  5. 更新预览文件分块

    class UpdatePreviewFileChunk(BaseModel):
        knowledge_id: int = Field(..., description="知识库ID")
        file_path: str = Field(..., description="文件路径")
        text: str = Field(..., description="文本块内容")
        chunk_index: int = Field(..., description="文本块索引, 在metadata里")
        bbox: Optional[str] = Field(default='', description="文本块bbox信息")
    
    • 用途:表示更新预览文件分块内容的请求。
    • 字段:
      • knowledge_id:知识库的唯一标识符。
      • file_path:文件的路径。
      • text:文本块的内容。
      • chunk_index:文本块的索引。
      • bbox:文本块的 bounding box 信息,可选,默认值为空字符串。
  6. 知识库文件单项

    class KnowledgeFileOne(BaseModel):
        file_path: str = Field(..., description="文件路径")
    
    • 用途:表示知识库中单个文件的路径。
    • 字段:
      • file_path:文件的路径。
  7. 知识库文件处理

    class KnowledgeFileProcess(FileProcessBase):
        file_list: List[KnowledgeFileOne] = Field(..., description="文件列表")
        callback_url: Optional[str] = Field(default=None, description="异步任务回调地址")
        extra: Optional[str] = Field(default=None, description="附加信息")
    
    • 用途:表示知识库文件处理的请求。
    • 字段:
      • file_list:文件列表,类型为 List[KnowledgeFileOne]
      • callback_url:异步任务的回调地址,可选。
      • extra:附加信息,可选。

总结

这段代码定义了一系列用于数据验证和序列化的模型,涵盖了用户认证、聊天管理、知识库处理、模型和服务器配置等多个方面。以下是关键点总结:

  1. 数据验证和序列化:通过 PydanticBaseModel,确保 API 接收和返回的数据结构符合预期,提升代码的可靠性和可维护性。
  2. 通用响应结构:使用泛型的 UnifiedResponseModel 和辅助函数 resp_200resp_500,统一 API 的响应格式,便于前端处理和错误管理。
  3. 业务逻辑覆盖
    • 用户认证:如 CaptchaInputCreateUserReq
    • 聊天管理:如 ChatMessageChatResponseAddChatMessages
    • 知识库处理:如 KnowledgeFileProcessFileChunk
    • 模型和服务器配置:如 LLMModelCreateReqLLMServerCreateReq
    • 流程管理:如 FlowListCreateFlowCompareReq
  4. 验证器:通过 @validator@root_validator,进一步确保数据的合法性和完整性。
  5. 类型注解:广泛使用类型注解(如 OptionalListDict 等),提高代码的可读性和静态分析工具的效果。

base.py

import uuid
from contextlib import contextmanager

from bisheng.database.service import DatabaseService
from bisheng.settings import settings
from bisheng.utils.logger import logger
from sqlmodel import Session

db_service: 'DatabaseService' = DatabaseService(settings.database_url)


@contextmanager
def session_getter() -> Session:
    """轻量级session context"""
    try:
        session = Session(db_service.engine)
        yield session
    except Exception as e:
        logger.info('Session rollback because of exception:{}', e)
        session.rollback()
        raise
    finally:
        session.close()


def generate_uuid() -> str:
    """
    生成uuid的字符串
    """
    return uuid.uuid4().hex

1. 导入部分

标准库导入

import uuid
from contextlib import contextmanager
  • uuid
    • 用途:用于生成唯一标识符(UUID)。
    • 常用功能:
      • uuid.uuid4():生成一个随机的 UUID。
      • uuid.uuid4().hex:获取 UUID 的 32 位十六进制字符串表示。
  • contextmanager
    • 用途:用于创建上下文管理器,通常与 with 语句一起使用。
    • 功能:简化上下文管理器的创建,避免手动编写类。

项目内部模块导入

from bisheng.database.service import DatabaseService
from bisheng.settings import settings
from bisheng.utils.logger import logger
  • bisheng.database.service.DatabaseService
    • 用途:数据库服务类,通常封装了与数据库的连接和操作逻辑。
    • 假设功能:
      • 初始化数据库引擎。
      • 提供数据库会话管理。
      • 可能包含其他数据库相关的服务方法。
  • bisheng.settings.settings
    • 用途:项目的配置对象,包含了各种配置参数,如数据库 URL、API 密钥等。
    • 功能:通过 settings.database_url 访问数据库连接字符串。
  • bisheng.utils.logger.logger
    • 用途:日志记录器,用于记录日志信息。
    • 功能:
      • 记录信息、警告、错误等不同级别的日志。
      • 便于调试和监控应用程序运行状态。

第三方库导入

from sqlmodel import Session
  • sqlmodel.Session
    • 用途:用于与数据库交互的会话对象。
    • 功能:
      • 提供 CRUD(创建、读取、更新、删除)操作。
      • 管理数据库事务。
      • 处理数据库连接的打开和关闭。

2. 数据库服务实例化

db_service: 'DatabaseService' = DatabaseService(settings.database_url)
  • db_service
    • 类型注解'DatabaseService'(使用字符串注解以避免循环导入问题)。
    • 用途:创建一个 DatabaseService 的实例,用于管理与数据库的连接。
    • 初始化:
      • 参数settings.database_url,即数据库的连接字符串。
      • 功能:内部会根据 database_url 配置数据库引擎(如 SQLAlchemy 引擎)。

示例

假设 settings.database_url"postgresql://user:password@localhost/dbname",那么 DatabaseService 会使用这个 URL 初始化与 PostgreSQL 数据库的连接。

3. 上下文管理器:session_getter

@contextmanager
def session_getter() -> Session:
    """轻量级session context"""
    try:
        session = Session(db_service.engine)
        yield session
    except Exception as e:
        logger.info('Session rollback because of exception:{}', e)
        session.rollback()
        raise
    finally:
        session.close()

功能和用途

  • 目的:提供一个安全且便捷的方式来管理数据库会话,确保会话在使用后正确关闭,并在发生异常时回滚事务。
  • 使用场景:
    • 在执行数据库操作时使用 with session_getter() as session:,确保会话的正确管理。
    • 处理事务,确保数据一致性和完整性。

工作原理

  1. 创建会话

    session = Session(db_service.engine)
    
    • db_service.engine:获取数据库引擎,用于创建会话。
  2. 使用 yield 传递会话对象

    yield session
    
    • 允许 with 语句块内的代码使用该会话对象。
  3. 异常处理

    except Exception as e:
        logger.info('Session rollback because of exception:{}', e)
        session.rollback()
        raise
    
    • 记录日志:使用 logger 记录异常信息。
    • 事务回滚:确保在发生异常时,事务不会部分提交,保持数据一致性。
    • 重新抛出异常:让上层调用者知道发生了错误。
  4. 关闭会话

    finally:
        session.close()
    
    • 确保资源释放:无论是否发生异常,都会关闭会话,释放数据库连接资源。

示例用法

from bisheng.database.utils import session_getter

def get_user(user_id: int):
    with session_getter() as session:
        user = session.query(User).filter(User.id == user_id).first()
        return user

在这个示例中:

  • with session_getter() as session:获取一个数据库会话。
  • 执行查询:通过会话对象执行数据库查询。
  • 自动管理:退出 with 块时,会自动关闭会话,确保资源释放。

4. UUID 生成函数:generate_uuid

def generate_uuid() -> str:
    """
    生成uuid的字符串
    """
    return uuid.uuid4().hex

功能和用途

  • 目的:生成一个唯一的 UUID 字符串,用于标识对象或记录。
  • 用途:
    • 为数据库记录生成唯一标识符。
    • 生成会话 ID、交易 ID 等。
    • 确保标识符的唯一性,避免冲突。

工作原理

  1. 生成 UUID

    uuid.uuid4()
    
    • uuid4:生成一个随机的 UUID(基于随机数)。
  2. 获取十六进制表示

    .hex
    
    • hex 属性:将 UUID 对象转换为 32 位十六进制字符串,没有连字符(-)。

示例

unique_id = generate_uuid()
print(unique_id)  # 输出示例:e4eaaaf2dcd24d5f8e4e8c1b0b3f1f6a
  • 优势

    • 唯一性:UUID 生成的概率几乎为零,确保每个 ID 的唯一性。
    • 不可预测性:基于随机数生成,难以预测下一个 UUID 的值。

注意事项

  • 长度:UUID 的 hex 表示长度为 32 个字符。
  • 存储:在数据库中存储时,考虑使用合适的数据类型(如 VARCHAR(32)CHAR(32))。
  • 性能:UUID 的随机性可能导致数据库索引性能下降,特别是在高频率写入的情况下。可以考虑使用有序 UUID(如 UUID1)以优化索引性能。

5. 总结

这段代码主要涵盖了以下几个关键方面:

  1. 数据库服务初始化
    • 使用 DatabaseService 类初始化数据库连接,基于项目配置中的 database_url
  2. 会话管理
    • 定义了一个上下文管理器 session_getter,用于安全地管理数据库会话,确保会话在使用后正确关闭,并在异常发生时回滚事务。
  3. UUID 生成
    • 提供了一个简单的函数 generate_uuid,用于生成唯一的 UUID 字符串,确保标识符的唯一性和不可预测性。

如何在项目中使用

  • 数据库操作

    • 使用 session_getter 上下文管理器来执行数据库操作,确保会话的正确管理和资源的释放。
    from bisheng.database.utils import session_getter, generate_uuid
    
    def create_user(name: str, email: str):
        user_id = generate_uuid()
        with session_getter() as session:
            new_user = User(id=user_id, name=name, email=email)
            session.add(new_user)
            session.commit()
        return new_user
    
  • 生成唯一标识符

    • 使用 generate_uuid 函数为新记录生成唯一的 ID。
    unique_id = generate_uuid()
    print(f"Generated UUID: {unique_id}")
    

进一步阅读

  • DatabaseService
    • 查看 bisheng.database.service 模块中的 DatabaseService 类实现,了解其具体功能和方法。
  • 项目配置
    • 查看 bisheng.settings 模块,了解项目的配置管理,尤其是 database_url 的设置方式。
  • 日志记录器
    • 查看 bisheng.utils.logger 模块,了解日志记录器的配置和使用方法,以便更好地调试和监控应用程序。
  • sqlmodel
    • sqlmodel 是一个用于与 SQL 数据库交互的库,结合了 SQLAlchemy 和 Pydantic 的特性。建议阅读 SQLModel 官方文档 以深入理解其用法和最佳实践。

常见问题

  1. 为什么使用 contextmanager 装饰器创建 session_getter
    • 使用 contextmanager 可以简化上下文管理器的创建,使代码更加简洁和易读。它允许使用 with 语句块自动管理资源,如数据库会话的打开和关闭。
  2. 为什么在异常处理中回滚事务?
    • 当数据库操作过程中发生异常时,事务可能处于不一致状态。回滚事务可以确保数据库不会部分提交数据,保持数据的一致性和完整性。
  3. UUID 为什么使用 .hex 属性而不是直接使用 uuid4() 生成的对象?
    • 使用 .hex 属性可以获得一个标准的 32 位十六进制字符串表示,便于存储和传输。相比直接使用 UUID 对象,字符串表示更容易与其他系统集成和调试。

UserService.py

这段代码主要涉及用户身份验证、权限管理、用户创建以及与数据库交互的服务逻辑。代码使用了 FastAPI 框架,并结合了 PydanticJWTRedis 等技术。以下是对每一部分的详细解析:

import functools
import json
from base64 import b64decode
from typing import List

import rsa
from bisheng.api.errcode.base import UnAuthorizedError
from bisheng.api.errcode.user import (UserLoginOfflineError, UserNameAlreadyExistError,
                                      UserNeedGroupAndRoleError)
from bisheng.api.JWT import ACCESS_TOKEN_EXPIRE_TIME
from bisheng.api.utils import md5_hash
from bisheng.api.v1.schemas import CreateUserReq
from bisheng.cache.redis import redis_client
from bisheng.database.models.assistant import Assistant, AssistantDao
from bisheng.database.models.flow import Flow, FlowDao, FlowRead
from bisheng.database.models.knowledge import Knowledge, KnowledgeDao, KnowledgeRead
from bisheng.database.models.role import AdminRole
from bisheng.database.models.role_access import AccessType, RoleAccessDao
from bisheng.database.models.user import User, UserDao
from bisheng.database.models.user_group import UserGroupDao
from bisheng.database.models.user_role import UserRoleDao
from bisheng.settings import settings
from bisheng.utils.constants import RSA_KEY, USER_CURRENT_SESSION
from fastapi import Depends, HTTPException, Request
from fastapi_jwt_auth import AuthJWT

1. 导入部分

标准库导入

import functools
import json
from base64 import b64decode
from typing import List
  • functools

    • 用途:用于高级函数操作,如装饰器。

    • 常用功能

      • functools.wraps:用于装饰器,保持原函数的元数据。
  • json

    • 用途:用于 JSON 数据的序列化和反序列化。

    • 常用功能

      • json.loads:将 JSON 字符串转换为 Python 对象。
      • json.dumps:将 Python 对象转换为 JSON 字符串。
  • base64.b64decode

    • 用途:用于解码 Base64 编码的数据。
    • 应用场景:常用于处理加密数据或传输二进制数据。
  • typing.List

    • 用途:用于类型注解,表示列表类型。
    • 示例List[int] 表示整数列表。

第三方库导入

import rsa
from fastapi import Depends, HTTPException, Request
from fastapi_jwt_auth import AuthJWT
  • rsa

    • 用途:用于 RSA 加密和解密。

    • 常用功能

      • rsa.decrypt:使用私钥解密数据。
      • rsa.encrypt:使用公钥加密数据。
  • fastapi

    • 用途:用于构建 Web API。

    • 常用功能

      • Depends:用于依赖注入。
      • HTTPException:用于抛出 HTTP 异常。
      • Request:表示请求对象。
  • fastapi_jwt_auth.AuthJWT

    • 用途:用于处理 JWT(JSON Web Token)认证。

    • 常用功能

      • create_access_token:创建访问令牌。
      • create_refresh_token:创建刷新令牌。
      • jwt_required:装饰器,用于保护路由,需要 JWT 认证。

项目内部模块导入

from bisheng.api.errcode.base import UnAuthorizedError
from bisheng.api.errcode.user import (UserLoginOfflineError, UserNameAlreadyExistError,
                                      UserNeedGroupAndRoleError)
from bisheng.api.JWT import ACCESS_TOKEN_EXPIRE_TIME
from bisheng.api.utils import md5_hash
from bisheng.api.v1.schemas import CreateUserReq
from bisheng.cache.redis import redis_client
from bisheng.database.models.assistant import Assistant, AssistantDao
from bisheng.database.models.flow import Flow, FlowDao, FlowRead
from bisheng.database.models.knowledge import Knowledge, KnowledgeDao, KnowledgeRead
from bisheng.database.models.role import AdminRole
from bisheng.database.models.role_access import AccessType, RoleAccessDao
from bisheng.database.models.user import User, UserDao
from bisheng.database.models.user_group import UserGroupDao
from bisheng.database.models.user_role import UserRoleDao
from bisheng.settings import settings
from bisheng.utils.constants import RSA_KEY, USER_CURRENT_SESSION
  • bisheng.api.errcode.\*

    • 用途:定义自定义的 API 错误码和异常类,用于统一错误处理。

    • 常用异常

      • UnAuthorizedError:未授权错误。
      • UserLoginOfflineError:用户登录离线错误。
      • UserNameAlreadyExistError:用户名已存在错误。
      • UserNeedGroupAndRoleError:用户需要组和角色错误。
  • bisheng.api.JWT

    • 用途:存储 JWT 相关的配置常量,如访问令牌过期时间。

    • 字段

      • ACCESS_TOKEN_EXPIRE_TIME:访问令牌的过期时间。
  • bisheng.api.utils

    • 用途:存储各种实用函数。

    • 常用函数

      • md5_hash:生成字符串的 MD5 哈希值。
  • bisheng.api.v1.schemas.CreateUserReq

    • 用途:定义创建用户请求的数据模型,通常使用 Pydantic BaseModel
  • bisheng.cache.redis.redis_client

    • 用途:Redis 客户端实例,用于与 Redis 缓存交互。
  • bisheng.database.models.\*

    • 用途:导入数据库模型和数据访问对象(DAO)。

    • 常用模型和 DAO

      • Assistant, AssistantDao:助手模型和 DAO。
      • Flow, FlowDao, FlowRead:流程模型、DAO 和读取模型。
      • Knowledge, KnowledgeDao, KnowledgeRead:知识库模型、DAO 和读取模型。
      • AdminRole:管理员角色模型。
      • AccessType, RoleAccessDao:访问类型和角色访问 DAO。
      • User, UserDao:用户模型和 DAO。
      • UserGroupDao:用户组 DAO。
      • UserRoleDao:用户角色 DAO。
  • bisheng.settings.settings

    • 用途:项目的配置对象,包含了各种配置参数,如数据库 URL、RSA 密钥等。
  • bisheng.utils.constants

    • 用途:存储项目中使用的常量。

    • 字段

      • RSA_KEY:RSA 密钥的常量键。
      • USER_CURRENT_SESSION:用户当前会话的 Redis 键模板。

2. UserPayload

class UserPayload:

    def __init__(self, **kwargs):
        self.user_id = kwargs.get('user_id')
        self.user_role = kwargs.get('role')
        if self.user_role != 'admin':  # 非管理员用户,需要获取他的角色列表
            roles = UserRoleDao.get_user_roles(self.user_id)
            self.user_role = [one.role_id for one in roles]
        self.user_name = kwargs.get('user_name')

    def is_admin(self):
        if self.user_role == 'admin':
            return True
        if isinstance(self.user_role, list):
            for one in self.user_role:
                if one == AdminRole:
                    return True
        return False

    @staticmethod
    def wrapper_access_check(func):
        """
        权限检查的装饰器
        如果是admin用户则不执行后续具体的检查逻辑
        """

        @functools.wraps(func)
        def wrapper(*args, **kwargs):
            if args[0].is_admin():
                return True
            return func(*args, **kwargs)

        return wrapper

    @wrapper_access_check
    def access_check(self, owner_user_id: int, target_id: str, access_type: AccessType) -> bool:
        """
            检查用户是否有某个资源的权限
        """
        # 判断是否属于本人资源
        if self.user_id == owner_user_id:
            return True
        # 判断授权
        if RoleAccessDao.judge_role_access(self.user_role, target_id, access_type):
            return True
        return False

    @wrapper_access_check
    def check_group_admin(self, group_id: int) -> bool:
        """
            检查用户是否是某个组的管理员
        """
        # 判断是否是用户组的管理员
        user_group = UserGroupDao.get_user_admin_group(self.user_id)
        if not user_group:
            return False
        for one in user_group:
            if one.group_id == group_id:
                return True
        return False

    @wrapper_access_check
    def check_groups_admin(self, group_ids: List[int]) -> bool:
        """
        检查用户是否是用户组列表中的管理员,有一个就是true
        """
        user_groups = UserGroupDao.get_user_admin_group(self.user_id)
        for one in user_groups:
            if one.is_group_admin and one.group_id in group_ids:
                return True
        return False

功能和用途

UserPayload 类用于封装当前登录用户的相关信息,并提供权限检查的方法。这在构建基于角色的访问控制(RBAC)系统中非常常见。

详细解析

构造函数 __init__
def __init__(self, **kwargs):
    self.user_id = kwargs.get('user_id')
    self.user_role = kwargs.get('role')
    if self.user_role != 'admin':  # 非管理员用户,需要获取他的角色列表
        roles = UserRoleDao.get_user_roles(self.user_id)
        self.user_role = [one.role_id for one in roles]
    self.user_name = kwargs.get('user_name')
  • 功能:初始化用户的基本信息,包括 user_iduser_roleuser_name

  • 逻辑

    • kwargs 获取 user_idroleuser_name
    • 如果用户不是管理员 (self.user_role != 'admin'),则从数据库获取该用户的所有角色,并将其存储为角色 ID 的列表。
方法 is_admin
def is_admin(self):
    if self.user_role == 'admin':
        return True
    if isinstance(self.user_role, list):
        for one in self.user_role:
            if one == AdminRole:
                return True
    return False
  • 功能:判断用户是否具有管理员权限。

  • 逻辑

    • 如果 user_role'admin',返回 True
    • 如果 user_role 是一个列表,遍历列表,检查是否包含 AdminRole,如果包含,返回 True
    • 否则,返回 False
装饰器 wrapper_access_check
@staticmethod
def wrapper_access_check(func):
    """
    权限检查的装饰器
    如果是admin用户则不执行后续具体的检查逻辑
    """

    @functools.wraps(func)
    def wrapper(*args, **kwargs):
        if args[0].is_admin():
            return True
        return func(*args, **kwargs)

    return wrapper
  • 功能:定义一个装饰器,用于权限检查。

  • 逻辑

    • 如果用户是管理员,直接返回 True,跳过后续的具体权限检查逻辑。
    • 否则,执行被装饰的函数。
方法 access_check
@wrapper_access_check
def access_check(self, owner_user_id: int, target_id: str, access_type: AccessType) -> bool:
    """
        检查用户是否有某个资源的权限
    """
    # 判断是否属于本人资源
    if self.user_id == owner_user_id:
        return True
    # 判断授权
    if RoleAccessDao.judge_role_access(self.user_role, target_id, access_type):
        return True
    return False
  • 功能:检查用户是否有访问某个资源的权限。

  • 逻辑

    • 如果资源的拥有者是当前用户,返回 True
    • 否则,调用 RoleAccessDao.judge_role_access 方法检查用户角色是否有访问该资源的权限。
    • 返回相应的布尔值。
方法 check_group_admin
@wrapper_access_check
def check_group_admin(self, group_id: int) -> bool:
    """
        检查用户是否是某个组的管理员
    """
    # 判断是否是用户组的管理员
    user_group = UserGroupDao.get_user_admin_group(self.user_id)
    if not user_group:
        return False
    for one in user_group:
        if one.group_id == group_id:
            return True
    return False
  • 功能:检查用户是否是某个用户组的管理员。

  • 逻辑

    • 调用 UserGroupDao.get_user_admin_group 获取用户所属的管理员组。
    • 如果用户不属于任何管理员组,返回 False
    • 遍历管理员组,检查是否包含指定的 group_id,如果包含,返回 True
    • 否则,返回 False
方法 check_groups_admin
@wrapper_access_check
def check_groups_admin(self, group_ids: List[int]) -> bool:
    """
    检查用户是否是用户组列表中的管理员,有一个就是true
    """
    user_groups = UserGroupDao.get_user_admin_group(self.user_id)
    for one in user_groups:
        if one.is_group_admin and one.group_id in group_ids:
            return True
    return False
  • 功能:检查用户是否是给定用户组列表中的任意一个用户组的管理员。

  • 逻辑

    • 调用 UserGroupDao.get_user_admin_group 获取用户所属的管理员组。
    • 遍历管理员组,检查是否有任何一个用户组的 group_id 在给定的 group_ids 列表中,并且用户是该组的管理员 (one.is_group_admin)。
    • 如果找到符合条件的用户组,返回 True
    • 否则,返回 False

3. UserService

class UserService:

    @classmethod
    def decrypt_md5_password(cls, password: str):
        if value := redis_client.get(RSA_KEY):
            private_key = value[1]
            password = md5_hash(rsa.decrypt(b64decode(password), private_key).decode('utf-8'))
        else:
            password = md5_hash(password)
        return password

    @classmethod
    def create_user(cls, request: Request, login_user: UserPayload, req_data: CreateUserReq):
        """
        创建用户
        """
        exists_user = UserDao.get_user_by_username(req_data.user_name)
        if exists_user:
            # 抛出异常
            raise UserNameAlreadyExistError.http_exception()
        user = User(
            user_name=req_data.user_name,
            password=cls.decrypt_md5_password(req_data.password),
        )
        group_ids = []
        role_ids = []
        for one in req_data.group_roles:
            group_ids.append(one.group_id)
            role_ids.extend(one.role_ids)
        if not group_ids or not role_ids:
            raise UserNeedGroupAndRoleError.http_exception()
        user = UserDao.add_user_with_groups_and_roles(user, group_ids, role_ids)
        return user

功能和用途

UserService 类提供了与用户相关的业务逻辑,如密码解密和用户创建。这种类通常包含静态方法或类方法,用于处理特定的业务操作。

详细解析

方法 decrypt_md5_password
@classmethod
def decrypt_md5_password(cls, password: str):
    if value := redis_client.get(RSA_KEY):
        private_key = value[1]
        password = md5_hash(rsa.decrypt(b64decode(password), private_key).decode('utf-8'))
    else:
        password = md5_hash(password)
    return password
  • 功能:解密密码并生成 MD5 哈希值。
  • 逻辑
    • 尝试从 Redis 获取 RSA 密钥 (RSA_KEY)。
    • 如果存在 RSA 密钥:
      • 使用 RSA 私钥解密 Base64 编码的密码。
      • 将解密后的密码转换为 UTF-8 字符串。
      • 对解密后的密码进行 MD5 哈希。
    • 如果不存在 RSA 密钥:
      • 直接对原始密码进行 MD5 哈希。
    • 返回哈希后的密码。
  • 应用场景
    • 用户注册或密码更新时,将用户输入的密码加密存储到数据库。
    • 增强密码的安全性,通过加密和哈希减少明文密码的存储风险。
方法 create_user
@classmethod
def create_user(cls, request: Request, login_user: UserPayload, req_data: CreateUserReq):
    """
    创建用户
    """
    exists_user = UserDao.get_user_by_username(req_data.user_name)
    if exists_user:
        # 抛出异常
        raise UserNameAlreadyExistError.http_exception()
    user = User(
        user_name=req_data.user_name,
        password=cls.decrypt_md5_password(req_data.password),
    )
    group_ids = []
    role_ids = []
    for one in req_data.group_roles:
        group_ids.append(one.group_id)
        role_ids.extend(one.role_ids)
    if not group_ids or not role_ids:
        raise UserNeedGroupAndRoleError.http_exception()
    user = UserDao.add_user_with_groups_and_roles(user, group_ids, role_ids)
    return user
  • 功能:创建新用户,并将其分配到指定的用户组和角色中。
  • 参数
    • request:请求对象,包含请求相关的信息。
    • login_user:当前登录用户的信息,类型为 UserPayload
    • req_data:创建用户的请求数据,类型为 CreateUserReq
  • 逻辑
    1. 检查用户名是否存在
      • 调用 UserDao.get_user_by_username 检查请求中的用户名是否已经存在。
      • 如果用户已存在,抛出 UserNameAlreadyExistError 异常。
    2. 创建用户实例
      • 使用 CreateUserReq 中的 user_name 和解密后的密码创建一个新的 User 实例。
    3. 收集用户组和角色 ID
      • 遍历 req_data.group_roles,收集所有的 group_idrole_id
    4. 检查用户组和角色
      • 如果用户没有指定任何用户组或角色,抛出 UserNeedGroupAndRoleError 异常。
    5. 添加用户并关联组和角色
      • 调用 UserDao.add_user_with_groups_and_roles 方法,将用户添加到数据库,并关联指定的用户组和角色。
    6. 返回创建的用户
      • 返回新创建的 User 实例。
  • 应用场景
    • 用户注册功能,允许管理员或系统自动创建新用户,并分配相应的权限和组。

4. 辅助函数

函数 sso_login

def sso_login():
    pass
  • 功能:此函数目前为空,实现细节未提供。
  • 用途:通常用于单点登录(SSO)功能,允许用户通过单一认证源登录多个系统。

函数 gen_user_role

def gen_user_role(db_user: User):
    # 查询用户的角色列表
    db_user_role = UserRoleDao.get_user_roles(db_user.user_id)
    role = ''
    role_ids = []
    for user_role in db_user_role:
        if user_role.role_id == 1:
            # 是管理员,忽略其他的角色
            role = 'admin'
        else:
            role_ids.append(user_role.role_id)
    if role != 'admin':
        # 判断是否是用户组管理员
        db_user_groups = UserGroupDao.get_user_admin_group(db_user.user_id)
        if len(db_user_groups) > 0:
            role = 'group_admin'
        else:
            role = role_ids
    # 获取用户的菜单栏权限列表
    web_menu = RoleAccessDao.get_role_access(role_ids, AccessType.WEB_MENU)
    web_menu = list(set([one.third_id for one in web_menu]))
    return role, web_menu
  • 功能:生成用户的角色和相应的菜单权限。
  • 参数
    • db_user:数据库中的 User 实例。
  • 逻辑
    1. 查询用户角色
      • 调用 UserRoleDao.get_user_roles 获取用户的角色列表。
    2. 确定角色
      • 遍历用户角色列表:
        • 如果用户具有角色 ID 为 1,将角色设置为 'admin',并忽略其他角色。
        • 否则,将角色 ID 添加到 role_ids 列表中。
    3. 检查用户组管理员身份
      • 如果用户不是管理员,调用 UserGroupDao.get_user_admin_group 检查用户是否是某个用户组的管理员。
      • 如果是用户组管理员,设置角色为 'group_admin'
      • 否则,角色保持为角色 ID 的列表。
    4. 获取菜单权限
      • 调用 RoleAccessDao.get_role_access 获取用户角色对应的菜单权限。
      • 提取并去重 third_id,形成菜单权限列表 web_menu
    5. 返回角色和菜单权限
      • 返回用户的角色(字符串或列表)和菜单权限列表。
  • 应用场景
    • 在用户登录后,生成用户的权限信息,用于前端展示和后续的权限校验。

函数 gen_user_jwt

def gen_user_jwt(db_user: User):
    if 1 == db_user.delete:
        raise HTTPException(status_code=500, detail='该账号已被禁用,请联系管理员')
    # 查询角色
    role, web_menu = gen_user_role(db_user)
    # 生成JWT令牌
    payload = {'user_name': db_user.user_name, 'user_id': db_user.user_id, 'role': role}
    # Create the tokens and passing to set_access_cookies or set_refresh_cookies
    access_token = AuthJWT().create_access_token(subject=json.dumps(payload),
                                                 expires_time=ACCESS_TOKEN_EXPIRE_TIME)

    refresh_token = AuthJWT().create_refresh_token(subject=db_user.user_name)

    # Set the JWT cookies in the response
    return access_token, refresh_token, role, web_menu
  • 功能:为用户生成 JWT 访问令牌和刷新令牌。
  • 参数
    • db_user:数据库中的 User 实例。
  • 逻辑
    1. 检查用户状态
      • 如果 db_user.delete 等于 1,表示账号被禁用,抛出 HTTPException,状态码为 500,提示用户联系管理员。
    2. 获取用户角色和菜单权限
      • 调用 gen_user_role 函数,获取用户的角色和菜单权限。
    3. 生成 JWT 载荷
      • 创建一个包含 user_nameuser_idrole 的字典 payload
    4. 生成访问令牌和刷新令牌
      • 调用 AuthJWT().create_access_token 生成访问令牌,主题为 payload 的 JSON 字符串,过期时间为 ACCESS_TOKEN_EXPIRE_TIME
      • 调用 AuthJWT().create_refresh_token 生成刷新令牌,主题为用户的用户名。
    5. 返回令牌和权限信息
      • 返回访问令牌、刷新令牌、角色和菜单权限。
  • 应用场景
    • 在用户成功登录后,生成并返回 JWT 令牌,用于后续的认证和授权。

函数 get_knowledge_list_by_access

def get_knowledge_list_by_access(role_id: int, name: str, page_num: int, page_size: int):
    count_filter = []
    if name:
        count_filter.append(Knowledge.name.like('%{}%'.format(name)))

    db_role_access = KnowledgeDao.get_knowledge_by_access(role_id, page_num, page_size)
    total_count = KnowledgeDao.get_count_by_filter(count_filter)
    # 补充用户名
    user_ids = [access[0].user_id for access in db_role_access]
    db_users = UserDao.get_user_by_ids(user_ids)
    user_dict = {user.user_id: user.user_name for user in db_users}

    return {
        'data': [
            KnowledgeRead.validate({
                'name': access[0].name,
                'user_name': user_dict.get(access[0].user_id),
                'user_id': access[0].user_id,
                'update_time': access[0].update_time,
                'id': access[0].id
            }) for access in db_role_access
        ],
        'total':
            total_count
    }
  • 功能:根据用户角色获取有访问权限的知识库列表,并支持分页和名称过滤。
  • 参数
    • role_id:用户的角色 ID。
    • name:知识库名称的搜索关键字。
    • page_num:分页的页码。
    • page_size:每页显示的记录数。
  • 逻辑
    1. 构建过滤条件
      • 如果提供了 name,则添加一个 SQL LIKE 过滤条件,用于模糊匹配知识库名称。
    2. 查询有访问权限的知识库
      • 调用 KnowledgeDao.get_knowledge_by_access,传入 role_idpage_numpage_size,获取用户有权限访问的知识库列表。
    3. 查询总数
      • 调用 KnowledgeDao.get_count_by_filter,传入过滤条件,获取满足条件的知识库总数。
    4. 补充用户名
      • 从知识库访问列表中提取所有的 user_id
      • 调用 UserDao.get_user_by_ids 获取对应的用户信息。
      • 构建一个字典 user_dict,将 user_id 映射到 user_name
    5. 构建响应数据
      • 遍历 db_role_access,为每个知识库构建一个包含 nameuser_nameuser_idupdate_timeid 的字典,并通过 KnowledgeRead.validate 进行数据验证和序列化。
      • 返回包含 datatotal 的字典。
  • 应用场景
    • 用于前端页面展示用户有权限访问的知识库列表,并支持搜索和分页功能。

函数 get_flow_list_by_access

def get_flow_list_by_access(role_id: int, name: str, page_num: int, page_size: int):
    count_filter = []
    if name:
        count_filter.append(Flow.name.like('%{}%'.format(name)))

    db_role_access = FlowDao.get_flow_by_access(role_id, name, page_num, page_size)
    total_count = FlowDao.get_count_by_filters(count_filter)
    # 补充用户名
    user_ids = [access[0].user_id for access in db_role_access]
    db_users = UserDao.get_user_by_ids(user_ids)
    user_dict = {user.user_id: user.user_name for user in db_users}

    return {
        'data': [
            FlowRead.validate({
                'name': access[0].name,
                'user_name': user_dict.get(access[0].user_id),
                'user_id': access[0].user_id,
                'update_time': access[0].update_time,
                'id': access[0].id
            }) for access in db_role_access
        ],
        'total':
            total_count
    }
  • 功能:根据用户角色获取有访问权限的流程列表,并支持分页和名称过滤。
  • 参数
    • role_id:用户的角色 ID。
    • name:流程名称的搜索关键字。
    • page_num:分页的页码。
    • page_size:每页显示的记录数。
  • 逻辑
    1. 构建过滤条件
      • 如果提供了 name,则添加一个 SQL LIKE 过滤条件,用于模糊匹配流程名称。
    2. 查询有访问权限的流程
      • 调用 FlowDao.get_flow_by_access,传入 role_idnamepage_numpage_size,获取用户有权限访问的流程列表。
    3. 查询总数
      • 调用 FlowDao.get_count_by_filters,传入过滤条件,获取满足条件的流程总数。
    4. 补充用户名
      • 从流程访问列表中提取所有的 user_id
      • 调用 UserDao.get_user_by_ids 获取对应的用户信息。
      • 构建一个字典 user_dict,将 user_id 映射到 user_name
    5. 构建响应数据
      • 遍历 db_role_access,为每个流程构建一个包含 nameuser_nameuser_idupdate_timeid 的字典,并通过 FlowRead.validate 进行数据验证和序列化。
      • 返回包含 datatotal 的字典。
  • 应用场景
    • 用于前端页面展示用户有权限访问的流程列表,并支持搜索和分页功能。

函数 get_assistant_list_by_access

def get_assistant_list_by_access(role_id: int, name: str, page_num: int, page_size: int):
    count_filter = []
    if name:
        count_filter.append(Assistant.name.like('%{}%'.format(name)))

    db_role_access = AssistantDao.get_assistants_by_access(role_id, name, page_size, page_num)
    total_count = AssistantDao.get_count_by_filters(count_filter)
    # 补充用户名
    user_ids = [access[0].user_id for access in db_role_access]
    db_users = UserDao.get_user_by_ids(user_ids)
    user_dict = {user.user_id: user.user_name for user in db_users}

    return {
        'data': [{
            'name': access[0].name,
            'user_name': user_dict.get(access[0].user_id),
            'user_id': access[0].user_id,
            'update_time': access[0].update_time,
            'id': access[0].id
        } for access in db_role_access],
        'total':
            total_count
    }
  • 功能:根据用户角色获取有访问权限的助手列表,并支持分页和名称过滤。
  • 参数
    • role_id:用户的角色 ID。
    • name:助手名称的搜索关键字。
    • page_num:分页的页码。
    • page_size:每页显示的记录数。
  • 逻辑
    1. 构建过滤条件
      • 如果提供了 name,则添加一个 SQL LIKE 过滤条件,用于模糊匹配助手名称。
    2. 查询有访问权限的助手
      • 调用 AssistantDao.get_assistants_by_access,传入 role_idnamepage_sizepage_num,获取用户有权限访问的助手列表。
    3. 查询总数
      • 调用 AssistantDao.get_count_by_filters,传入过滤条件,获取满足条件的助手总数。
    4. 补充用户名
      • 从助手访问列表中提取所有的 user_id
      • 调用 UserDao.get_user_by_ids 获取对应的用户信息。
      • 构建一个字典 user_dict,将 user_id 映射到 user_name
    5. 构建响应数据
      • 遍历 db_role_access,为每个助手构建一个包含 nameuser_nameuser_idupdate_timeid 的字典。
      • 返回包含 datatotal 的字典。
  • 应用场景
    • 用于前端页面展示用户有权限访问的助手列表,并支持搜索和分页功能。

5. 获取当前登录用户的依赖项

函数 get_login_user

async def get_login_user(authorize: AuthJWT = Depends()) -> UserPayload:
    """
    获取当前登录的用户
    """
    # 校验是否过期,过期则直接返回http 状态码的 401
    authorize.jwt_required()

    current_user = json.loads(authorize.get_jwt_subject())
    user = UserPayload(**current_user)

    # 判断是否允许多点登录
    if not settings.get_system_login_method().allow_multi_login:
        # 获取access_token
        current_token = redis_client.get(USER_CURRENT_SESSION.format(user.user_id))
        # 登录被挤下线了,http状态码是200, status_code是特殊code
        if current_token != authorize._token:
            raise UserLoginOfflineError.http_exception()
    return user
  • 功能:获取当前登录的用户信息,并进行相关的登录状态检查。
  • 参数
    • authorizeAuthJWT 实例,通过 FastAPI 的 Depends 注入。
  • 逻辑
    1. JWT 验证
      • 调用 authorize.jwt_required(),确保请求包含有效的 JWT 令牌。
      • 如果 JWT 过期或无效,自动抛出 401 Unauthorized 异常。
    2. 获取 JWT 载荷
      • 调用 authorize.get_jwt_subject() 获取 JWT 的主题部分(通常是用户信息的 JSON 字符串)。
      • 使用 json.loads 将 JSON 字符串转换为 Python 字典。
      • 使用 UserPayload 类初始化用户信息对象。
    3. 检查多点登录
      • 调用 settings.get_system_login_method().allow_multi_login 判断系统是否允许多点登录。
      • 如果不允许多点登录:
        • 从 Redis 获取当前用户的 access_token,键格式为 USER_CURRENT_SESSION.format(user.user_id)
        • 比较 Redis 中存储的 current_token 与当前请求中的 _token
        • 如果两者不匹配,说明用户在其他地方登录,抛出 UserLoginOfflineError 异常。
    4. 返回用户信息
      • 返回 UserPayload 实例,包含用户的 ID、角色和用户名。
  • 应用场景
    • 用于保护需要登录才能访问的 API 路由,确保只有认证用户可以访问。
    • 结合 FastAPI 的依赖注入机制,简化路由函数中的用户信息获取。

函数 get_admin_user

async def get_admin_user(authorize: AuthJWT = Depends()) -> UserPayload:
    """
    获取超级管理账号,非超级管理员用户,抛出异常
    """
    login_user = await get_login_user(authorize)
    if not login_user.is_admin():
        raise UnAuthorizedError.http_exception()
    return login_user
  • 功能:获取当前登录的超级管理员用户,如果用户不是管理员,则抛出未授权异常。
  • 参数
    • authorizeAuthJWT 实例,通过 FastAPI 的 Depends 注入。
  • 逻辑
    1. 获取当前登录用户
      • 调用 get_login_user 函数,获取当前登录用户的 UserPayload 实例。
    2. 检查管理员权限
      • 调用 login_user.is_admin() 判断用户是否是管理员。
      • 如果用户不是管理员,抛出 UnAuthorizedError 异常。
    3. 返回管理员用户信息
      • 如果用户是管理员,返回 UserPayload 实例。
  • 应用场景
    • 用于保护需要管理员权限才能访问的 API 路由,确保只有管理员用户可以访问。
    • 结合 FastAPI 的依赖注入机制,简化路由函数中的管理员用户信息获取和权限校验。

6. 辅助函数 gen_user_rolegen_user_jwtget_*_list_by_access

函数 gen_user_role

def gen_user_role(db_user: User):
    # 查询用户的角色列表
    db_user_role = UserRoleDao.get_user_roles(db_user.user_id)
    role = ''
    role_ids = []
    for user_role in db_user_role:
        if user_role.role_id == 1:
            # 是管理员,忽略其他的角色
            role = 'admin'
        else:
            role_ids.append(user_role.role_id)
    if role != 'admin':
        # 判断是否是用户组管理员
        db_user_groups = UserGroupDao.get_user_admin_group(db_user.user_id)
        if len(db_user_groups) > 0:
            role = 'group_admin'
        else:
            role = role_ids
    # 获取用户的菜单栏权限列表
    web_menu = RoleAccessDao.get_role_access(role_ids, AccessType.WEB_MENU)
    web_menu = list(set([one.third_id for one in web_menu]))
    return role, web_menu
  • 功能:生成用户的角色和菜单权限。
  • 参数
    • db_user:数据库中的 User 实例。
  • 逻辑
    1. 查询用户角色
      • 调用 UserRoleDao.get_user_roles 获取用户的角色列表。
    2. 确定用户角色
      • 遍历用户的角色列表:
        • 如果用户拥有角色 ID 为 1,则将角色设置为 'admin',并忽略其他角色。
        • 否则,将角色 ID 添加到 role_ids 列表中。
    3. 检查用户组管理员身份
      • 如果用户不是管理员,调用 UserGroupDao.get_user_admin_group 检查用户是否是某个用户组的管理员。
      • 如果是用户组管理员,设置角色为 'group_admin'
      • 否则,角色保持为角色 ID 的列表。
    4. 获取菜单权限
      • 调用 RoleAccessDao.get_role_access,传入 role_idsAccessType.WEB_MENU,获取用户的菜单权限。
      • 提取并去重 third_id,形成 web_menu 列表。
    5. 返回角色和菜单权限
      • 返回用户的角色(字符串或列表)和菜单权限列表。
  • 应用场景
    • 在用户登录后,生成用户的角色和权限信息,用于前端展示和权限校验。

函数 gen_user_jwt

def gen_user_jwt(db_user: User):
    if 1 == db_user.delete:
        raise HTTPException(status_code=500, detail='该账号已被禁用,请联系管理员')
    # 查询角色
    role, web_menu = gen_user_role(db_user)
    # 生成JWT令牌
    payload = {'user_name': db_user.user_name, 'user_id': db_user.user_id, 'role': role}
    # Create the tokens and passing to set_access_cookies or set_refresh_cookies
    access_token = AuthJWT().create_access_token(subject=json.dumps(payload),
                                                 expires_time=ACCESS_TOKEN_EXPIRE_TIME)

    refresh_token = AuthJWT().create_refresh_token(subject=db_user.user_name)

    # Set the JWT cookies in the response
    return access_token, refresh_token, role, web_menu
  • 功能:为用户生成 JWT 访问令牌和刷新令牌,并返回角色和菜单权限信息。
  • 参数
    • db_user:数据库中的 User 实例。
  • 逻辑
    1. 检查用户状态
      • 如果 db_user.delete 等于 1,表示账号被禁用,抛出 HTTPException,状态码为 500,提示用户联系管理员。
    2. 获取用户角色和菜单权限
      • 调用 gen_user_role 函数,获取用户的角色和菜单权限。
    3. 生成 JWT 载荷
      • 创建一个包含 user_nameuser_idrole 的字典 payload
    4. 生成访问令牌和刷新令牌
      • 调用 AuthJWT().create_access_token 生成访问令牌,主题为 payload 的 JSON 字符串,过期时间为 ACCESS_TOKEN_EXPIRE_TIME
      • 调用 AuthJWT().create_refresh_token 生成刷新令牌,主题为用户的用户名。
    5. 返回令牌和权限信息
      • 返回访问令牌、刷新令牌、角色和菜单权限。
  • 应用场景
    • 在用户成功登录后,生成并返回 JWT 令牌,用于后续的认证和授权。

函数 get_knowledge_list_by_access

def get_knowledge_list_by_access(role_id: int, name: str, page_num: int, page_size: int):
    count_filter = []
    if name:
        count_filter.append(Knowledge.name.like('%{}%'.format(name)))

    db_role_access = KnowledgeDao.get_knowledge_by_access(role_id, page_num, page_size)
    total_count = KnowledgeDao.get_count_by_filter(count_filter)
    # 补充用户名
    user_ids = [access[0].user_id for access in db_role_access]
    db_users = UserDao.get_user_by_ids(user_ids)
    user_dict = {user.user_id: user.user_name for user in db_users}

    return {
        'data': [
            KnowledgeRead.validate({
                'name': access[0].name,
                'user_name': user_dict.get(access[0].user_id),
                'user_id': access[0].user_id,
                'update_time': access[0].update_time,
                'id': access[0].id
            }) for access in db_role_access
        ],
        'total':
            total_count
    }
  • 功能:根据用户角色获取有访问权限的知识库列表,并支持分页和名称过滤。
  • 参数
    • role_id:用户的角色 ID。
    • name:知识库名称的搜索关键字。
    • page_num:分页的页码。
    • page_size:每页显示的记录数。
  • 逻辑
    1. 构建过滤条件
      • 如果提供了 name,则添加一个 SQL LIKE 过滤条件,用于模糊匹配知识库名称。
    2. 查询有访问权限的知识库
      • 调用 KnowledgeDao.get_knowledge_by_access,传入 role_idpage_numpage_size,获取用户有权限访问的知识库列表。
    3. 查询总数
      • 调用 KnowledgeDao.get_count_by_filter,传入过滤条件,获取满足条件的知识库总数。
    4. 补充用户名
      • 从知识库访问列表中提取所有的 user_id
      • 调用 UserDao.get_user_by_ids 获取对应的用户信息。
      • 构建一个字典 user_dict,将 user_id 映射到 user_name
    5. 构建响应数据
      • 遍历 db_role_access,为每个知识库构建一个包含 nameuser_nameuser_idupdate_timeid 的字典,并通过 KnowledgeRead.validate 进行数据验证和序列化。
      • 返回包含 datatotal 的字典。
  • 应用场景
    • 用于前端页面展示用户有权限访问的知识库列表,并支持搜索和分页功能。

函数 get_flow_list_by_access

def get_flow_list_by_access(role_id: int, name: str, page_num: int, page_size: int):
    count_filter = []
    if name:
        count_filter.append(Flow.name.like('%{}%'.format(name)))

    db_role_access = FlowDao.get_flow_by_access(role_id, name, page_num, page_size)
    total_count = FlowDao.get_count_by_filters(count_filter)
    # 补充用户名
    user_ids = [access[0].user_id for access in db_role_access]
    db_users = UserDao.get_user_by_ids(user_ids)
    user_dict = {user.user_id: user.user_name for user in db_users}

    return {
        'data': [
            FlowRead.validate({
                'name': access[0].name,
                'user_name': user_dict.get(access[0].user_id),
                'user_id': access[0].user_id,
                'update_time': access[0].update_time,
                'id': access[0].id
            }) for access in db_role_access
        ],
        'total':
            total_count
    }
  • 功能:根据用户角色获取有访问权限的流程列表,并支持分页和名称过滤。
  • 参数
    • role_id:用户的角色 ID。
    • name:流程名称的搜索关键字。
    • page_num:分页的页码。
    • page_size:每页显示的记录数。
  • 逻辑
    1. 构建过滤条件
      • 如果提供了 name,则添加一个 SQL LIKE 过滤条件,用于模糊匹配流程名称。
    2. 查询有访问权限的流程
      • 调用 FlowDao.get_flow_by_access,传入 role_idnamepage_numpage_size,获取用户有权限访问的流程列表。
    3. 查询总数
      • 调用 FlowDao.get_count_by_filters,传入过滤条件,获取满足条件的流程总数。
    4. 补充用户名
      • 从流程访问列表中提取所有的 user_id
      • 调用 UserDao.get_user_by_ids 获取对应的用户信息。
      • 构建一个字典 user_dict,将 user_id 映射到 user_name
    5. 构建响应数据
      • 遍历 db_role_access,为每个流程构建一个包含 nameuser_nameuser_idupdate_timeid 的字典,并通过 FlowRead.validate 进行数据验证和序列化。
      • 返回包含 datatotal 的字典。
  • 应用场景
    • 用于前端页面展示用户有权限访问的流程列表,并支持搜索和分页功能。

函数 get_assistant_list_by_access

def get_assistant_list_by_access(role_id: int, name: str, page_num: int, page_size: int):
    count_filter = []
    if name:
        count_filter.append(Assistant.name.like('%{}%'.format(name)))

    db_role_access = AssistantDao.get_assistants_by_access(role_id, name, page_size, page_num)
    total_count = AssistantDao.get_count_by_filters(count_filter)
    # 补充用户名
    user_ids = [access[0].user_id for access in db_role_access]
    db_users = UserDao.get_user_by_ids(user_ids)
    user_dict = {user.user_id: user.user_name for user in db_users}

    return {
        'data': [{
            'name': access[0].name,
            'user_name': user_dict.get(access[0].user_id),
            'user_id': access[0].user_id,
            'update_time': access[0].update_time,
            'id': access[0].id
        } for access in db_role_access],
        'total':
            total_count
    }
  • 功能:根据用户角色获取有访问权限的助手列表,并支持分页和名称过滤。
  • 参数
    • role_id:用户的角色 ID。
    • name:助手名称的搜索关键字。
    • page_num:分页的页码。
    • page_size:每页显示的记录数。
  • 逻辑
    1. 构建过滤条件
      • 如果提供了 name,则添加一个 SQL LIKE 过滤条件,用于模糊匹配助手名称。
    2. 查询有访问权限的助手
      • 调用 AssistantDao.get_assistants_by_access,传入 role_idnamepage_sizepage_num,获取用户有权限访问的助手列表。
    3. 查询总数
      • 调用 AssistantDao.get_count_by_filters,传入过滤条件,获取满足条件的助手总数。
    4. 补充用户名
      • 从助手访问列表中提取所有的 user_id
      • 调用 UserDao.get_user_by_ids 获取对应的用户信息。
      • 构建一个字典 user_dict,将 user_id 映射到 user_name
    5. 构建响应数据
      • 遍历 db_role_access,为每个助手构建一个包含 nameuser_nameuser_idupdate_timeid 的字典。
      • 返回包含 datatotal 的字典。
  • 应用场景
    • 用于前端页面展示用户有权限访问的助手列表,并支持搜索和分页功能。

7. 认证和权限获取

函数 get_login_user

async def get_login_user(authorize: AuthJWT = Depends()) -> UserPayload:
    """
    获取当前登录的用户
    """
    # 校验是否过期,过期则直接返回http 状态码的 401
    authorize.jwt_required()

    current_user = json.loads(authorize.get_jwt_subject())
    user = UserPayload(**current_user)

    # 判断是否允许多点登录
    if not settings.get_system_login_method().allow_multi_login:
        # 获取access_token
        current_token = redis_client.get(USER_CURRENT_SESSION.format(user.user_id))
        # 登录被挤下线了,http状态码是200, status_code是特殊code
        if current_token != authorize._token:
            raise UserLoginOfflineError.http_exception()
    return user
  • 功能:获取当前登录的用户信息,并进行多点登录状态检查。
  • 参数
    • authorizeAuthJWT 实例,通过 FastAPI 的 Depends 注入。
  • 逻辑
    1. JWT 验证
      • 调用 authorize.jwt_required(),确保请求包含有效的 JWT 令牌。
      • 如果 JWT 过期或无效,自动抛出 401 Unauthorized 异常。
    2. 获取 JWT 载荷
      • 调用 authorize.get_jwt_subject() 获取 JWT 的主题部分(通常是用户信息的 JSON 字符串)。
      • 使用 json.loads 将 JSON 字符串转换为 Python 字典。
      • 使用 UserPayload 类初始化用户信息对象。
    3. 检查多点登录
      • 调用 settings.get_system_login_method().allow_multi_login 判断系统是否允许多点登录。
      • 如果不允许多点登录:
        • 从 Redis 获取当前用户的 access_token,键格式为 USER_CURRENT_SESSION.format(user.user_id)
        • 比较 Redis 中存储的 current_token 与当前请求中的 _token
        • 如果两者不匹配,说明用户在其他地方登录,抛出 UserLoginOfflineError 异常。
    4. 返回用户信息
      • 返回 UserPayload 实例,包含用户的 ID、角色和用户名。
  • 应用场景
    • 用于保护需要登录才能访问的 API 路由,确保只有认证用户可以访问。
    • 结合 FastAPI 的依赖注入机制,简化路由函数中的用户信息获取。

函数 get_admin_user

async def get_admin_user(authorize: AuthJWT = Depends()) -> UserPayload:
    """
    获取超级管理账号,非超级管理员用户,抛出异常
    """
    login_user = await get_login_user(authorize)
    if not login_user.is_admin():
        raise UnAuthorizedError.http_exception()
    return login_user
  • 功能:获取当前登录的超级管理员用户,如果用户不是管理员,则抛出未授权异常。
  • 参数
    • authorizeAuthJWT 实例,通过 FastAPI 的 Depends 注入。
  • 逻辑
    1. 获取当前登录用户
      • 调用 get_login_user 函数,获取当前登录用户的 UserPayload 实例。
    2. 检查管理员权限
      • 调用 login_user.is_admin() 判断用户是否是管理员。
      • 如果用户不是管理员,抛出 UnAuthorizedError 异常。
    3. 返回管理员用户信息
      • 如果用户是管理员,返回 UserPayload 实例。
  • 应用场景
    • 用于保护需要管理员权限才能访问的 API 路由,确保只有管理员用户可以访问。
    • 结合 FastAPI 的依赖注入机制,简化路由函数中的管理员用户信息获取和权限校验。

8. 其他辅助函数

函数 gen_user_role

此函数已在前面详细讲解,此处不再重复。

函数 gen_user_jwt

此函数已在前面详细讲解,此处不再重复。

函数 get_knowledge_list_by_access

此函数已在前面详细讲解,此处不再重复。

函数 get_flow_list_by_access

此函数已在前面详细讲解,此处不再重复。

函数 get_assistant_list_by_access

此函数已在前面详细讲解,此处不再重复。

9. 总结

这段代码主要涵盖了以下几个关键方面:

  1. 用户身份验证
    • 使用 JWT 进行用户认证,确保每个请求都携带有效的令牌。
    • 提供了 get_login_userget_admin_user 函数,用于获取当前登录用户的信息,并进行权限校验。
  2. 权限管理
    • UserPayload 类封装了用户的角色和权限信息,提供了权限检查的方法。
    • 通过装饰器 wrapper_access_check 简化权限检查逻辑,提高代码复用性。
  3. 用户管理
    • UserService 类提供了用户创建和密码解密的功能。
    • 处理用户注册时的逻辑,如检查用户名是否存在、分配用户组和角色等。
  4. 与数据库交互
    • 使用 DAO(数据访问对象)模式,通过 UserDaoFlowDaoKnowledgeDao 等类与数据库进行交互。
    • 通过 UserRoleDaoRoleAccessDao 等类管理用户角色和权限。
  5. 缓存和配置管理
    • 使用 Redis 作为缓存,存储和获取用户会话信息和 RSA 密钥。
    • 通过 settings 对象管理项目的配置参数,如数据库连接字符串、是否允许多点登录等。
  6. 异常处理
    • 定义了自定义的异常类,用于统一错误响应和处理。
    • 在关键业务逻辑中抛出自定义异常,确保错误信息的一致性和可读性。
  7. 实用功能
    • 提供了 generate_uuid 函数,用于生成唯一的标识符。
    • 提供了 md5_hash 和 RSA 解密功能,增强密码的安全性。

如何在项目中使用

  • 保护路由

    • 使用 Depends(get_login_user)Depends(get_admin_user) 作为路由的依赖项,确保只有经过认证的用户或管理员可以访问特定的路由。
    from fastapi import APIRouter, Depends
    
    router = APIRouter()
    
    @router.get("/protected-route")
    async def protected_route(user: UserPayload = Depends(get_login_user)):
        return {"message": f"Hello, {user.user_name}!"}
    
    @router.get("/admin-route")
    async def admin_route(user: UserPayload = Depends(get_admin_user)):
        return {"message": f"Welcome, admin {user.user_name}!"}
    
  • 创建用户

    • 使用 UserService.create_user 方法处理用户创建逻辑。
    @router.post("/create-user")
    async def create_user(request: Request, req_data: CreateUserReq, user: UserPayload = Depends(get_admin_user)):
        new_user = UserService.create_user(request, user, req_data)
        return {"user_id": new_user.user_id, "user_name": new_user.user_name}
    
  • 生成和返回 JWT

    • 在用户登录成功后,调用 gen_user_jwt 函数生成 JWT 令牌,并将其返回给前端。
    @router.post("/login")
    async def login(credentials: LoginSchema):
        db_user = UserDao.authenticate(credentials.username, credentials.password)
        if not db_user:
            raise HTTPException(status_code=401, detail="Invalid credentials")
        access_token, refresh_token, role, web_menu = gen_user_jwt(db_user)
        return {"access_token": access_token, "refresh_token": refresh_token, "role": role, "web_menu": web_menu}
    

进一步阅读

  • FastAPI 文档
    • FastAPI 官方文档 提供了详细的框架使用指南和最佳实践。
  • Pydantic 文档
    • Pydantic 官方文档 详细介绍了数据模型和验证功能。
  • JWT 认证
    • 学习 JWT(JSON Web Token) 的基本概念和使用方法,了解如何在 API 中实现安全的认证机制。
  • Redis 缓存
    • Redis 官方文档 了解如何使用 Redis 进行缓存和会话管理。
  • RSA 加密
    • 学习 RSA 加密算法 的基本原理和应用场景,了解如何在 Python 中使用 rsa 库进行加密和解密操作。

常见问题

  1. 为什么使用装饰器进行权限检查?
    • 使用装饰器可以将权限检查逻辑与业务逻辑分离,提高代码的可读性和可维护性。通过装饰器,可以复用权限检查逻辑,避免在每个方法中重复编写相同的代码。
  2. 如何确保 JWT 的安全性?
    • 确保 JWT 使用强密码进行签名,避免暴露私钥。
    • 设置合理的过期时间,防止令牌被滥用。
    • 使用 HTTPS 保护令牌在传输过程中的安全性。
  3. 为什么需要检查多点登录?
    • 检查多点登录可以防止用户账号被多人同时使用,提升账号的安全性。
    • 如果系统不允许多点登录,当用户在其他设备登录时,可以自动将之前的会话踢下线。
  4. 如何处理 RSA 密钥管理?
    • 确保 RSA 密钥安全存储,不要将私钥暴露在代码库中。
    • 使用环境变量或安全的密钥管理服务(如 AWS KMS)来存储和管理 RSA 密钥。
  5. 如何优化数据库查询性能?
    • 确保数据库表中有适当的索引,尤其是在常用的查询字段上。
    • 使用分页查询减少一次查询返回的数据量。
    • 避免 N+1 查询问题,通过使用 JOIN 或预加载相关数据提高查询效率。

http://www.kler.cn/a/414527.html

相关文章:

  • 【漏洞复现】|百易云资产管理运营系统/mobilefront/c/2.php前台文件上传
  • 【机器学习算法】Adaboost原理及实现
  • 操作系统 内存管理——针对实习面试
  • 全景图像(Panorama Image)向透视图像(Perspective Image)的跨视图转化(Cross-view)
  • Taro 鸿蒙技术内幕系列(三) - 多语言场景下的通用事件系统设计
  • 1138:将字符串中的小写字母转换成大写字母
  • 霍夫变换:原理剖析与 OpenCV 应用实例
  • Leetcode:349. 两个数组的交集
  • 大数据挖掘实战-PyODPS基础操作
  • 相机学习笔记——工业相机的基本参数
  • 详解 PyTorch 图像预处理:使用 torchvision.transforms 实现有效的数据处理
  • 如何利用Java爬虫获取店铺详情:一篇详尽指南
  • C++算法练习-day47——450.删除二叉搜索树中的节点
  • 我们项目要升级到flutter架构的几点原因
  • elasticsearch集群部署及加密通讯
  • 架构-微服务-环境搭建
  • ubuntu连接副屏显示器不量的一系列踩坑记录
  • 【PGCCC】Postgresql BRIN 索引原理
  • Jenkins Nginx Vue项目自动化部署
  • faiss库中ivf-sq(ScalarQuantizer,标量量化)代码解读-2
  • 淘宝关键词挖掘:Python爬虫技术在电商领域的应用
  • 虚拟现实(VR)与增强现实(AR)有什么区别?
  • 【k8s深入理解之 Scheme 补充-6】理解资源外部版本之间的优先级
  • TypeScript中function和const定义函数的区别
  • java 排序 详解
  • 【Unity基础】初识Unity中的渲染管线