Chainlit快速实现AI对话应用将聊天记录的持久化到MySql关系数据库中
概述
默认情况下,Chainlit
应用不会保留其生成的聊天和元素。即网页一刷新,所有的聊天记录,页面上的所有聊天记录都会消失。但是,存储和利用这些数据的能力可能是您的项目或组织的重要组成部分。
之前写过一篇文章《Chainlit快速实现AI对话应用并将聊天数据的持久化到sqllite本地数据库中》,这个技术方案的优点是,不需要自己在安装数据库,创建表结构等操作,缺点是,只适合用户量比较少的情况。使用mysql
数据库可以解决中等规模的用户访问聊天记录访问问题。
教程
1. 安装chainlit依赖
pip install chainlit aiomysql pymysql cryptography sqlalchemy
aiomysql
异步mysql
驱动pymysql
同步mysql
驱动sqlalchemy
SQL 工具包及对象关系映射(ORM)工具cryptography
是一个用于Python的开源软件包,旨在提供一套易于使用的加密工具和算法chainlit
是一个开源框架,用于快速构建和部署对话式应用,如聊天机器人和虚拟助手。
2. 配置环境变量
在项目根目录下,创建.env
文件,内容如下:
OPENAI_BASE_URL="https://dashscope.aliyuncs.com/compatible-mode/v1"
OPENAI_API_KEY="your api_key"
- 由于国内无法访问
open ai
的chatgpt
,所以需要配置OPENAI_BASE_URL
的代理地址,如果使用国内的LLM
大模型接口,可以使用兼容open ai
的接口地址
安装Mysql数据库
可以参考这篇文章 《MySQL 安装和配置教程 | MySQL入门》,或者自行百度如何安装。
安装mysql
数据库后,使用navicat
等数据管理工具,创建一个数据库,例如,名为chain_lit
(或者其他名字都可以)的数据库,然后导入一下创建表结构的sql命令:
SET NAMES utf8mb4;
SET FOREIGN_KEY_CHECKS = 0;
-- ----------------------------
-- Table structure for elements
-- ----------------------------
DROP TABLE IF EXISTS `elements`;
CREATE TABLE `elements` (
`id` char(36) CHARACTER SET utf8mb4 COLLATE utf8mb4_bin NOT NULL,
`threadId` char(36) CHARACTER SET utf8mb4 COLLATE utf8mb4_bin NULL DEFAULT NULL,
`type` varchar(255) CHARACTER SET utf8mb4 COLLATE utf8mb4_bin NULL DEFAULT NULL,
`url` text CHARACTER SET utf8mb4 COLLATE utf8mb4_bin NULL,
`chainlitKey` text CHARACTER SET utf8mb4 COLLATE utf8mb4_bin NULL,
`name` varchar(255) CHARACTER SET utf8mb4 COLLATE utf8mb4_bin NOT NULL,
`display` text CHARACTER SET utf8mb4 COLLATE utf8mb4_bin NULL,
`objectKey` text CHARACTER SET utf8mb4 COLLATE utf8mb4_bin NULL,
`size` varchar(255) CHARACTER SET utf8mb4 COLLATE utf8mb4_bin NULL DEFAULT NULL,
`page` int NULL DEFAULT NULL,
`language` varchar(255) CHARACTER SET utf8mb4 COLLATE utf8mb4_bin NULL DEFAULT NULL,
`forId` char(36) CHARACTER SET utf8mb4 COLLATE utf8mb4_bin NULL DEFAULT NULL,
`mime` varchar(255) CHARACTER SET utf8mb4 COLLATE utf8mb4_bin NULL DEFAULT NULL,
PRIMARY KEY (`id`) USING BTREE
) ENGINE = InnoDB CHARACTER SET = utf8mb4 COLLATE = utf8mb4_bin ROW_FORMAT = Dynamic;
-- ----------------------------
-- Table structure for feedbacks
-- ----------------------------
DROP TABLE IF EXISTS `feedbacks`;
CREATE TABLE `feedbacks` (
`id` char(36) CHARACTER SET utf8mb4 COLLATE utf8mb4_bin NOT NULL,
`forId` char(36) CHARACTER SET utf8mb4 COLLATE utf8mb4_bin NOT NULL,
`threadId` char(36) CHARACTER SET utf8mb4 COLLATE utf8mb4_bin NOT NULL,
`value` int NOT NULL,
`comment` text CHARACTER SET utf8mb4 COLLATE utf8mb4_bin NULL,
PRIMARY KEY (`id`) USING BTREE
) ENGINE = InnoDB CHARACTER SET = utf8mb4 COLLATE = utf8mb4_bin ROW_FORMAT = Dynamic;
-- ----------------------------
-- Table structure for steps
-- ----------------------------
DROP TABLE IF EXISTS `steps`;
CREATE TABLE `steps` (
`id` char(36) CHARACTER SET utf8mb4 COLLATE utf8mb4_bin NOT NULL,
`name` varchar(255) CHARACTER SET utf8mb4 COLLATE utf8mb4_bin NOT NULL,
`type` varchar(255) CHARACTER SET utf8mb4 COLLATE utf8mb4_bin NOT NULL,
`threadId` char(36) CHARACTER SET utf8mb4 COLLATE utf8mb4_bin NOT NULL,
`parentId` char(36) CHARACTER SET utf8mb4 COLLATE utf8mb4_bin NULL DEFAULT NULL,
`disableFeedback` tinyint(1) NOT NULL DEFAULT 1,
`streaming` tinyint(1) NOT NULL,
`waitForAnswer` tinyint(1) NULL DEFAULT NULL,
`isError` tinyint(1) NULL DEFAULT NULL,
`metadata` json NULL,
`tags` json NULL,
`input` text CHARACTER SET utf8mb4 COLLATE utf8mb4_bin NULL,
`output` text CHARACTER SET utf8mb4 COLLATE utf8mb4_bin NULL,
`createdAt` varchar(50) CHARACTER SET utf8mb4 COLLATE utf8mb4_bin NULL DEFAULT NULL,
`start` varchar(50) CHARACTER SET utf8mb4 COLLATE utf8mb4_bin NULL DEFAULT NULL,
`end` varchar(50) CHARACTER SET utf8mb4 COLLATE utf8mb4_bin NULL DEFAULT NULL,
`generation` json NULL,
`showInput` text CHARACTER SET utf8mb4 COLLATE utf8mb4_bin NULL,
`language` varchar(255) CHARACTER SET utf8mb4 COLLATE utf8mb4_bin NULL DEFAULT NULL,
`indent` int NULL DEFAULT NULL,
PRIMARY KEY (`id`) USING BTREE
) ENGINE = InnoDB CHARACTER SET = utf8mb4 COLLATE = utf8mb4_bin ROW_FORMAT = Dynamic;
-- ----------------------------
-- Table structure for threads
-- ----------------------------
DROP TABLE IF EXISTS `threads`;
CREATE TABLE `threads` (
`id` char(36) CHARACTER SET utf8mb4 COLLATE utf8mb4_bin NOT NULL,
`createdAt` varchar(50) CHARACTER SET utf8mb4 COLLATE utf8mb4_bin NULL DEFAULT NULL,
`name` varchar(255) CHARACTER SET utf8mb4 COLLATE utf8mb4_bin NULL DEFAULT NULL,
`userId` char(36) CHARACTER SET utf8mb4 COLLATE utf8mb4_bin NULL DEFAULT NULL,
`userIdentifier` varchar(255) CHARACTER SET utf8mb4 COLLATE utf8mb4_bin NULL DEFAULT NULL,
`tags` json NULL,
`metadata` json NULL,
PRIMARY KEY (`id`) USING BTREE,
INDEX `userId`(`userId` ASC) USING BTREE,
CONSTRAINT `threads_ibfk_1` FOREIGN KEY (`userId`) REFERENCES `users` (`id`) ON DELETE CASCADE ON UPDATE RESTRICT
) ENGINE = InnoDB CHARACTER SET = utf8mb4 COLLATE = utf8mb4_bin ROW_FORMAT = Dynamic;
-- ----------------------------
-- Table structure for users
-- ----------------------------
DROP TABLE IF EXISTS `users`;
CREATE TABLE `users` (
`id` char(36) CHARACTER SET utf8mb4 COLLATE utf8mb4_bin NOT NULL,
`identifier` varchar(255) CHARACTER SET utf8mb4 COLLATE utf8mb4_bin NOT NULL,
`metadata` json NOT NULL,
`createdAt` varchar(50) CHARACTER SET utf8mb4 COLLATE utf8mb4_bin NULL DEFAULT NULL,
PRIMARY KEY (`id`) USING BTREE,
UNIQUE INDEX `identifier`(`identifier` ASC) USING BTREE
) ENGINE = InnoDB CHARACTER SET = utf8mb4 COLLATE = utf8mb4_bin ROW_FORMAT = Dynamic;
SET FOREIGN_KEY_CHECKS = 1;
3. 创建代码
在项目根目录下,创建mysql_client.py
文件,代码如下:
import pymysql
from aiomysql import connection
from chainlit.data import BaseStorageClient
from chainlit.logger import logger
class MysqlStorageClient(BaseStorageClient):
"""
Class to enable storage in a MYSQL database.
parms:
host: Hostname or IP address of the MYSQL server.
dbname: Name of the database to connect to.
user: User name used to authenticate.
password: Password used to authenticate.
port: Port number to connect to (default: 3306).
"""
def __init__(self, host: str, dbname: str, user: str, password: str, port: int = 5432):
try:
self.conn: connection = pymysql.Connect(
host=host,
port=port,
user=user,
passwd=password,
db=dbname,
charset='utf8'
)
logger.info("MysqlStorageClient initialized")
except Exception as e:
logger.warn(f"MysqlStorageClient initialization error: {e}")
在项目根目录下,创建mysql_data.py
文件,代码如下:
import json
import ssl
import uuid
from dataclasses import asdict
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
from literalai.helper import utc_now
import aiofiles
import aiohttp
from chainlit.context import context
from chainlit.data import BaseDataLayer, BaseStorageClient, queue_until_user_message
from chainlit.logger import logger
from chainlit.step import StepDict
from chainlit.types import (
Feedback,
FeedbackDict,
PageInfo,
PaginatedResponse,
Pagination,
ThreadDict,
ThreadFilter,
)
from chainlit.user import PersistedUser, User
from sqlalchemy import text
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, create_async_engine
from sqlalchemy.orm import sessionmaker
if TYPE_CHECKING:
from chainlit.element import Element, ElementDict
from chainlit.step import StepDict
class MysqlDataLayer(BaseDataLayer):
def __init__(
self,
conninfo: str,
ssl_require: bool = False,
storage_provider: Optional[BaseStorageClient] = None,
user_thread_limit: Optional[int] = 1000,
show_logger: Optional[bool] = False,
):
self._conninfo = conninfo
self.user_thread_limit = user_thread_limit
self.show_logger = show_logger
ssl_args = {}
if ssl_require:
# Create an SSL context to require an SSL connection
ssl_context = ssl.create_default_context()
ssl_context.check_hostname = False
ssl_context.verify_mode = ssl.CERT_NONE
ssl_args["ssl"] = ssl_context
self.engine: AsyncEngine = create_async_engine(
self._conninfo, connect_args=ssl_args
)
self.async_session = sessionmaker(bind=self.engine, expire_on_commit=False, class_=AsyncSession) # type: ignore
if storage_provider:
self.storage_provider: Optional[BaseStorageClient] = storage_provider
if self.show_logger:
logger.info("SQLAlchemyDataLayer storage client initialized")
else:
self.storage_provider = None
logger.warn(
"SQLAlchemyDataLayer storage client is not initialized and elements will not be persisted!"
)
async def build_debug_url(self) -> str:
return ""
###### SQL Helpers ######
async def execute_sql(
self, query: str, parameters: dict
) -> Union[List[Dict[str, Any]], int, None]:
parameterized_query = text(query)
async with self.async_session() as session:
try:
await session.begin()
result = await session.execute(parameterized_query, parameters)
await session.commit()
if result.returns_rows:
json_result = [dict(row._mapping) for row in result.fetchall()]
clean_json_result = self.clean_result(json_result)
return clean_json_result
else:
return result.rowcount
except SQLAlchemyError as e:
await session.rollback()
logger.warn(f"An error occurred: {e}")
return None
except Exception as e:
await session.rollback()
logger.warn(f"An unexpected error occurred: {e}")
return None
async def get_current_timestamp(self) -> str:
return utc_now()
def clean_result(self, obj):
"""Recursively change UUID -> str and serialize dictionaries"""
if isinstance(obj, dict):
return {k: self.clean_result(v) for k, v in obj.items()}
elif isinstance(obj, list):
return [self.clean_result(item) for item in obj]
elif isinstance(obj, uuid.UUID):
return str(obj)
return obj
###### User ######
async def get_user(self, identifier: str) -> Optional[PersistedUser]:
if self.show_logger:
logger.info(f"SQLAlchemy: get_user, identifier={identifier}")
query = "SELECT * FROM users WHERE identifier = :identifier"
parameters = {"identifier": identifier}
result = await self.execute_sql(query=query, parameters=parameters)
if result and isinstance(result, list):
user_data = result[0]
if isinstance(user_data['metadata'], str):
user_data['metadata'] = json.loads(user_data['metadata'])
print('get_user over')
return PersistedUser(**user_data)
return None
async def create_user(self, user: User) -> Optional[PersistedUser]:
if self.show_logger:
logger.info(f"SQLAlchemy: create_user, user_identifier={user.identifier}")
existing_user: Optional["PersistedUser"] = await self.get_user(user.identifier)
user_dict: Dict[str, Any] = {
"identifier": str(user.identifier),
"metadata": json.dumps(user.metadata) or {},
}
if not existing_user: # create the user
if self.show_logger:
logger.info("SQLAlchemy: create_user, creating the user")
user_dict["id"] = str(uuid.uuid4())
user_dict["createdAt"] = await self.get_current_timestamp()
query = "INSERT INTO users (`id`, `identifier`, `createdAt`, `metadata`) VALUES (:id, :identifier, :createdAt, :metadata)"
await self.execute_sql(query=query, parameters=user_dict)
else: # update the user
if self.show_logger:
logger.info("SQLAlchemy: update user metadata")
query = "UPDATE users SET metadata = :metadata WHERE identifier = :identifier"
await self.execute_sql(
query=query, parameters=user_dict
) # We want to update the metadata
return await self.get_user(user.identifier)
###### Threads ######
async def get_thread_author(self, thread_id: str) -> str:
print('get_thread_author', thread_id)
if self.show_logger:
logger.info(f"SQLAlchemy: get_thread_author, thread_id={thread_id}")
query = """SELECT userIdentifier FROM threads WHERE id = :id"""
parameters = {"id": thread_id}
result = await self.execute_sql(query=query, parameters=parameters)
print('result', result)
if isinstance(result, list) and result:
author_identifier = result[0].get("userIdentifier")
if author_identifier is not None:
return author_identifier
raise ValueError(f"Author not found for thread_id {thread_id}")
async def get_thread(self, thread_id: str) -> Optional[ThreadDict]:
print('get_thread', thread_id)
if self.show_logger:
logger.info(f"SQLAlchemy: get_thread, thread_id={thread_id}")
user_threads: Optional[List[ThreadDict]] = await self.get_all_user_threads(
thread_id=thread_id
)
if user_threads:
return user_threads[0]
else:
return None
async def update_thread(
self,
thread_id: str,
name: Optional[str] = None,
user_id: Optional[str] = None,
metadata: Optional[Dict] = None,
tags: Optional[List[str]] = None,
):
if self.show_logger:
logger.info(f"SQLAlchemy: update_thread, thread_id={thread_id}")
if context.session.user is not None:
user_identifier = context.session.user.identifier
else:
raise ValueError("User not found in session context")
data = {
"id": thread_id,
"createdAt": (
await self.get_current_timestamp() if metadata is None else None
),
"name": (
name
if name is not None
else (metadata.get("name") if metadata and "name" in metadata else None)
),
"userId": user_id,
"userIdentifier": user_identifier,
"tags": tags,
"metadata": json.dumps(metadata) if metadata else None,
}
parameters = {
key: value for key, value in data.items() if value is not None
} # Remove keys with None values
columns = ", ".join(f'{key}' for key in parameters.keys())
values = ", ".join(f":{key}" for key in parameters.keys())
updates = ", ".join(
f'{key} = VALUES({key})' for key in parameters.keys() if key != "id"
)
query = f"""
INSERT INTO threads ({columns})
VALUES ({values})
ON DUPLICATE KEY UPDATE
{updates};
"""
await self.execute_sql(query=query, parameters=parameters)
async def delete_thread(self, thread_id: str):
if self.show_logger:
logger.info(f"SQLAlchemy: delete_thread, thread_id={thread_id}")
# Delete feedbacks/elements/steps/thread
feedbacks_query = "DELETE FROM feedbacks WHERE forId IN (SELECT id FROM steps WHERE threadId = :id)"
elements_query = "DELETE FROM elements WHERE threadId = :id"
steps_query = "DELETE FROM steps WHERE threadId = :id"
thread_query = "DELETE FROM threads WHERE id = :id"
parameters = {"id": thread_id}
await self.execute_sql(query=feedbacks_query, parameters=parameters)
await self.execute_sql(query=elements_query, parameters=parameters)
await self.execute_sql(query=steps_query, parameters=parameters)
await self.execute_sql(query=thread_query, parameters=parameters)
async def list_threads(
self, pagination: Pagination, filters: ThreadFilter
) -> PaginatedResponse:
if self.show_logger:
logger.info(
f"SQLAlchemy: list_threads, pagination={pagination}, filters={filters}"
)
if not filters.userId:
raise ValueError("userId is required")
all_user_threads: List[ThreadDict] = (
await self.get_all_user_threads(user_id=filters.userId) or []
)
search_keyword = filters.search.lower() if filters.search else None
feedback_value = int(filters.feedback) if filters.feedback else None
filtered_threads = []
for thread in all_user_threads:
keyword_match = True
feedback_match = True
if search_keyword or feedback_value is not None:
if search_keyword:
keyword_match = any(
search_keyword in step["output"].lower()
for step in thread["steps"]
if "output" in step
)
if feedback_value is not None:
feedback_match = False # Assume no match until found
for step in thread["steps"]:
feedback = step.get("feedback")
if feedback and feedback.get("value") == feedback_value:
feedback_match = True
break
if keyword_match and feedback_match:
filtered_threads.append(thread)
start = 0
if pagination.cursor:
for i, thread in enumerate(filtered_threads):
if (
thread["id"] == pagination.cursor
): # Find the start index using pagination.cursor
start = i + 1
break
end = start + pagination.first
paginated_threads = filtered_threads[start:end] or []
has_next_page = len(filtered_threads) > end
start_cursor = paginated_threads[0]["id"] if paginated_threads else None
end_cursor = paginated_threads[-1]["id"] if paginated_threads else None
return PaginatedResponse(
pageInfo=PageInfo(
hasNextPage=has_next_page,
startCursor=start_cursor,
endCursor=end_cursor,
),
data=paginated_threads,
)
###### Steps ######
@queue_until_user_message()
async def create_step(self, step_dict: "StepDict"):
if self.show_logger:
logger.info(f"SQLAlchemy: create_step, step_id={step_dict.get('id')}")
if not getattr(context.session.user, "id", None):
raise ValueError("No authenticated user in context")
step_dict["showInput"] = (
str(step_dict.get("showInput", "")).lower()
if "showInput" in step_dict
else None
)
parameters = {
key: value
for key, value in step_dict.items()
if value is not None and not (isinstance(value, dict) and not value)
}
parameters["metadata"] = json.dumps(step_dict.get("metadata", {}))
parameters["generation"] = json.dumps(step_dict.get("generation", {}))
columns = ", ".join(f'{key}' for key in parameters.keys())
values = ", ".join(f":{key}" for key in parameters.keys())
updates = ", ".join(
f'{key} = :{key}' for key in parameters.keys() if key != "id"
)
query = f"""
INSERT INTO steps ({columns})
VALUES ({values})
ON DUPLICATE KEY UPDATE
{updates};
"""
await self.execute_sql(query=query, parameters=parameters)
@queue_until_user_message()
async def update_step(self, step_dict: "StepDict"):
if self.show_logger:
logger.info(f"SQLAlchemy: update_step, step_id={step_dict.get('id')}")
await self.create_step(step_dict)
@queue_until_user_message()
async def delete_step(self, step_id: str):
if self.show_logger:
logger.info(f"SQLAlchemy: delete_step, step_id={step_id}")
# Delete feedbacks/elements/steps
feedbacks_query = "DELETE FROM feedbacks WHERE forId = :id"
elements_query = "DELETE FROM elements WHERE forId = :id"
steps_query = "DELETE FROM steps WHERE id = :id"
parameters = {"id": step_id}
await self.execute_sql(query=feedbacks_query, parameters=parameters)
await self.execute_sql(query=elements_query, parameters=parameters)
await self.execute_sql(query=steps_query, parameters=parameters)
###### Feedback ######
async def upsert_feedback(self, feedback: Feedback) -> str:
if self.show_logger:
logger.info(f"SQLAlchemy: upsert_feedback, feedback_id={feedback.id}")
feedback.id = feedback.id or str(uuid.uuid4())
feedback_dict = asdict(feedback)
parameters = {
key: value for key, value in feedback_dict.items() if value is not None
}
columns = ", ".join(f'{key}' for key in parameters.keys())
values = ", ".join(f":{key}" for key in parameters.keys())
updates = ", ".join(
f'{key} = :{key}' for key in parameters.keys() if key != "id"
)
query = f"""
INSERT INTO feedbacks ({columns})
VALUES ({values})
ON DUPLICATE KEY UPDATE
{updates};
"""
await self.execute_sql(query=query, parameters=parameters)
return feedback.id
async def delete_feedback(self, feedback_id: str) -> bool:
if self.show_logger:
logger.info(f"SQLAlchemy: delete_feedback, feedback_id={feedback_id}")
query = "DELETE FROM feedbacks WHERE id = :feedback_id"
parameters = {"feedback_id": feedback_id}
await self.execute_sql(query=query, parameters=parameters)
return True
###### Elements ######
@queue_until_user_message()
async def create_element(self, element: "Element"):
if self.show_logger:
logger.info(f"SQLAlchemy: create_element, element_id = {element.id}")
if not getattr(context.session.user, "id", None):
raise ValueError("No authenticated user in context")
if not self.storage_provider:
logger.warn(
f"SQLAlchemy: create_element error. No blob_storage_client is configured!"
)
return
if not element.for_id:
return
content: Optional[Union[bytes, str]] = None
if element.path:
async with aiofiles.open(element.path, "rb") as f:
content = await f.read()
elif element.url:
async with aiohttp.ClientSession() as session:
async with session.get(element.url) as response:
if response.status == 200:
content = await response.read()
else:
content = None
elif element.content:
content = element.content
else:
raise ValueError("Element url, path or content must be provided")
if content is None:
raise ValueError("Content is None, cannot upload file")
context_user = context.session.user
user_folder = getattr(context_user, "id", "unknown")
file_object_key = f"{user_folder}/{element.id}" + (
f"/{element.name}" if element.name else ""
)
if not element.mime:
element.mime = "application/octet-stream"
uploaded_file = await self.storage_provider.upload_file(
object_key=file_object_key, data=content, mime=element.mime, overwrite=True
)
if not uploaded_file:
raise ValueError(
"SQLAlchemy Error: create_element, Failed to persist data in storage_provider"
)
element_dict: ElementDict = element.to_dict()
element_dict["url"] = uploaded_file.get("url")
element_dict["objectKey"] = uploaded_file.get("object_key")
element_dict_cleaned = {k: v for k, v in element_dict.items() if v is not None}
columns = ", ".join(f'{column}' for column in element_dict_cleaned.keys())
placeholders = ", ".join(f":{column}" for column in element_dict_cleaned.keys())
query = f"INSERT INTO elements ({columns}) VALUES ({placeholders})"
await self.execute_sql(query=query, parameters=element_dict_cleaned)
@queue_until_user_message()
async def delete_element(self, element_id: str, thread_id: Optional[str] = None):
if self.show_logger:
logger.info(f"SQLAlchemy: delete_element, element_id={element_id}")
query = "DELETE FROM elements WHERE id = :id"
parameters = {"id": element_id}
await self.execute_sql(query=query, parameters=parameters)
async def delete_user_session(self, id: str) -> bool:
return False # Not sure why documentation wants this
async def get_all_user_threads(
self, user_id: Optional[str] = None, thread_id: Optional[str] = None
) -> Optional[List[ThreadDict]]:
"""Fetch all user threads up to self.user_thread_limit, or one thread by id if thread_id is provided."""
if self.show_logger:
logger.info(f"SQLAlchemy: get_all_user_threads")
user_threads_query = """
SELECT
id AS thread_id,
createdAt AS thread_createdat,
name AS thread_name,
userId AS user_id,
userIdentifier AS user_identifier,
tags AS thread_tags,
metadata AS thread_metadata
FROM threads
WHERE userId = :user_id OR id = :thread_id
ORDER BY createdAt DESC
LIMIT :limit
"""
user_threads = await self.execute_sql(
query=user_threads_query,
parameters={
"user_id": user_id,
"limit": self.user_thread_limit,
"thread_id": thread_id,
},
)
if not isinstance(user_threads, list):
return None
if not user_threads:
return []
else:
thread_ids = (
"('"
+ "','".join(map(str, [thread["thread_id"] for thread in user_threads]))
+ "')"
)
steps_feedbacks_query = f"""
SELECT
s.id AS step_id,
s.name AS step_name,
s.type AS step_type,
s.threadId AS step_threadid,
s.parentId AS step_parentid,
s.streaming AS step_streaming,
s.waitForAnswer AS step_waitforanswer,
s.isError AS step_iserror,
s.metadata AS step_metadata,
s.tags AS step_tags,
s.input AS step_input,
s.output AS step_output,
s.createdAt AS step_createdat,
s.start AS step_start,
s.end AS step_end,
s.generation AS step_generation,
s.showInput AS step_showinput,
s.language AS step_language,
s.indent AS step_indent,
f.value AS feedback_value,
f.comment AS feedback_comment
FROM steps s LEFT JOIN feedbacks f ON s.id = f.forId
WHERE s.threadId IN {thread_ids}
ORDER BY s.createdAt ASC
"""
steps_feedbacks = await self.execute_sql(
query=steps_feedbacks_query, parameters={}
)
elements_query = f"""
SELECT
e.id AS element_id,
e.threadId as element_threadid,
e.type AS element_type,
e.chainlitKey AS element_chainlitkey,
e.url AS element_url,
e.objectKey as element_objectkey,
e.name AS element_name,
e.display AS element_display,
e.size AS element_size,
e.language AS element_language,
e.page AS element_page,
e.forId AS element_forid,
e.mime AS element_mime
FROM elements e
WHERE e.threadId IN {thread_ids}
"""
elements = await self.execute_sql(query=elements_query, parameters={})
thread_dicts = {}
for thread in user_threads:
thread_id = thread["thread_id"]
if thread_id is not None:
if isinstance(thread['thread_metadata'], str):
thread['thread_metadata'] = json.loads(thread['thread_metadata'])
thread_dicts[thread_id] = ThreadDict(
id=thread_id,
createdAt=thread["thread_createdat"],
name=thread["thread_name"],
userId=thread["user_id"],
userIdentifier=thread["user_identifier"],
tags=thread["thread_tags"],
metadata=thread["thread_metadata"],
steps=[],
elements=[],
)
# Process steps_feedbacks to populate the steps in the corresponding ThreadDict
if isinstance(steps_feedbacks, list):
for step_feedback in steps_feedbacks:
thread_id = step_feedback["step_threadid"]
if thread_id is not None:
feedback = None
if step_feedback["feedback_value"] is not None:
feedback = FeedbackDict(
forId=step_feedback["step_id"],
id=step_feedback.get("feedback_id"),
value=step_feedback["feedback_value"],
comment=step_feedback.get("feedback_comment"),
)
step_dict = StepDict(
id=step_feedback["step_id"],
name=step_feedback["step_name"],
type=step_feedback["step_type"],
threadId=thread_id,
parentId=step_feedback.get("step_parentid"),
streaming=step_feedback.get("step_streaming", False),
waitForAnswer=step_feedback.get("step_waitforanswer"),
isError=step_feedback.get("step_iserror"),
metadata=(
step_feedback["step_metadata"]
if step_feedback.get("step_metadata") is not None
else {}
),
tags=step_feedback.get("step_tags"),
input=(
step_feedback.get("step_input", "")
if step_feedback["step_showinput"] == "true"
else None
),
output=step_feedback.get("step_output", ""),
createdAt=step_feedback.get("step_createdat"),
start=step_feedback.get("step_start"),
end=step_feedback.get("step_end"),
generation=step_feedback.get("step_generation"),
showInput=step_feedback.get("step_showinput"),
language=step_feedback.get("step_language"),
indent=step_feedback.get("step_indent"),
feedback=feedback,
)
# Append the step to the steps list of the corresponding ThreadDict
thread_dicts[thread_id]["steps"].append(step_dict)
if isinstance(elements, list):
for element in elements:
thread_id = element["element_threadid"]
if thread_id is not None:
element_dict = ElementDict(
id=element["element_id"],
threadId=thread_id,
type=element["element_type"],
chainlitKey=element.get("element_chainlitkey"),
url=element.get("element_url"),
objectKey=element.get("element_objectkey"),
name=element["element_name"],
display=element["element_display"],
size=element.get("element_size"),
language=element.get("element_language"),
autoPlay=element.get("element_autoPlay"),
playerConfig=element.get("element_playerconfig"),
page=element.get("element_page"),
forId=element.get("element_forid"),
mime=element.get("element_mime"),
)
thread_dicts[thread_id]["elements"].append(element_dict) # type: ignore
return list(thread_dicts.values())
在项目根目录下,创建一个app.py
的文件,代码如下:
from typing import List, Optional
import chainlit as cl
import chainlit.data as cl_data
from openai import AsyncOpenAI
from mysql_client import MysqlStorageClient
from mysql_data import MysqlDataLayer
client = AsyncOpenAI()
thread_history = [] # type: List[cl_data.ThreadDict]
deleted_thread_ids = [] # type: List[str]
storage_client = MysqlStorageClient(host="127.0.0.1",
dbname="chain_lit",
port=3306,
user="root",
password="123456")
cl_data._data_layer = MysqlDataLayer(
conninfo="mysql+aiomysql://root:123456@127.0.0.1:3306/chain_lit",
storage_provider=storage_client)
@cl.on_chat_start
async def main():
content = "你好,我是泰山AI智能客服,有什么可以帮助您吗?"
await cl.Message(content).send()
@cl.on_message
async def handle_message():
# Wait for queue to be flushed
await cl.sleep(1)
msg = cl.Message(content="")
await msg.send()
stream = await client.chat.completions.create(
model="qwen-turbo", messages=cl.chat_context.to_openai(), stream=True
)
async for part in stream:
if token := part.choices[0].delta.content or "":
await msg.stream_token(token)
await msg.update()
@cl.password_auth_callback
def auth_callback(username: str, password: str) -> Optional[cl.User]:
if (username, password) == ("admin", "admin"):
return cl.User(identifier="admin")
else:
return None
@cl.on_chat_resume
async def on_chat_resume():
pass
- 将代码中关于
mysql
数据库连接信息,修改为自己的即可。
4. 执行命令创建 AUTH_SECRET
鉴权
chainlit create-secret
复制最后一行代码到.env
环境配置文件中
CHAINLIT_AUTH_SECRET="$b?/v0NeJlAU~I5As1WSCa,j8wJ3w%agTyIFlUt4408?mfC*,/wovlfA%3O/751U"
OPENAI_BASE_URL="https://dashscope.aliyuncs.com/compatible-mode/v1"
OPENAI_API_KEY=""
5. 执行服务启动命令
chainlit run app.py -w
6. 启动后效果展示
- 现在聊天记录都被保存在服务的
mysql
本地数据库中了,只要不重启服务,聊天记录就不会丢失了!
相关文章推荐
《使用 Xinference 部署本地模型》
《Fastgpt接入Whisper本地模型实现语音输入》
《Fastgpt部署和接入使用重排模型bge-reranker》
《Fastgpt部署接入 M3E和chatglm2-m3e文本向量模型》
《Fastgpt 无法启动或启动后无法正常使用的讨论(启动失败、用户未注册等问题这里)》
《vllm推理服务兼容openai服务API》
《vLLM模型推理引擎参数大全》
《解决vllm推理框架内在开启多显卡时报错问题》
《Ollama 在本地快速部署大型语言模型,可进行定制并创建属于您自己的模型》