Vanna使用ollama分析本地MySQL数据库 加入redis保存训练记录
相关代码
from vanna.base.base import VannaBase
from vanna.chromadb import ChromaDB_VectorStore
from vanna.ollama import Ollama
import logging
import os
import requests
import json
import pandas as pd
import chromadb
import redis
import pickle
from IPython.display import display
logging.basicConfig(level=logging.INFO)
class MyVanna(ChromaDB_VectorStore, Ollama):
def __init__(self, config=None):
# 初始化配置
self.config = {
'model': 'llama2:latest',
'ollama_host': 'http://127.0.0.1:11434',
'verbose': True,
'temperature': 0.1,
'collection_name': 'my_vanna_collection',
'redis_host': '127.0.0.1',
'redis_port': 6379,
'redis_db': 5,
'redis_password': '123456',
'redis_key_prefix': 'vanna_training:'
}
if config:
self.config.update(config)
# 初始化 ChromaDB
self.chroma_client = chromadb.PersistentClient(path=self.config['chroma_db_path'])
try:
self._collection = self.chroma_client.get_collection(self.config['collection_name'])
logging.info(f"获取已存在的集合: {self.config['collection_name']}")
except:
self._collection = self.chroma_client.create_collection(self.config['collection_name'])
logging.info(f"创建新的集合: {self.config['collection_name']}")
# 初始化 Redis 连接
try:
self.redis_client = redis.Redis(
host=self.config['redis_host'],
port=self.config['redis_port'],
db=self.config['redis_db'],
password=self.config['redis_password'],
decode_responses=False,
socket_timeout=5,
retry_on_timeout=True
)
# 测试连接
self.redis_client.ping()
logging.info("Redis 连接成功")
except Exception as e:
logging.error(f"Redis 连接错误: {str(e)}")
raise
# 初始化父类
ChromaDB_VectorStore.__init__(self, config=self.config)
Ollama.__init__(self, config=self.config)
self._ddl = None
def submit_prompt(self, prompt, **kwargs):
"""重写 submit_prompt 方法"""
try:
url = f"{self.config['ollama_host']}/api/generate"
# 如果传入的是消息列表,则组合成单个提示词
if isinstance(prompt, list):
full_prompt = "\n".join([msg.get('content', '') for msg in prompt if isinstance(msg, dict)])
else:
full_prompt = prompt
data = {
"model": self.config['model'],
"prompt": full_prompt,
"stream": False
}
headers = {
"Content-Type": "application/json"
}
logging.info(f"发送请求到 Ollama: {url}")
logging.debug(f"请求数据: {json.dumps(data, ensure_ascii=False)}")
response = requests.post(url, json=data, headers=headers)
response.raise_for_status()
response_data = response.json()
logging.debug(f"Ollama 响应: {json.dumps(response_data, ensure_ascii=False)}")
if 'response' in response_data:
return response_data['response'].strip()
else:
logging.error(f"Ollama 响应格式错误: {response_data}")
raise ValueError("无效的 Ollama 响应格式")
except Exception as e:
logging.error(f"提交 prompt 错误: {str(e)}")
raise
def train(self, ddl=None, question=None, sql=None, documentation=None):
"""重写 train 方法,使用 Redis"""
try:
if ddl:
self._ddl = ddl
# 保存 DDL 到 Redis
self.redis_client.set(f"{self.config['redis_key_prefix']}ddl", ddl)
logging.info("DDL 已保存到 Redis")
if question and sql:
# 准备训练数据
data = {
'question': question,
'sql': sql,
'documentation': documentation or ''
}
# 生成唯一 ID
import hashlib
doc_id = hashlib.md5(json.dumps(data, ensure_ascii=False).encode()).hexdigest()
# 保存到 Redis
key = f"{self.config['redis_key_prefix']}data:{doc_id}"
self.redis_client.set(key, pickle.dumps(data))
# 将 ID 添加到训练数据集合中
self.redis_client.sadd(f"{self.config['redis_key_prefix']}data_ids", doc_id)
logging.info(f"训练数据已保存到 Redis: {data}")
return True
except Exception as e:
logging.error(f"训练错误: {str(e)}")
raise
def get_sql_prompt(self, question, ddl=None, similar_questions=None, similar_sql=None,
initial_prompt=None, question_sql_list=None, ddl_list=None, doc_list=None,
**kwargs):
"""重写 get_sql_prompt 方法"""
# 使用存储的 DDL
if not ddl and self._ddl:
ddl = self._ddl
# 构建提示词
prompt = "你是一个 SQL 专家。请根据以下信息生成 SQL 查询。\n\n"
prompt += "### 数据库结构:\n"
if ddl:
prompt += f"{ddl}\n\n"
# 添加文档说明
if doc_list:
prompt += "### 相关文档:\n"
for doc in doc_list:
prompt += f"{doc}\n"
prompt += "\n"
prompt += "### 问题:\n"
prompt += f"{question}\n\n"
if similar_questions and similar_sql:
prompt += "### 相似问题和对应的 SQL:\n"
for q, s in zip(similar_questions, similar_sql):
prompt += f"\n问题: {q}\nSQL: {s}\n"
prompt += "\n### 请生成对应的 SQL 查询 汉字转为简体:\n"
return prompt
def generate_sql(self, question, **kwargs):
try:
if self._ddl:
kwargs['ddl'] = self._ddl
return super().generate_sql(question, **kwargs)
except Exception as e:
logging.error(f"SQL 生成错误: {str(e)}")
raise
def get_related_ddl(self, question=None, **kwargs):
"""重写 get_related_ddl 方法,从 Redis 获取 DDL"""
try:
if self._ddl:
return self._ddl
# 从 Redis 获取 DDL
ddl = self.redis_client.get(f"{self.config['redis_key_prefix']}ddl")
if ddl:
self._ddl = ddl.decode()
return self._ddl
return None
except Exception as e:
logging.error(f"获取 DDL 错误: {str(e)}")
return None
def generate_plotly_code(self, question, sql_result=None, **kwargs):
"""重写 generate_plotly_code 方法"""
try:
# 构建提示词
prompt = self.get_plotly_prompt(question, sql_result=sql_result, **kwargs)
# 添加系统提示词
system_prompt = "你是一个数据可视化专家。请根据用户的需求生成 Plotly 图表代码。只返回 Python 代码,不需要其他解释。如果繁体转为简体。"
full_prompt = f"{system_prompt}\n\n{prompt}"
# 直接调用 submit_prompt
return self.submit_prompt(full_prompt, is_plotly=True)
except Exception as e:
logging.error(f"生成图表代码错误: {str(e)}")
raise
def get_plotly_prompt(self, question, sql=None, sql_result=None, **kwargs):
"""重写 get_plotly_prompt 方法"""
prompt = f"""请根据以下信息生成 Plotly 图表代码:
问题:{question}
SQL查询:{sql if sql else ''}
查询结果:{sql_result if sql_result else ''}
要求:
1. 使用 Plotly Express 生成图表
2. 只返回 Python 代码
3. 不要包含任何解释或说明
4. 确保代码的正确性
5. 如果繁体转为简体
"""
return prompt
def get_training_data(self):
"""重写 get_training_data 方法,使用 Redis"""
try:
# 获取所有训练数据 ID
data_ids = self.redis_client.smembers(f"{self.config['redis_key_prefix']}data_ids")
if not data_ids:
logging.info("Redis 中没有找到训练数据")
return pd.DataFrame(columns=['question', 'sql', 'documentation'])
# 获取所有训练数据
documents = []
for doc_id in data_ids:
try:
key = f"{self.config['redis_key_prefix']}data:{doc_id.decode()}"
data = self.redis_client.get(key)
if data:
doc_data = pickle.loads(data)
documents.append(doc_data)
logging.info(f"从 Redis 加载数据: {doc_data}")
except Exception as e:
logging.error(f"处理 Redis 数据时出错: {e}")
continue
# 创建 DataFrame
if documents:
df = pd.DataFrame(documents)
logging.info(f"已加载 {len(df)} 条训练数据")
return df
else:
logging.info("没有找到有效的训练数据")
return pd.DataFrame(columns=['question', 'sql', 'documentation'])
except Exception as e:
logging.error(f"获取训练数据错误: {str(e)}")
return pd.DataFrame(columns=['question', 'sql', 'documentation'])
def train_model(vn):
try:
# 训练 DDL
print("开始训练 DDL...")
ddl = """
CREATE TABLE `customer` (
`name` int NOT NULL COMMENT '姓名',
`gender` int DEFAULT NULL COMMENT '性别(男性=1/女性=2)',
`id_card` varchar(100) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT '身份证',
`mobile` varchar(100) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT '手机',
`nation` varchar(10) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT '民族',
`residential_city` varchar(100) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT '居住城市',
`age` int DEFAULT NULL COMMENT '岁数 年纪',
`salary` int NOT NULL COMMENT '薪水',
PRIMARY KEY (`name`)
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci COMMENT='customer'
"""
vn.train(ddl=ddl)
print("DDL 训练完成")
# 训练示例查询
examples = [
{
'question': "宁波有多少客户?",
'sql': "SELECT COUNT(*) as count FROM customer WHERE residential_city like '%宁波%'"
},
{
'question': "有多少女性客户?",
'sql': "SELECT COUNT(*) as count FROM customer WHERE gender = 2"
},
{
'question': "客户平均年龄是多少?",
'sql': "SELECT AVG(age) as average_age FROM customer"
},
{
'question': "客户平均薪水是多少?",
'sql': "SELECT AVG(salary) as average_salary FROM customer"
}
]
for example in examples:
print(f"\n训练示例: {example['question']}")
vn.train(question=example['question'], sql=example['sql'])
print("\n所有训练完成")
result = vn.ask("宁波有多少客户?")
print(f"\n查询问题: 宁波有多少客户?\n查询结果: {result}")
except Exception as e:
logging.error(f"训练错误: {str(e)}")
raise
if __name__ == "__main__":
try:
# 初始化 Vanna
vn = MyVanna()
# 连接数据库
vn.connect_to_mysql(
host='localhost',
dbname='test',
user='root',
password='123456',
port=3306
)
# 训练模型
train_model(vn)
# 启动 Flask 应用
from vanna.flask import VannaFlaskApp
app = VannaFlaskApp(vn)
app.run(host='0.0.0.0', port=7123)
except Exception as e:
logging.error(f"程序运行错误: {str(e)}")
CREATE TABLE `customer` (
`name` int NOT NULL COMMENT '姓名',
`gender` int DEFAULT NULL COMMENT '性别(男性=1/女性=2)',
`id_card` varchar(100) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT '身份证',
`mobile` varchar(100) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT '手机',
`nation` varchar(10) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT '民族',
`residential_city` varchar(100) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT '居住城市',
`age` int DEFAULT NULL COMMENT '岁数 年纪',
`salary` int NOT NULL COMMENT '薪水',
`id` int NOT NULL AUTO_INCREMENT COMMENT 'id',
PRIMARY KEY (`id`)
) ENGINE=InnoDB AUTO_INCREMENT=21 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci COMMENT='customer';
INSERT INTO `customer` (`name`,`gender`,`id_card`,`mobile`,`nation`,`residential_city`,`age`,`salary`,`id`) VALUES ('1','1','330201199001011234','13800001111','汉族','宁波','27','5520','1');
INSERT INTO `customer` (`name`,`gender`,`id_card`,`mobile`,`nation`,`residential_city`,`age`,`salary`,`id`) VALUES ('2','2','330201199102022345','13800002222','汉族','宁波','70','7042','2');
INSERT INTO `customer` (`name`,`gender`,`id_card`,`mobile`,`nation`,`residential_city`,`age`,`salary`,`id`) VALUES ('3','1','330201199203033456','13800003333','回族','宁波','94','4119','3');
INSERT INTO `customer` (`name`,`gender`,`id_card`,`mobile`,`nation`,`residential_city`,`age`,`salary`,`id`) VALUES ('4','2','330201199304044567','13800004444','汉族','宁波','60','4886','4');
INSERT INTO `customer` (`name`,`gender`,`id_card`,`mobile`,`nation`,`residential_city`,`age`,`salary`,`id`) VALUES ('5','1','330201199405055678','13800005555','壮族','宁波','5','5762','5');
INSERT INTO `customer` (`name`,`gender`,`id_card`,`mobile`,`nation`,`residential_city`,`age`,`salary`,`id`) VALUES ('6','1','110101199506066789','13800006666','汉族','北京','58','5515','6');
INSERT INTO `customer` (`name`,`gender`,`id_card`,`mobile`,`nation`,`residential_city`,`age`,`salary`,`id`) VALUES ('7','2','310101199607077890','13800007777','汉族','上海','69','2927','7');
INSERT INTO `customer` (`name`,`gender`,`id_card`,`mobile`,`nation`,`residential_city`,`age`,`salary`,`id`) VALUES ('8','1','440101199708088901','13800008888','满族','广州','90','5979','8');
INSERT INTO `customer` (`name`,`gender`,`id_card`,`mobile`,`nation`,`residential_city`,`age`,`salary`,`id`) VALUES ('9','2','500101199809099012','13800009999','汉族','重庆','91','7256','9');
INSERT INTO `customer` (`name`,`gender`,`id_card`,`mobile`,`nation`,`residential_city`,`age`,`salary`,`id`) VALUES ('10','1','610101199910101123','13800010000','回族','西安','28','4067','10');
INSERT INTO `customer` (`name`,`gender`,`id_card`,`mobile`,`nation`,`residential_city`,`age`,`salary`,`id`) VALUES ('11','2','320101199001111234','13800011111','汉族','南京','13','1979','11');
INSERT INTO `customer` (`name`,`gender`,`id_card`,`mobile`,`nation`,`residential_city`,`age`,`salary`,`id`) VALUES ('12','1','330101199002121345','13800012222','畲族','杭州','8','994','12');
INSERT INTO `customer` (`name`,`gender`,`id_card`,`mobile`,`nation`,`residential_city`,`age`,`salary`,`id`) VALUES ('13','2','420101199003131456','13800013333','汉族','武汉','29','1073','13');
INSERT INTO `customer` (`name`,`gender`,`id_card`,`mobile`,`nation`,`residential_city`,`age`,`salary`,`id`) VALUES ('14','1','510101199004141567','13800014444','彝族','成都','84','1441','14');
INSERT INTO `customer` (`name`,`gender`,`id_card`,`mobile`,`nation`,`residential_city`,`age`,`salary`,`id`) VALUES ('15','2','350101199005151678','13800015555','汉族','福州','33','7725','15');
INSERT INTO `customer` (`name`,`gender`,`id_card`,`mobile`,`nation`,`residential_city`,`age`,`salary`,`id`) VALUES ('16','1','370101199006161789','13800016666','汉族','济南','89','3821','16');
INSERT INTO `customer` (`name`,`gender`,`id_card`,`mobile`,`nation`,`residential_city`,`age`,`salary`,`id`) VALUES ('17','2','430101199007171890','13800017777','苗族','长沙','86','3082','17');
INSERT INTO `customer` (`name`,`gender`,`id_card`,`mobile`,`nation`,`residential_city`,`age`,`salary`,`id`) VALUES ('18','1','220101199008181901','13800018888','汉族','长春','48','4170','18');
INSERT INTO `customer` (`name`,`gender`,`id_card`,`mobile`,`nation`,`residential_city`,`age`,`salary`,`id`) VALUES ('19','2','450101199009192012','13800019999','壮族','南宁','30','1498','19');
INSERT INTO `customer` (`name`,`gender`,`id_card`,`mobile`,`nation`,`residential_city`,`age`,`salary`,`id`) VALUES ('20','1','130101199010202123','13800020000','汉族','石家庄','54','941','20');
1. 系统概述
这是一个基于 Ollama 和 Redis 的智能 SQL 问答系统,可以将自然语言问题转换为 SQL 查询语句。系统具有以下主要特点:
- 基于 LLM (Large Language Model) 的自然语言转 SQL
- 支持训练数据的持久化存储
- 提供 REST API 接口
- 支持数据可视化生成
2. 核心组件
2.1 MyVanna 类
主要继承关系:
class MyVanna(ChromaDB_VectorStore, Ollama)
核心配置参数:
self.config = {
'model': 'llama2:latest', # LLM 模型
'ollama_host': 'http://127.0.0.1:11434', # Ollama 服务地址
'temperature': 0.1, # 生成温度
'redis_host': '127.0.0.1', # Redis 配置
'redis_port': 6379,
'redis_db': 5,
'redis_password': '123456',
'redis_key_prefix': 'vanna_training:'
}
3. 主要功能模块
3.1 提示词生成 (Prompt Engineering)
def get_sql_prompt(self, question, ddl=None, similar_questions=None, similar_sql=None, ...):
提示词结构:
- 角色定义
- 数据库结构说明
- 相关文档
- 用户问题
- 相似问题参考
- 输出要求
3.2 训练功能
def train(self, ddl=None, question=None, sql=None, documentation=None):
训练数据包含:
- DDL(数据库结构)
- 问题-SQL 对
- 相关文档
存储方式:
- 使用 Redis 持久化
- 使用 hash 作为唯一标识
- 支持批量训练
3.3 SQL 生成
def generate_sql(self, question, **kwargs):
工作流程:
- 获取相关 DDL
- 构建提示词
- 调用 LLM 生成 SQL
- 错误处理和日志记录
3.4 数据可视化
def generate_plotly_code(self, question, sql_result=None, **kwargs):
特点:
- 使用 Plotly 生成可视化代码
- 支持 SQL 结果的直接可视化
- 自动处理中文编码
4. 示例训练数据
examples = [
{
'question': "宁波有多少客户?",
'sql': "SELECT COUNT(*) as count FROM customer WHERE residential_city like '%宁波%'"
},
{
'question': "有多少女性客户?",
'sql': "SELECT COUNT(*) as count FROM customer WHERE gender = 2"
}
# ...
]
5. 部署和使用
5.1 环境要求
- Python 3.x
- Redis 服务
- MySQL 数据库
- Ollama 服务
5.2 启动服务
if __name__ == "__main__":
vn = MyVanna()
vn.connect_to_mysql(...)
train_model(vn)
app = VannaFlaskApp(vn)
app.run(host='0.0.0.0', port=7123)
6. 改进建议
-
错误处理优化
- 添加更详细的错误类型
- 实现错误重试机制
-
性能优化
- 添加缓存机制
- 实现批量处理
-
安全性增强
- 添加 SQL 注入防护
- 实现访问控制
-
功能扩展
- 支持更多数据库类型
- 添加更多可视化选项
- 实现对话历史记录
7. 总结
该系统通过结合 LLM 和传统数据库技术,实现了一个灵活的自然语言到 SQL 的转换系统。其模块化设计和可扩展性使其适合在实际业务场景中使用和扩展。
主要优势:
- 模块化设计
- 可扩展架构
- 完整的训练流程
- 持久化存储支持
潜在改进空间:
- 性能优化
- 安全性增强
- 功能扩展
- 错误处理完善