当前位置: 首页 > article >正文

Spider 数据集上实现nlp2sql训练任务

NLP2SQL(自然语言处理到 SQL 查询的转换)是一个重要的自然语言处理(NLP)任务,其目标是将用户的自然语言问题转换为相应的 SQL 查询。这一任务在许多场景下具有广泛的应用,尤其是在与数据库交互的场景中,例如数据分析、业务智能和问答系统。

任务目标
  • 理解自然语言: 理解用户输入的自然语言问题,包括意图、实体和上下文。
  • 生成 SQL 查询: 将理解后的信息转换为正确的 SQL 查询,以从数据库中检索所需的数据。

例如

输入: 用户的自然语言问题,“获取 Gelderland 区的总人口。”

输出: 对应的 SQL 查询,SELECT population FROM districts WHERE name = 'Gelderland';

Spider 是一个难度最大数据集

耶鲁大学在2018年新提出的一个大规模的NL2SQL(Text-to-SQL)数据集。
该数据集包含了10,181条自然语言问句、分布在200个独立数据库中的5,693条SQL,内容覆盖了138个不同的领域。
涉及的SQL语法最全面,是目前难度最大的NL2SQL数据集。

下载查看spider数据集内容

Question 1: How many singers do we have ? ||| concert_singer
SQL: select count(*) from singer

Question 2: What is the total number of singers ? ||| concert_singer
SQL: select count(*) from singer

Question 3: Show name , country , age for all singers ordered by age from the oldest to the youngest . ||| concert_singer
SQL: select name , country , age from singer order by age desc

...

首先需要转换为Spider的标准格式(参考tables.jsontrain.json):

{
  "db_id": "concert_singer",
  "question": "Show name, country, age...",
  "query": "SELECT name, country, age FROM singer ORDER BY age DESC",
  "schema": {
    "table_names": ["singer"],
    "column_names": [
      [0, "name", "text"],
      [0, "country", "text"],
      [0, "age", "int"]
    ]
  }
}

拆分为table.json的原因可能涉及到数据组织和重用。每个数据库的结构(表、列、外键)在多个问题中都会被重复使用。如果每个问题都附带完整的schema信息,会导致数据冗余,增加存储和处理的开销。所以,将schema单独存储为table.json,可以让不同的数据条目引用同一个数据库模式,减少重复数据。拆分后的结构需要更高效的数据管理,例如在训练模型时,根据每个问题的db_id去table.json中查找对应的schema信息。这样做的好处是当多个问题属于同一个数据库时,不需要每次都重复加载schema提高了效率。

column_names 表示数据库表中每一列的详细信息。具体来说,column_names 是一个列表,其中每个元素都是一个包含三个部分的子列表:

  1. 表索引(0):表示该列属于哪个表。在这个例子中,所有列都属于第一个表(索引为 0)。
  2. 列名("name"、"country"、"age"):表示列的名称。
  3. 数据类型("text"、"int"):表示该列的数据类型,例如文本(text)或整数(int)。

实现下面逻辑转换原始数据

def extract_columns_from_sql(sql):
    # 使用正则表达式匹配 SELECT 语句中的列名
    match = re.search(r"SELECT\s+(.*?)\s+FROM", sql, re.IGNORECASE)
    if match:
        # 提取列名
        columns = match.group(1).split(",")
        # 构建 column_names 列表
        column_names = []
        for index, column in enumerate(columns):
            column = column.strip()  # 去除多余的空格
            data_type = "text"  # 默认数据类型为 text,可以根据需要修改
            # 添加到 column_names 列表,假设所有列类型为 text
            column_names.append([0, column, data_type])
        return column_names
    return []

# 从 dev.sql 文件读取数据
def load_sql_data(file_path):
    data_list = []
    with open(file_path, 'r', encoding='utf-8') as f:  # 指定编码为 UTF-8
        lines = f.readlines()
        for i in range(0, len(lines), 3):  # 每三行一组
            question_line = lines[i].strip()
            sql_line = lines[i + 1].strip()

            if not question_line or not sql_line:
                continue

            # 提取问题和 SQL
            question = question_line.split(': ', 1)[1].strip()  # 获取问题内容
            sql = sql_line.split(': ', 1)[1].strip()  # 获取 SQL 查询

            # 提取表名
            db_id = question_line.split('|||')[-1].strip()  # 从问题行获取表名
            question = question.split('|||')[0].strip()

            target_sql = preprocess(question, db_id, sql)

            data_list.append({
                "input_text": f"Translate to SQL: {question} [SEP] Tables: {db_id}",
                "target_sql": json.dumps(target_sql)  # 将目标 SQL 转换为 JSON 格式字符串
            })
    return data_list

选择Tokenizer.from_pretrained("t5-base") 是用于加载 T5(Text-to-Text Transfer Transformer)模型的分词器。T5 是一个强大的自然语言处理模型,能够处理各种文本任务(如翻译、摘要、问答等),并且将所有任务视为文本到文本的转换。

from transformers import T5Tokenizer

tokenizer = T5Tokenizer.from_pretrained("t5-base")

def preprocess(question, db_id, sql):
    # 提取列名
    column_names = extract_columns_from_sql(sql)

    # 构建目标格式
    target_sql = {
        "db_id": db_id,
        "question": question,
        "query": sql,
        "schema": {
            "table_names": [db_id],
            "column_names": column_names
        }
    }
    return target_sql# 

示例数据
question = "Show name, country, age for all singers ordered by age from the oldest to the youngest."
schema = "singer(name, country, age)"
sql = "SELECT name, country, age FROM singer ORDER BY age DESC"

input_text, target_sql = preprocess(question, schema, sql)
# input_text = "Translate to SQL: Show name... [SEP] Tables: singer(name, country, age)"
# target_sql = "select name, country, age from singer order by age desc"
print('input_text', input_text)
print('target_sql', target_sql)

所有nlp任务都涉及的需要token化,使用t5-base 做tokenize

def tokenize_function(examples):
    model_inputs = tokenizer(
        examples["input_text"],
        max_length=512,
        truncation=True,
        padding="max_length"
    )
    with tokenizer.as_target_tokenizer():
        labels = tokenizer(
            examples["target_sql"],
            max_length=512,
            truncation=True,
            padding="max_length"
        )
    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

使用 tokenizer.as_target_tokenizer() 上下文管理器,确保目标文本(即 SQL 查询)被正确处理。目标文本也经过编码,转换为 token IDs,并同样进行填充和截断。将目标文本的编码结果(token IDs)存储在 model_inputs["labels"] 中。这是模型在训练时需要的输出,用于计算损失。最终返回一个字典 model_inputs,它包含了模型的输入和对应的标签。这种结构使得模型在训练时可以直接使用。

最后组织下训练代码

tokenized_datasets = dataset.map(tokenize_function, batched=True)

# 加载模型
model = T5ForConditionalGeneration.from_pretrained("t5-base")

# 训练参数
training_args = Seq2SeqTrainingArguments(
    output_dir="./results",
    evaluation_strategy="epoch",
    learning_rate=3e-5,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    num_train_epochs=100,
    predict_with_generate=True,
    run_name="spider"
)

# 开始训练
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_datasets["train"] if 'train' in tokenized_datasets else tokenized_datasets,
    eval_dataset=tokenized_datasets["test"] if 'test' in tokenized_datasets else None,
    data_collator=DataCollatorForSeq2Seq(tokenizer)
)

trainer.train()

这里使用的是Seq2SeqTrainer, 它是 Hugging Face 的 transformers 库中用于序列到序列(Seq2Seq)任务的训练器。它为处理诸如翻译、文本生成和问答等任务提供了一个高层次的接口,简化了训练过程。以下是 Seq2SeqTrainer 的主要功能和特点:

  1. 简化训练流程Seq2SeqTrainer 封装了许多常见的训练步骤,如数据加载、模型训练、评估和预测,使得用户可以更专注于模型和数据,而不必处理繁琐的训练细节。

  2. 支持多种训练参数: 通过 Seq2SeqTrainingArguments 类,可以灵活配置训练参数,如学习率、批量大小、训练轮数、评估策略等。

  3. 自动处理填充和截断: 在处理输入和输出序列时,Seq2SeqTrainer 可以自动填充和截断序列,以确保它们适应模型的输入要求。

  4. 集成评估和监控: 支持在训练过程中进行模型评估,并可以根据评估指标(如损失)监控训练进度。用户可以设置评估频率和评估数据集

开始训练,进行100次epoch

训练监控在 Weights & Biases ,Seq2SeqTrainer 能够向 Weights & Biases (wandb) 传输训练监控数据,主要是因为它内置了与 wandb 的集成。以下是一些关键点,解释了这一过程:

  1. 自动集成:当你使用 Seq2SeqTrainer 时,它会自动检测 wandb 的安装并在初始化时配置相关设置。这意味着你无需手动设置 wandb。

  2. 回调功能Trainer 类提供了回调功能,可以在训练过程中记录各种指标(如损失、准确率等)。这些指标会被自动发送到 wandb。

  3. 配置管理training_args 中的参数可以指定 wandb 的项目名称、运行名称等,从而更好地组织和管理实验。

  4. 训练循环:在每个训练和评估周期结束时,Trainer 会调用相应的回调函数,将重要的训练信息(如损失、学习率等)记录到 wandb。

  5. 可视化:通过 wandb,你可以实时监控训练过程,包括损失曲线、模型性能等,帮助你更好地理解模型的训练动态。

多次试验还可以比较训练性能

训练结束, 损失收敛到0.05410315271151268

{'eval_loss': 0.008576861582696438, 'eval_runtime': 1.3883, 'eval_samples_per_second': 74.912, 'eval_steps_per_second': 5.042, 'epoch': 100.0}
{'train_runtime': 2914.0548, 'train_samples_per_second': 31.914, 'train_steps_per_second': 2.025, 'train_loss': 0.05410315271151268, 'epoch': 100.0}
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5900/5900 [48:31<00:00,  2.03it/s]
wandb:
wandb: 🚀 View run spider at: https://wandb.ai/chenruithinking-4th-paradigm/huggingface/runs/dkccvpp4
wandb: Find logs at: wandb/run-20250207_112702-dkccvpp4/logs

测试下预测能力

import os
from transformers import T5Tokenizer, T5ForConditionalGeneration

# 设置 NCCL 环境变量
os.environ["NCCL_P2P_DISABLE"] = "1"
os.environ["NCCL_IB_DISABLE"] = "1"

# 加载分词器
tokenizer = T5Tokenizer.from_pretrained("t5-base")


model = T5ForConditionalGeneration.from_pretrained("./results/t5-sql-model")
tokenizer.save_pretrained("./results/t5-sql-model")

def generate_sql(question, db_id):
    input_text = f"Translate to SQL: {question} [SEP] Tables: {db_id}"
    input_ids = tokenizer.encode(input_text, return_tensors="pt")  # 使▒~T▒ PyTorch ▒~Z~D▒| ▒~G~O▒| ▒▒~O
    output = model.generate(
        input_ids,
        max_length=512,
        num_beams=5,  # 或者尝试其他解码策略
        early_stopping=True
    )

    print('output', output)
    generated_sql = tokenizer.decode(output[0], skip_special_tokens=True)
    return generated_sql

question = "How many singers do we have ?"
db_id = "concert_singer"
evaluation_output = generate_sql(question, db_id)
print("evaluation_output:", evaluation_output)

输出结果

evaluation_output: "db_id": "concert_singer", "question": "How many singers do we have ?", "query": "select count(*) from singer", "schema": "table_names": ["concert_singer"], "column_names": [[0, "count(*)", "text"]]


http://www.kler.cn/a/536512.html

相关文章:

  • 2024~2025学年佛山市普通高中教学质量检测(一)【高三数学】
  • 判断您的Mac当前使用的是Zsh还是Bash:echo $SHELL、echo $0
  • Linux 命令行指南
  • 第6章《VTK与Qt集成》
  • 一个可以在浏览器console内运行的极简爬虫,可列出网页内指定关键词的所有句子。
  • tolua[一]框架搭建,运行example
  • SpringCloud面试题----SpringCloud和Dubbo有什么区别
  • Synchronized和ReentrantLock面试详解
  • 第4章 Jetpack Compose提供了一系列的布局组件
  • 【Elasticsearch】分桶聚合功能概述
  • Windows上工程组织方式 --- dll插件式
  • 本地缓存怎么保证数据一致性?
  • pikachu[皮卡丘] 靶场全级别通关教程答案 以及 学习方法 如何通过渗透测试靶场挑战「pikachu」来精通Web渗透技巧? 一篇文章搞完这些问题
  • 高级测试工程师,在数据安全方面,如何用AI提升?DeepSpeek的回答
  • iOS pod install一直失败,访问github超时记录
  • LabVIEW位移测量系统
  • 06vue3实战-----项目开发准备
  • windows部署本地deepseek
  • arkui-x 鼠标切换为键盘,焦点衔接问题
  • 【实战篇】DeepSeek + Cline 编程实战:从入门到“上头”
  • STM32上部署AI的两个实用软件——Nanoedge AI Studio和STM32Cube AI
  • 流媒体缓存管理策略
  • Python的那些事第十四篇:Flask与Django框架的趣味探索之旅
  • 阿里云cdn怎样设置图片压缩
  • 【Spring】_SpringBoot配置文件
  • Jetpack ViewModel