【text2sql】DB-GPT-Hub:text2sql的微调框架及基准测试套件
text2sql任务是将自然语言问题转换为SQL查询。使用大模型来进行 sql 生成的方式也越来越常见。根据大模型用于文本到SQL生成的方式,text2sql可以分为两种场景:零样本/少样本提示和微调。
零样本/少样本提示:在零样本场景中,不提供示例;而在少样本场景中,提供少量输入输出示例以提示大模型。形式上,给定一个由 θ \theta θ参数化的LLM,问题 q i q_i qi和 k k k个示例( k ≥ 0 k \geq 0 k≥0),目标是最大化从大型语言模型生成正确SQL s i s_i si的概率:
max s i P L L M θ ( s i ∣ σ ( q i , M ) ) , ∣ M ∣ = k \max_{s_{i}} P_{L L M_{\theta}}\left(s_{i}\mid\sigma\left(q_{i},\mathcal{M}\right)\right),\quad|\mathcal{M}|=k simaxPLLMθ(si∣σ(qi,M)),∣M∣=k
其中 Θ \Theta Θ和 σ ( q i , M ) 1 \sigma\left(q_i,\mathcal{M}\right)^1 σ(qi,M)1表示通过结合示例中的相关信息来表征目标问题 q i q_{i} qi的表示空间。
微调:微调过程涉及通过使用包含一系列序列化输入 q i q_i qi和相应SQL输出 s i s_i si对的文本到SQL数据集来调整预训练的 L L M θ LLM_{\theta} LLMθ,以生成SQL。微调的目标是最小化经验损失:
min θ L ( s ^ i ( L L M θ ) , s i ∣ σ ( q i ) ) \min_{\theta}\mathcal{L}\left(\widehat{s}_{i}\left(L L M_{\theta}\right), s_{i}\mid\sigma\left(q_{i}\right)\right) θminL(s i(LLMθ),si∣σ(qi))
其中 L \mathcal{L} L是衡量生成的SQL与真实值之间差异的损失函数。
尽管少样本提示的大模型取得了显著进展,但仅依靠其参数知识和提示来准确处理高度复杂的SQL查询对于预训练的大型语言模型来说仍然是一个巨大的挑战。DB-GPT-Hub侧重于微调更大的大型语言模型。
框架设计
DB-GPT 框架下提出了一个端到端大模型 Text2SQL 微调子框架 DB-GPT-Hub。在 DB-GPT 框架下,构架了 Text2SQL 领域下的数据预处理 - 模型微调 - 模型预测 - 模型验证 - 模型评估的全链路工作流程,如下图所示:
代码库设计
-
数据集构建:将原始文本到SQL数据加工成适合微调LLM的格式(例如,列表1中显示的TRF),这包括将模式和数据库描述整合到提示中作为指令,以及在训练和评估期间增强性能的各种问题表示。此外,我们还将选择不同的少样本策略,例如示例选择和组织,来构建评估数据集(Gao等人,2023)。
-
训练:支持使用PEFT策略对开源LLM进行微调。我们支持从小到大的模型规模的大多数公共架构,例如Qwen、Llama、Baichuan和ChatGLM。
-
预测:支持对开源LLM的微调版本以及闭源LLM进行SQL查询推理。我们支持少样本和零样本方法,以生成特定场景下的SQL。
-
评估:包含不同的指标(EX、EM、有效效率得分(VES)),以从不同角度评估生成的SQL的性能。
数据集
-
Spider是一个大规模跨域数据集,包含10,181个自然语言查询,5,693个独特的复杂SQL查询,涵盖200个数据库,覆盖138个领域。该数据集的标准协议将其分为8,659个训练样本和2,147个测试样本,分布在34个数据库中。SQL查询分为四个难度级别,即简单、中等、困难和超难。如下:
-
简单:
Question: What is the number of cars with more than 4 cylinders?
SQL:SELECT COUNT (*)FROM cars_dataWHERE cylinders > 4
-
中等:
Question: For each stadium, how many concerts are there?
SQL:SELECT T2.name, COUNT (*) FROM concert AS T1 JOIN stadium AS T2ON T1.stadium_id = T2.stadium_idGROUP BY T1.stadium_id
-
较难
Question: Which countries in Europe have at least 3 car manufacturers?
SQL:SELECT T1.country name FROM countries AS T1 JOIN continents AS T2 ON T1.continent T2.cont_id JOIN car makers AS T3 ON T1.country_id = T3.country WHERE T2.continent = ‘Europe’ GROUPBY T1.country_name HAVINGCOUNT (*) >= 3
-
极难
Question: What is the average life expectancy in the countries where English is not the official language?
SQL:SELECT AVG(life_expectancy) FROM country WHERE name NOT IN ( SELECT T1.name FROM country AS T1 JOIN country_language AS T2 ON T1.code = T2.country_code WHERE T2.language = “English” AND T2.is_official = “T”)
-
其他数据集:
-
WikiSQL: 一个大型的语义解析数据集,由80,654个自然语句表述和24,241张表格的sql标注构成。WikiSQL中每一个问句的查询范围仅限于同一张表,不包含排序、分组、子查询等复杂操作。
-
CHASE: 一个跨领域多轮交互text2sql中文数据集,包含5459个多轮问题组成的列表,一共17940个<query, SQL>二元组,涉及280个不同领域的数据库。
-
BIRD-SQL:数据集是一个英文的大规模跨领域文本到SQL基准测试,特别关注大型数据库内容。该数据集包含12,751对文本到SQL数据对和95个数据库,总大小为33.4GB,跨越37个职业领域。BIRD-SQL数据集通过探索三个额外的挑战,即处理大规模和混乱的数据库值、外部知识推理和优化SQL执行效率,缩小了文本到SQL研究与实际应用之间的差距。
-
CoSQL:是一个用于构建跨域对话文本到sql系统的语料库。它是Spider和SParC任务的对话版本。CoSQL由30k+回合和10k+带注释的SQL查询组成,这些查询来自Wizard-of-Oz的3k个对话集合,查询了跨越138个领域的200个复杂数据库。每个对话都模拟了一个真实的DB查询场景,其中一个工作人员作为用户探索数据库,一个SQL专家使用SQL检索答案,澄清模棱两可的问题,或者以其他方式通知。
按照NSQL的处理模板,对数据集做简单处理,共得到约20w条训练数据
为了充分利用数据库中的表和字段等相关信息,对 Spider 中的原始数据进行处理,用自然语言表示数据库包含的表结构以及表结构包含的字段以及相应的主键和外键等,经过数据预处理后,可以得到如下的数据格式:
{
"instruction": "concert_singer(数据库名) contains tables(表) such as stadium, singer, concert, singer_in_concert. Table stadium has columns(列) such as stadium_id, location, name, capacity, highest, lowest, average. stadium_id is the primary key(主键). Table singer has columns such as singer_id, name, country, song_name, song_release_year, age, is_male. singer_id is the primary key. Table concert has columns such as concert_id, concert_name, theme, stadium_id, year. concert_id is the primary key. Table singer_in_concert has columns such as concert_id, singer_id. concert_id is the primary key. The year of concert is the foreign key(外键)of location of stadium. The stadium_id of singer_in_concert is the foreign key of name of singer. The singer_id of singer_in_concert is the foreign key of concert_name of concert.",
"input": "How many singers do we have?",
"response": "select count(*) from singer"
}
{
"instruction": "concert_singer(数据库名)包含表(表),例如stadium, singer, concert, singer_in_concert。表体育场有列(列),如stadium_id、位置、名称、容量、最高、最低、平均。Stadium_id是主键(主键)。表singer有这样的列:singer_id、name、country、song_name、song_release_year、age、is_male。Singer_id为主键。表concert有如下列:concert_id、concert_name、theme、stadium_id、year。Concert_id是主键。表singer_in_concert有如下列:concert_id, singer_id。Concert_id是主键。演唱会年份是场馆位置的外键(外键)。singer_in_concert的stadium_id是歌手名的外键。singer_in_concert的singer_id是concert的concert_name的外键。",
"input": "我们有多少歌手?",
"response": "select count(*) from singer"
}
为了更好的利用大语言模型的理解能力,使用了 prompt dict 以优化输入,如下所示:
SQL_PROMPT_DICT = {
"prompt_input": (
"I want you to act as a SQL terminal in front of an example database. "
"Below is an instruction that describes a task, Write a response that appropriately completes the request.\n\n"
"###Instruction:\n{instruction}\n\n###Input:\n{input}\n\n###Response: "
),
"prompt_no_input": (
"I want you to act as a SQL terminal in front of an example database. "
"Below is an instruction that describes a task, Write a response that appropriately completes the request.\n\n"
"###Instruction:\n{instruction}\n\n### Response: "
),
}
指标:
使用两个常用的指标,精确集匹配准确率(EM)和执行准确率(EX),来评估所有模型的性能。EM衡量预测SQL查询与其对应真实值之间的匹配SQL关键字,而EX比较预测SQL查询的执行输出与某些数据库实例上真实SQL查询的执行输出。由于对于给定问题可能存在多个有效的SQL查询,EX提供了对模型性能更精确的估计。在这两种指标中,较高的值被认为更好。主要使用EX来评估论文中SQL的准确性。
LLM
为了确保公平比较,对所有LLM使用相同的最大上下文长度2048。在评估期间,留下512个token用于响应生成。将参数温度设置为0,以消除随机性的影响。
实验
参考文献
- paper:DB-GPT-Hub: Towards Open Benchmarking Text-to-SQL Empowered by Large Language Models,https://arxiv.org/abs/2406.11434
- code:https://github.com/eosphoros-ai/DB-GPT-Hub
- text2sql benchmark:https://github.com/eosphoros-ai/Awesome-Text2SQL