当前位置: 首页 > article >正文

Text-to-SQL将自然语言转换为数据库查询语句

有关Text-To-SQL方法,可以查阅我的另一篇文章,Text-to-SQL方法研究

直接与数据库对话-text2sql

Text2sql就是把文本转换为sql语言,这段时间公司有这方面的需求,调研了一下市面上text2sql的方法,比如阿里的Chat2DB,麻省理工开源的Vanna。试验了一下,最终还是决定自研,基于Vanna的思想,RAG+大模型。

使用开源的Vanna实现text2sql比较方便,Vanna可以直接连接数据库,但是当用户权限能访问多个数据库的时候,就比较麻烦了,而且Vanna向量化存储之后,新的question作对比时没有区分数据库。因此自己实现了一下text2sq,仍然采用Vanna的思想,提前训练DDL,Sqlques,和数据库document。

这里简单做一下记录,以供后续学习使用。

基本思路

1、数据库DDL语句,SQL-Question,Dcoument信息获取

2、基于用户提问question和数据库Document锁定要分析的数据库

3、模型训练:借助数据库的DDL语句、元数据(描述数据库自身数据的信息)、相关文档说明、参考样例SQL等,训练一个RAG“模型”。

这一模型结合了embedding技术和向量数据库,使得数据库的结构和内容能够被高效地索引和检索。

4、语义检索: 当用户输入自然语言描述的问题时,①会从向量库里面检索,迅速找出与问题相关的内容;②进行BM25算法文本召回,找到与问题 最相关的内容;③分别使用RRF算法和Re-ranking重排序算法,锁定最相关内容

语义匹配:使用算法(如BERT等)来理解查询和文档的语义相似性

文本召回匹配:BM25算法文本召回,找到与问题最相关的内容

rerank结果重排序:对搜索结果进行排序。

5、Prompt构建: 检索到的相关信息会被组装进Prompt中,形成一个结构化的查询描述。这一Prompt随后会被传递给LLM(大型语言模型)用于生成准确的SQL查询。

实现逻辑图

实现架构图:

具体实现方式如下所示:

1.数据库的选择
class DataBaseSearch(object):
    def __init__(self, _model):
        self.name = 'DataBaseSearch'
        self.model = _model
        self.instruction = "为这段内容生成表示以用于匹配文本描述:"
        self.SIZE = 1024
        self.index = faiss.IndexFlatL2(self.SIZE)

        self.textdata = []
        self.subdata = {}
        self.i2key = {}
        self.id2ddls = {}
        self.id2sqlques = {}
        self.id2docs = {}
        self.strtexts = {}

        # self.ddldata = []
        # self.sqlques_data = []
        # self.document_data = []
        self.load_textdata()         # 加载text数据
        self.load_textdata_vec()     # text数据向量化

    def load_textdata(self):
        try:
            response = requests.post(
                url="xxx",
                verify=False)
            print(response.text)
            jsonobj = json.loads(response.text)

            textdatas = jsonobj["data"]

            for textdata in textdatas:                                 # 提取每一个数据库内容
                cid = textdata["dataSetID"]
                cddls = textdata["ddl"]
                csql_ques = textdata["exp"]
                cdocuments = textdata["Intro"]

                self.textdata.append((cid, cddls, csql_ques, cdocuments))   # 整合所有数据
        except Exception as e:
            print(e)
        # print("load textdata ", self.textdata)


    def load_textdata_vec(self):
        num0 = 0
        for recode in self.textdata:
            _id = recode[0]
            _ddls = recode[1]
            _sql_ques = recode[2]
            _documents = recode[3]

            # _strtexts = str(_ddls) + str(_sql_ques) + str(_documents)
            _strtexts = str(_sql_ques) + str(_documents)
            text_embeddings = self.model.encode([_strtexts], normalize_embeddings=True)
            self.index.add(text_embeddings)

            self.i2key[num0] = _id
            self.strtexts[_id] = _strtexts
            self.id2ddls[_id] = _ddls
            self.id2sqlques[_id] = _sql_ques
            self.id2docs[_id] = _documents


            num0 += 1
        # print("init instruction vec", num0)

    def calculate_score(self, score, question, kws):
        pass

    def find_vec_database(self, question, k, theata):
        # print(question)
        q_embeddings = self.model.encode([self.instruction + question], normalize_embeddings=True)

        D, I = self.index.search(q_embeddings, k)

        result = []
        for i in range(k):
            sim_i = I[0][i]
            uuid = self.i2key.get(sim_i, "none")
            sim_v = D[0][i]
            database_texts = self.strtexts.get(uuid, "none")
            # score = self.calculate_score(sim_v, question, database_texts) # wait implement
            score = int(sim_v*1000)
            if score < theata:
                doc = {}
                doc["score"] = score
                doc["dataSetID"] = uuid
                result.append(doc)
        # print(result)
        return result

if __name__ == '__main__':

    modelpath = "E:\module\bge-large-zh-v1.5"
    model = SentenceTransformer(modelpath)

    vs = DataBaseSearch(model)

    result = vs.find_vec_database("查询济南市第三幼儿园所有小班班级?", 1, 2000)
    print(result)
2.sql_ques:sql问题训练
class SqlQuesSearch(object):
    def __init__(self, _model):
        self.name = "SqlQuesSearch"
        self.model = _model
        self.instruction = "为这段内容生成表示以用于匹配文本描述:"
        self.SIZE = 1024
        self.index = faiss.IndexFlatL2(self.SIZE)

        self.sqlquedata = []

        self.i2dbid = {}
        self.i2sqlid = {}
        self.id2sqlque = {}
        self.id2que = {}
        self.id2sql = {}
        self.dbid2sqlques = {}
        #
        # self.sqlques = {}
        #
        # self.i2key = {}
        #
        # self.id2sqlques = {}
        #
        # self.num2sqlque = {}



        # self.ddldata = []
        # self.sqlques_data = []
        # self.document_data = []

        self.load_textdata()  # 加载text数据

        self.load_textdata_vec()  # text数据向量化

    def load_textdata(self):
        try:
            response = requests.post(
                url="xxx",
                verify=False)
            print(response.text)
            jsonobj = json.loads(response.text)
           
            textdatas = jsonobj["data"]

datadatas = jsonobj["data"]

            for datadata in datadatas:  # 提取每一个数据库sql-ques内容
                dbid = datadata["dataSetID"]
                sql_ques = datadata["exp"]

                self.sqlquedata.append((dbid, sql_ques))  # 整合sql数据


        except Exception as e:
            print(e)
        # print("load textdata ", self.sqlquedata)

    def load_textdata_vec(self):

        num0 = 0
        for recode in self.sqlquedata:
            db_id = recode[0]
            sql_ques = recode[1]

            for sql_que in sql_ques:
                sql_id = sql_que["sql_id"]
                question = sql_que["question"]
                sql = sql_que["sql"]

                ddl_embeddings = self.model.encode([question], normalize_embeddings=True)
                self.index.add(ddl_embeddings)

                self.i2dbid[num0] = db_id
                self.i2sqlid[num0] = sql_id
                self.id2que[sql_id] = question
                self.id2sql[sql_id] = sql

                num0 += 1
        print("init sql-que vec", num0)

    def calculate_score(sim_v, question, sql_ques):
        pass

    def find_vec_sqlque(self, question, k, theta, dataSetID, number):
        q_embeddings = self.model.encode([self.instruction + question], normalize_embeddings=True)
        D, I = self.index.search(q_embeddings, k)

        result = []
        for i in range(k):
            sim_i = I[0][i]
            dbid = self.i2dbid.get(sim_i, "none")  # 获取数据库id
            sqlid = self.i2sqlid.get(sim_i, "none")
            question = self.id2que.get(sqlid, "none")
            sql = self.id2sql.get(sqlid, "none")
            if dbid == dataSetID:
                sim_v = D[0][i]
                score = int(sim_v * 1000)
                if score < theta:
                    doc = {}
                    doc["score"] = score
                    doc["question"] = question
                    doc["sql"] = sql
                    result.append(doc)
                    if len(result) == number:
                        break
        return result


if __name__ == '__main__':

    modelpath = "E:\module\bge-large-zh-v1.5"
    model = SentenceTransformer(modelpath)

    vs = SqlQuesSearch(model)

    result = vs.find_vec_sqlque("查询7月18日所有的儿童观察记录?", 3, 2000, dataSetID=111)

    print(result)
3.数据库DDL训练
class DdlQuesSearch(object):
    def __init__(self, _model):
        self.name = "DdlQuesSearch"
        self.model = _model
        self.instruction = "为这段内容生成表示以用于匹配文本描述:"
        self.SIZE = 1024
        self.index = faiss.IndexFlatL2(self.SIZE)

        self.ddldata = []

        self.sqlques = {}

        self.i2dbid = {}
        self.i2ddlid = {}


        self.dbid2ddls = {}
        self.id2ddl = {}
        self.ddlid2dbid = {}

        # self.ddldata = []
        # self.sqlques_data = []
        # self.document_data = []

        self.load_ddldata()  # 加载text数据

        self.load_ddl_vec()  # text数据向量化

    def load_ddldata(self):
        try:
            response = requests.post(
                url="xxx",
                verify=False)
            print(response.text)
            jsonobj = json.loads(response.text)

            for database in databases:
                db_id = database["dataSetID"]
                ddls = database["ddl"]

                self.ddldata.append((db_id, ddls))
                # print(db_id)

                # for ddl in database["ddl"]:
                #     ddl_id = ddl["ddl_id"]
                #     ddl = ddl['ddl']
                #
                #     self.id2ddl[ddl_id] = ddl
                # self.dbid2ddls[db_id] = self.id2ddl
        except Exception as e:
            print(e)
        # print("load textdata ", self.ddldata)

    def load_ddl_vec(self):

        num0 = 0
        for recode in self.ddldata:
            db_id = recode[0]
            ddls = recode[1]

            for ddl in ddls:
                ddl_id = ddl["ddl_id"]
                ddl_name = ddl["TABLE"]
                ddl = ddl['ddl']
                ddl_embeddings = self.model.encode([ddl], normalize_embeddings=True)
                self.index.add(ddl_embeddings)

                self.i2dbid[num0] = db_id
                self.i2ddlid[num0] = ddl_id
                self.id2ddl[ddl_id] = ddl
                self.ddlid2dbid[ddl_id] = db_id

                num0 += 1


            self.dbid2ddls[db_id] = self.id2ddl
        print("init ddl vec", num0)

    def find_vec_ddl(self, question, k, theata, dataSetID, number):       # dataSetID:数据库id
        # self.id2ddls.get(action_id)
        q_embeddings = self.model.encode([self.instruction + question], normalize_embeddings=True)
        D, I = self.index.search(q_embeddings, k)

        result = []
        for i in range(k):
            sim_i = I[0][i]
            dbid = self.i2dbid.get(sim_i, "none")         # 获取数据库id
            ddlid = self.i2ddlid.get(sim_i, "none")
            if dbid == dataSetID:
                sim_v = D[0][i]
                score = int(sim_v * 1000)

                if score < theata:
                    doc = {}
                    doc["score"] = score
                    doc["ddl"] = self.id2ddl.get(ddlid, "none")
                    result.append(doc)
                    if len(result) == number:
                        break
        return result


if __name__ == '__main__':
    modelpath = "E:\module\bge-large-zh-v1.5"
    model = SentenceTransformer(modelpath)

    vs = DdlQuesSearch(model)

    ss = vs.find_vec_ddl("定时任务执行记录表", 2, 2000, 111)
    print(ss)
4.数据库document训练
class DocQuesSearch(object):
    def __init__(self):
        self.name = "TestDataSearch"
        self.docdata = []

        self.load_doc_data()

    def load_doc_data(self):
        try:
            response = requests.post(
                url="xxx",
                verify=False)
            print(response.text)
            jsonobj = json.loads(response.text)

databases = jsonobj["data"]

            for database in databases:
                db_id = database["dataSetID"]
                doc = database["Intro"]
                self.docdata.append((db_id, doc))
        except Exception as e:
            print(e)
        # print("load ddldata ", self.docdata)

    def find_similar_doc(self, dataSetID):
        result = []
        for recode in self.docdata:
            dbid = recode[0]
            doc = recode[1]
            if dbid == dataSetID:
                result.append(doc)
        return result



if __name__ == '__main__':
    docques_search = DocQuesSearch()
    result = docques_search.find_similar_doc(222)
    print(result)
5.生成sql语句,这里使用的qwen-max模型
import re
import random
import os, json
import dashscope
from dashscope.api_entities.dashscope_response import Message
from ddl_engine import DdlQuesSearch
from dashscope import Generation
from sqlques_engine import SqlQuesSearch
from sentence_transformers import SentenceTransformer

class Genarate(object):
    def __init__(self):
        self.api_key = os.environ.get('api_key')
        self.model_name = os.environ.get('model')

    def system_message(self, message):
        return {'role': 'system', 'content': message}

    def user_message(self, message):
        return {'role': 'user', 'content': message}

    def assistant_message(self, message):
        return {'role': 'assistant', 'content': message}

    def submit_prompt(self, prompt):
        resp = Generation.call(
            model=self.model_name,
            messages=prompt,
            seed=random.randint(1, 10000),
            result_format='message',
            api_key=self.api_key)

        if resp["status_code"] == 200:
            answer = resp.output.choices[0].message.content
            global DEBUG_INFO
            DEBUG_INFO = (prompt, answer)
            return answer
        else:
            answer = None
            return answer


    def generate_sql(self, question, sql_result, ddl_result, doc_result):
        prompt = self.get_sql_prompt(
            question = question,
            sql_result = sql_result,
            ddl_result = ddl_result,
            doc_result = doc_result)

        print("SQL Prompt:",prompt)
        llm_response = self.submit_prompt(prompt)

        sql = self.extrat_sql(llm_response)

        return sql

    def extrat_sql(self, llm_response):
        sqls = re.findall(r"WITH.*?;", llm_response, re.DOTALL)
        if sqls:
            sql = sqls[-1]
            return sql

        sqls = re.findall(r"SELECT.*?;", llm_response, re.DOTALL)
        if sqls:
            sql = sqls[-1]
            return sql

        sqls = re.findall(r"```sql
(.*)```", llm_response, re.DOTALL)
        if sqls:
            sql = sqls[-1]
            return sql

        sqls = re.findall(r"```(.*)```", llm_response, re.DOTALL)
        if sqls:
            sql = sqls[-1]
            return sql

        return llm_response

    def get_sql_prompt(self, question, sql_result, ddl_result, doc_result):

        initial_prompt = "You are a SQL expert. " + 
                          "Please help to generate a SQL query to answer the question. Your response should ONLY be based on the given context and follow the response guidelines and format instructions. "

        initial_prompt = self.add_ddl_to_prompt( initial_prompt, ddl_result)
        initial_prompt = self.add_documentation_to_prompt(initial_prompt, doc_result)
        initial_prompt += (
            "===Response Guidelines 
"
            "1. If the provided context is sufficient, please generate a valid SQL query without any explanations for the question. 
"
            "2. If the provided context is almost sufficient but requires knowledge of a specific string in a particular column, please generate an intermediate SQL query to find the distinct strings in that column. Prepend the query with a comment saying intermediate_sql 
"
            "3. If the provided context is insufficient, please explain why it can't be generated. 
"
            "4. Please use the most relevant table(s). 
"
            "5. If the question has been asked and answered before, please repeat the answer exactly as it was given before. 
"
        )
        message_log = [self.system_message(initial_prompt)]

        message_log = self.add_sqlques_to_prompt(question, sql_result, message_log)

        return message_log

    def add_ddl_to_prompt(self, initial_prompt, ddl_result):
        """
        :param initial_prompt:
        :param ddl_result:
        :return:
        """

        ddl_list = [ ddl_['ddl'] for ddl_ in ddl_result]
        if len(ddl_list) > 0:
            initial_prompt += "
===Tables 
"
            for ddl in ddl_list:
                initial_prompt += f"{ddl}

"

        return initial_prompt

    def add_sqlques_to_prompt(self, question, sql_result, message_log):
        """
        :param sql_result:
        :return:
        """
        if len(sql_result) > 0:
            for example in sql_result:
                if example is not None and "question" in example and "sql" in example:
                    message_log.append(self.user_message(example["question"]))
                    message_log.append(self.assistant_message(example["sql"]))

            message_log.append(self.user_message(question))

        return message_log

    def add_documentation_to_prompt(self, initial_prompt, doc_result):
        if len(doc_result) > 0:
            initial_prompt += "
===Additional Context 

"

            for doc in doc_result:
                initial_prompt += f"{doc}

"
        return initial_prompt




if __name__ == '__main__':
    modelpath = "E:\module\bge-large-zh-v1.5"
    model = SentenceTransformer(modelpath)

    vs = DdlQuesSearch(model)

    ss = vs.find_vec_ddl("定时任务执行记录表", 1, 2000, 111)
    print(ss)
6.执行结果显示

如图可以看到正确生成了sql,可以正常执行,因为表是拉取到,没有数据,所以查询结果为空。

需要源码的同学,可以留言。


http://www.kler.cn/a/567610.html

相关文章:

  • pyside6学习专栏(八):在PySide6中使用matplotlib库绘制三维图形
  • Swan 表达式 - 选择表达式
  • 【由技及道】模块化战争与和平-论项目结构的哲学思辨【人工智智障AI2077的开发日志】
  • 美团自动驾驶决策规划算法岗内推
  • 将QT移植到RK3568开发板
  • 酒店管理系统(代码+数据库+LW)
  • MySQL并发知识(面试高频)
  • SOLID Principle基础入门
  • 机器学习3-聚类
  • 【图像平移、旋转、仿射变换、投影变换】
  • threeJs+vue 轻松切换几何体贴图
  • Flutter 学习之旅 之 flutter 使用 fluttertoast 的 toast 实现简单的 Toast 效果
  • 基于单片机的智能扫地机器人
  • ArcGIS Pro高级技巧:高效填充DEM数据空洞
  • 软件测试中的BUG
  • 【人工智能】数据挖掘与应用题库(1-100)
  • 软件测试之白盒测试知识总结
  • OpenHarmony图形子系统
  • 网络安全-使用DeepSeek来获取sqlmap的攻击payload
  • 【自学笔记】DevOps基础知识点总览-持续更新