Text-to-SQL将自然语言转换为数据库查询语句
有关Text-To-SQL方法,可以查阅我的另一篇文章,Text-to-SQL方法研究
直接与数据库对话-text2sql
Text2sql就是把文本转换为sql语言,这段时间公司有这方面的需求,调研了一下市面上text2sql的方法,比如阿里的Chat2DB,麻省理工开源的Vanna。试验了一下,最终还是决定自研,基于Vanna的思想,RAG+大模型。
使用开源的Vanna实现text2sql比较方便,Vanna可以直接连接数据库,但是当用户权限能访问多个数据库的时候,就比较麻烦了,而且Vanna向量化存储之后,新的question作对比时没有区分数据库。因此自己实现了一下text2sq,仍然采用Vanna的思想,提前训练DDL,Sqlques,和数据库document。
这里简单做一下记录,以供后续学习使用。
基本思路
1、数据库DDL语句,SQL-Question,Dcoument信息获取
2、基于用户提问question和数据库Document锁定要分析的数据库
3、模型训练:借助数据库的DDL语句、元数据(描述数据库自身数据的信息)、相关文档说明、参考样例SQL等,训练一个RAG“模型”。
这一模型结合了embedding技术和向量数据库,使得数据库的结构和内容能够被高效地索引和检索。
4、语义检索: 当用户输入自然语言描述的问题时,①会从向量库里面检索,迅速找出与问题相关的内容;②进行BM25算法文本召回,找到与问题 最相关的内容;③分别使用RRF算法和Re-ranking重排序算法,锁定最相关内容
语义匹配:使用算法(如BERT等)来理解查询和文档的语义相似性
文本召回匹配:BM25算法文本召回,找到与问题最相关的内容
rerank结果重排序:对搜索结果进行排序。
5、Prompt构建: 检索到的相关信息会被组装进Prompt中,形成一个结构化的查询描述。这一Prompt随后会被传递给LLM(大型语言模型)用于生成准确的SQL查询。
实现逻辑图
实现架构图:
具体实现方式如下所示:
1.数据库的选择
class DataBaseSearch(object):
def __init__(self, _model):
self.name = 'DataBaseSearch'
self.model = _model
self.instruction = "为这段内容生成表示以用于匹配文本描述:"
self.SIZE = 1024
self.index = faiss.IndexFlatL2(self.SIZE)
self.textdata = []
self.subdata = {}
self.i2key = {}
self.id2ddls = {}
self.id2sqlques = {}
self.id2docs = {}
self.strtexts = {}
# self.ddldata = []
# self.sqlques_data = []
# self.document_data = []
self.load_textdata() # 加载text数据
self.load_textdata_vec() # text数据向量化
def load_textdata(self):
try:
response = requests.post(
url="xxx",
verify=False)
print(response.text)
jsonobj = json.loads(response.text)
textdatas = jsonobj["data"]
for textdata in textdatas: # 提取每一个数据库内容
cid = textdata["dataSetID"]
cddls = textdata["ddl"]
csql_ques = textdata["exp"]
cdocuments = textdata["Intro"]
self.textdata.append((cid, cddls, csql_ques, cdocuments)) # 整合所有数据
except Exception as e:
print(e)
# print("load textdata ", self.textdata)
def load_textdata_vec(self):
num0 = 0
for recode in self.textdata:
_id = recode[0]
_ddls = recode[1]
_sql_ques = recode[2]
_documents = recode[3]
# _strtexts = str(_ddls) + str(_sql_ques) + str(_documents)
_strtexts = str(_sql_ques) + str(_documents)
text_embeddings = self.model.encode([_strtexts], normalize_embeddings=True)
self.index.add(text_embeddings)
self.i2key[num0] = _id
self.strtexts[_id] = _strtexts
self.id2ddls[_id] = _ddls
self.id2sqlques[_id] = _sql_ques
self.id2docs[_id] = _documents
num0 += 1
# print("init instruction vec", num0)
def calculate_score(self, score, question, kws):
pass
def find_vec_database(self, question, k, theata):
# print(question)
q_embeddings = self.model.encode([self.instruction + question], normalize_embeddings=True)
D, I = self.index.search(q_embeddings, k)
result = []
for i in range(k):
sim_i = I[0][i]
uuid = self.i2key.get(sim_i, "none")
sim_v = D[0][i]
database_texts = self.strtexts.get(uuid, "none")
# score = self.calculate_score(sim_v, question, database_texts) # wait implement
score = int(sim_v*1000)
if score < theata:
doc = {}
doc["score"] = score
doc["dataSetID"] = uuid
result.append(doc)
# print(result)
return result
if __name__ == '__main__':
modelpath = "E:\module\bge-large-zh-v1.5"
model = SentenceTransformer(modelpath)
vs = DataBaseSearch(model)
result = vs.find_vec_database("查询济南市第三幼儿园所有小班班级?", 1, 2000)
print(result)
2.sql_ques:sql问题训练
class SqlQuesSearch(object):
def __init__(self, _model):
self.name = "SqlQuesSearch"
self.model = _model
self.instruction = "为这段内容生成表示以用于匹配文本描述:"
self.SIZE = 1024
self.index = faiss.IndexFlatL2(self.SIZE)
self.sqlquedata = []
self.i2dbid = {}
self.i2sqlid = {}
self.id2sqlque = {}
self.id2que = {}
self.id2sql = {}
self.dbid2sqlques = {}
#
# self.sqlques = {}
#
# self.i2key = {}
#
# self.id2sqlques = {}
#
# self.num2sqlque = {}
# self.ddldata = []
# self.sqlques_data = []
# self.document_data = []
self.load_textdata() # 加载text数据
self.load_textdata_vec() # text数据向量化
def load_textdata(self):
try:
response = requests.post(
url="xxx",
verify=False)
print(response.text)
jsonobj = json.loads(response.text)
textdatas = jsonobj["data"]
datadatas = jsonobj["data"]
for datadata in datadatas: # 提取每一个数据库sql-ques内容
dbid = datadata["dataSetID"]
sql_ques = datadata["exp"]
self.sqlquedata.append((dbid, sql_ques)) # 整合sql数据
except Exception as e:
print(e)
# print("load textdata ", self.sqlquedata)
def load_textdata_vec(self):
num0 = 0
for recode in self.sqlquedata:
db_id = recode[0]
sql_ques = recode[1]
for sql_que in sql_ques:
sql_id = sql_que["sql_id"]
question = sql_que["question"]
sql = sql_que["sql"]
ddl_embeddings = self.model.encode([question], normalize_embeddings=True)
self.index.add(ddl_embeddings)
self.i2dbid[num0] = db_id
self.i2sqlid[num0] = sql_id
self.id2que[sql_id] = question
self.id2sql[sql_id] = sql
num0 += 1
print("init sql-que vec", num0)
def calculate_score(sim_v, question, sql_ques):
pass
def find_vec_sqlque(self, question, k, theta, dataSetID, number):
q_embeddings = self.model.encode([self.instruction + question], normalize_embeddings=True)
D, I = self.index.search(q_embeddings, k)
result = []
for i in range(k):
sim_i = I[0][i]
dbid = self.i2dbid.get(sim_i, "none") # 获取数据库id
sqlid = self.i2sqlid.get(sim_i, "none")
question = self.id2que.get(sqlid, "none")
sql = self.id2sql.get(sqlid, "none")
if dbid == dataSetID:
sim_v = D[0][i]
score = int(sim_v * 1000)
if score < theta:
doc = {}
doc["score"] = score
doc["question"] = question
doc["sql"] = sql
result.append(doc)
if len(result) == number:
break
return result
if __name__ == '__main__':
modelpath = "E:\module\bge-large-zh-v1.5"
model = SentenceTransformer(modelpath)
vs = SqlQuesSearch(model)
result = vs.find_vec_sqlque("查询7月18日所有的儿童观察记录?", 3, 2000, dataSetID=111)
print(result)
3.数据库DDL训练
class DdlQuesSearch(object):
def __init__(self, _model):
self.name = "DdlQuesSearch"
self.model = _model
self.instruction = "为这段内容生成表示以用于匹配文本描述:"
self.SIZE = 1024
self.index = faiss.IndexFlatL2(self.SIZE)
self.ddldata = []
self.sqlques = {}
self.i2dbid = {}
self.i2ddlid = {}
self.dbid2ddls = {}
self.id2ddl = {}
self.ddlid2dbid = {}
# self.ddldata = []
# self.sqlques_data = []
# self.document_data = []
self.load_ddldata() # 加载text数据
self.load_ddl_vec() # text数据向量化
def load_ddldata(self):
try:
response = requests.post(
url="xxx",
verify=False)
print(response.text)
jsonobj = json.loads(response.text)
for database in databases:
db_id = database["dataSetID"]
ddls = database["ddl"]
self.ddldata.append((db_id, ddls))
# print(db_id)
# for ddl in database["ddl"]:
# ddl_id = ddl["ddl_id"]
# ddl = ddl['ddl']
#
# self.id2ddl[ddl_id] = ddl
# self.dbid2ddls[db_id] = self.id2ddl
except Exception as e:
print(e)
# print("load textdata ", self.ddldata)
def load_ddl_vec(self):
num0 = 0
for recode in self.ddldata:
db_id = recode[0]
ddls = recode[1]
for ddl in ddls:
ddl_id = ddl["ddl_id"]
ddl_name = ddl["TABLE"]
ddl = ddl['ddl']
ddl_embeddings = self.model.encode([ddl], normalize_embeddings=True)
self.index.add(ddl_embeddings)
self.i2dbid[num0] = db_id
self.i2ddlid[num0] = ddl_id
self.id2ddl[ddl_id] = ddl
self.ddlid2dbid[ddl_id] = db_id
num0 += 1
self.dbid2ddls[db_id] = self.id2ddl
print("init ddl vec", num0)
def find_vec_ddl(self, question, k, theata, dataSetID, number): # dataSetID:数据库id
# self.id2ddls.get(action_id)
q_embeddings = self.model.encode([self.instruction + question], normalize_embeddings=True)
D, I = self.index.search(q_embeddings, k)
result = []
for i in range(k):
sim_i = I[0][i]
dbid = self.i2dbid.get(sim_i, "none") # 获取数据库id
ddlid = self.i2ddlid.get(sim_i, "none")
if dbid == dataSetID:
sim_v = D[0][i]
score = int(sim_v * 1000)
if score < theata:
doc = {}
doc["score"] = score
doc["ddl"] = self.id2ddl.get(ddlid, "none")
result.append(doc)
if len(result) == number:
break
return result
if __name__ == '__main__':
modelpath = "E:\module\bge-large-zh-v1.5"
model = SentenceTransformer(modelpath)
vs = DdlQuesSearch(model)
ss = vs.find_vec_ddl("定时任务执行记录表", 2, 2000, 111)
print(ss)
4.数据库document训练
class DocQuesSearch(object):
def __init__(self):
self.name = "TestDataSearch"
self.docdata = []
self.load_doc_data()
def load_doc_data(self):
try:
response = requests.post(
url="xxx",
verify=False)
print(response.text)
jsonobj = json.loads(response.text)
databases = jsonobj["data"]
for database in databases:
db_id = database["dataSetID"]
doc = database["Intro"]
self.docdata.append((db_id, doc))
except Exception as e:
print(e)
# print("load ddldata ", self.docdata)
def find_similar_doc(self, dataSetID):
result = []
for recode in self.docdata:
dbid = recode[0]
doc = recode[1]
if dbid == dataSetID:
result.append(doc)
return result
if __name__ == '__main__':
docques_search = DocQuesSearch()
result = docques_search.find_similar_doc(222)
print(result)
5.生成sql语句,这里使用的qwen-max模型
import re
import random
import os, json
import dashscope
from dashscope.api_entities.dashscope_response import Message
from ddl_engine import DdlQuesSearch
from dashscope import Generation
from sqlques_engine import SqlQuesSearch
from sentence_transformers import SentenceTransformer
class Genarate(object):
def __init__(self):
self.api_key = os.environ.get('api_key')
self.model_name = os.environ.get('model')
def system_message(self, message):
return {'role': 'system', 'content': message}
def user_message(self, message):
return {'role': 'user', 'content': message}
def assistant_message(self, message):
return {'role': 'assistant', 'content': message}
def submit_prompt(self, prompt):
resp = Generation.call(
model=self.model_name,
messages=prompt,
seed=random.randint(1, 10000),
result_format='message',
api_key=self.api_key)
if resp["status_code"] == 200:
answer = resp.output.choices[0].message.content
global DEBUG_INFO
DEBUG_INFO = (prompt, answer)
return answer
else:
answer = None
return answer
def generate_sql(self, question, sql_result, ddl_result, doc_result):
prompt = self.get_sql_prompt(
question = question,
sql_result = sql_result,
ddl_result = ddl_result,
doc_result = doc_result)
print("SQL Prompt:",prompt)
llm_response = self.submit_prompt(prompt)
sql = self.extrat_sql(llm_response)
return sql
def extrat_sql(self, llm_response):
sqls = re.findall(r"WITH.*?;", llm_response, re.DOTALL)
if sqls:
sql = sqls[-1]
return sql
sqls = re.findall(r"SELECT.*?;", llm_response, re.DOTALL)
if sqls:
sql = sqls[-1]
return sql
sqls = re.findall(r"```sql
(.*)```", llm_response, re.DOTALL)
if sqls:
sql = sqls[-1]
return sql
sqls = re.findall(r"```(.*)```", llm_response, re.DOTALL)
if sqls:
sql = sqls[-1]
return sql
return llm_response
def get_sql_prompt(self, question, sql_result, ddl_result, doc_result):
initial_prompt = "You are a SQL expert. " +
"Please help to generate a SQL query to answer the question. Your response should ONLY be based on the given context and follow the response guidelines and format instructions. "
initial_prompt = self.add_ddl_to_prompt( initial_prompt, ddl_result)
initial_prompt = self.add_documentation_to_prompt(initial_prompt, doc_result)
initial_prompt += (
"===Response Guidelines
"
"1. If the provided context is sufficient, please generate a valid SQL query without any explanations for the question.
"
"2. If the provided context is almost sufficient but requires knowledge of a specific string in a particular column, please generate an intermediate SQL query to find the distinct strings in that column. Prepend the query with a comment saying intermediate_sql
"
"3. If the provided context is insufficient, please explain why it can't be generated.
"
"4. Please use the most relevant table(s).
"
"5. If the question has been asked and answered before, please repeat the answer exactly as it was given before.
"
)
message_log = [self.system_message(initial_prompt)]
message_log = self.add_sqlques_to_prompt(question, sql_result, message_log)
return message_log
def add_ddl_to_prompt(self, initial_prompt, ddl_result):
"""
:param initial_prompt:
:param ddl_result:
:return:
"""
ddl_list = [ ddl_['ddl'] for ddl_ in ddl_result]
if len(ddl_list) > 0:
initial_prompt += "
===Tables
"
for ddl in ddl_list:
initial_prompt += f"{ddl}
"
return initial_prompt
def add_sqlques_to_prompt(self, question, sql_result, message_log):
"""
:param sql_result:
:return:
"""
if len(sql_result) > 0:
for example in sql_result:
if example is not None and "question" in example and "sql" in example:
message_log.append(self.user_message(example["question"]))
message_log.append(self.assistant_message(example["sql"]))
message_log.append(self.user_message(question))
return message_log
def add_documentation_to_prompt(self, initial_prompt, doc_result):
if len(doc_result) > 0:
initial_prompt += "
===Additional Context
"
for doc in doc_result:
initial_prompt += f"{doc}
"
return initial_prompt
if __name__ == '__main__':
modelpath = "E:\module\bge-large-zh-v1.5"
model = SentenceTransformer(modelpath)
vs = DdlQuesSearch(model)
ss = vs.find_vec_ddl("定时任务执行记录表", 1, 2000, 111)
print(ss)
6.执行结果显示
如图可以看到正确生成了sql,可以正常执行,因为表是拉取到,没有数据,所以查询结果为空。
需要源码的同学,可以留言。