Spider 数据集上实现nlp2sql训练任务
NLP2SQL(自然语言处理到 SQL 查询的转换)是一个重要的自然语言处理(NLP)任务,其目标是将用户的自然语言问题转换为相应的 SQL 查询。这一任务在许多场景下具有广泛的应用,尤其是在与数据库交互的场景中,例如数据分析、业务智能和问答系统。
任务目标
- 理解自然语言: 理解用户输入的自然语言问题,包括意图、实体和上下文。
- 生成 SQL 查询: 将理解后的信息转换为正确的 SQL 查询,以从数据库中检索所需的数据。
例如
输入: 用户的自然语言问题,“获取 Gelderland 区的总人口。”
输出: 对应的 SQL 查询,
SELECT population FROM districts WHERE name = 'Gelderland';
Spider 是一个难度最大数据集
耶鲁大学在2018年新提出的一个大规模的NL2SQL(Text-to-SQL)数据集。
该数据集包含了10,181条自然语言问句、分布在200个独立数据库中的5,693条SQL,内容覆盖了138个不同的领域。
涉及的SQL语法最全面,是目前难度最大的NL2SQL数据集。
下载查看spider数据集内容
Question 1: How many singers do we have ? ||| concert_singer
SQL: select count(*) from singer
Question 2: What is the total number of singers ? ||| concert_singer
SQL: select count(*) from singer
Question 3: Show name , country , age for all singers ordered by age from the oldest to the youngest . ||| concert_singer
SQL: select name , country , age from singer order by age desc...
首先需要转换为Spider的标准格式(参考tables.json
和train.json
):
{
"db_id": "concert_singer",
"question": "Show name, country, age...",
"query": "SELECT name, country, age FROM singer ORDER BY age DESC",
"schema": {
"table_names": ["singer"],
"column_names": [
[0, "name", "text"],
[0, "country", "text"],
[0, "age", "int"]
]
}
}
拆分为table.json的原因可能涉及到数据组织和重用。每个数据库的结构(表、列、外键)在多个问题中都会被重复使用。如果每个问题都附带完整的schema信息,会导致数据冗余,增加存储和处理的开销。所以,将schema单独存储为table.json,可以让不同的数据条目引用同一个数据库模式,减少重复数据。拆分后的结构需要更高效的数据管理,例如在训练模型时,根据每个问题的db_id去table.json中查找对应的schema信息。这样做的好处是当多个问题属于同一个数据库时,不需要每次都重复加载schema提高了效率。
column_names
表示数据库表中每一列的详细信息。具体来说,column_names
是一个列表,其中每个元素都是一个包含三个部分的子列表:
- 表索引(0):表示该列属于哪个表。在这个例子中,所有列都属于第一个表(索引为 0)。
- 列名("name"、"country"、"age"):表示列的名称。
- 数据类型("text"、"int"):表示该列的数据类型,例如文本(
text
)或整数(int
)。
实现下面逻辑转换原始数据
def extract_columns_from_sql(sql):
# 使用正则表达式匹配 SELECT 语句中的列名
match = re.search(r"SELECT\s+(.*?)\s+FROM", sql, re.IGNORECASE)
if match:
# 提取列名
columns = match.group(1).split(",")
# 构建 column_names 列表
column_names = []
for index, column in enumerate(columns):
column = column.strip() # 去除多余的空格
data_type = "text" # 默认数据类型为 text,可以根据需要修改
# 添加到 column_names 列表,假设所有列类型为 text
column_names.append([0, column, data_type])
return column_names
return []
# 从 dev.sql 文件读取数据
def load_sql_data(file_path):
data_list = []
with open(file_path, 'r', encoding='utf-8') as f: # 指定编码为 UTF-8
lines = f.readlines()
for i in range(0, len(lines), 3): # 每三行一组
question_line = lines[i].strip()
sql_line = lines[i + 1].strip()
if not question_line or not sql_line:
continue
# 提取问题和 SQL
question = question_line.split(': ', 1)[1].strip() # 获取问题内容
sql = sql_line.split(': ', 1)[1].strip() # 获取 SQL 查询
# 提取表名
db_id = question_line.split('|||')[-1].strip() # 从问题行获取表名
question = question.split('|||')[0].strip()
target_sql = preprocess(question, db_id, sql)
data_list.append({
"input_text": f"Translate to SQL: {question} [SEP] Tables: {db_id}",
"target_sql": json.dumps(target_sql) # 将目标 SQL 转换为 JSON 格式字符串
})
return data_list
选择Tokenizer.from_pretrained("t5-base")
是用于加载 T5(Text-to-Text Transfer Transformer)模型的分词器。T5 是一个强大的自然语言处理模型,能够处理各种文本任务(如翻译、摘要、问答等),并且将所有任务视为文本到文本的转换。
from transformers import T5Tokenizer
tokenizer = T5Tokenizer.from_pretrained("t5-base")
def preprocess(question, db_id, sql):
# 提取列名
column_names = extract_columns_from_sql(sql)
# 构建目标格式
target_sql = {
"db_id": db_id,
"question": question,
"query": sql,
"schema": {
"table_names": [db_id],
"column_names": column_names
}
}
return target_sql#
示例数据
question = "Show name, country, age for all singers ordered by age from the oldest to the youngest."
schema = "singer(name, country, age)"
sql = "SELECT name, country, age FROM singer ORDER BY age DESC"
input_text, target_sql = preprocess(question, schema, sql)
# input_text = "Translate to SQL: Show name... [SEP] Tables: singer(name, country, age)"
# target_sql = "select name, country, age from singer order by age desc"
print('input_text', input_text)
print('target_sql', target_sql)
所有nlp任务都涉及的需要token化,使用t5-base 做tokenize
def tokenize_function(examples):
model_inputs = tokenizer(
examples["input_text"],
max_length=512,
truncation=True,
padding="max_length"
)
with tokenizer.as_target_tokenizer():
labels = tokenizer(
examples["target_sql"],
max_length=512,
truncation=True,
padding="max_length"
)
model_inputs["labels"] = labels["input_ids"]
return model_inputs
使用 tokenizer.as_target_tokenizer()
上下文管理器,确保目标文本(即 SQL 查询)被正确处理。目标文本也经过编码,转换为 token IDs,并同样进行填充和截断。将目标文本的编码结果(token IDs)存储在 model_inputs["labels"]
中。这是模型在训练时需要的输出,用于计算损失。最终返回一个字典 model_inputs
,它包含了模型的输入和对应的标签。这种结构使得模型在训练时可以直接使用。
最后组织下训练代码
tokenized_datasets = dataset.map(tokenize_function, batched=True)
# 加载模型
model = T5ForConditionalGeneration.from_pretrained("t5-base")
# 训练参数
training_args = Seq2SeqTrainingArguments(
output_dir="./results",
evaluation_strategy="epoch",
learning_rate=3e-5,
per_device_train_batch_size=8,
per_device_eval_batch_size=8,
num_train_epochs=100,
predict_with_generate=True,
run_name="spider"
)
# 开始训练
trainer = Seq2SeqTrainer(
model=model,
args=training_args,
train_dataset=tokenized_datasets["train"] if 'train' in tokenized_datasets else tokenized_datasets,
eval_dataset=tokenized_datasets["test"] if 'test' in tokenized_datasets else None,
data_collator=DataCollatorForSeq2Seq(tokenizer)
)
trainer.train()
这里使用的是Seq2SeqTrainer, 它
是 Hugging Face 的 transformers
库中用于序列到序列(Seq2Seq)任务的训练器。它为处理诸如翻译、文本生成和问答等任务提供了一个高层次的接口,简化了训练过程。以下是 Seq2SeqTrainer
的主要功能和特点:
-
简化训练流程:
Seq2SeqTrainer
封装了许多常见的训练步骤,如数据加载、模型训练、评估和预测,使得用户可以更专注于模型和数据,而不必处理繁琐的训练细节。 -
支持多种训练参数: 通过
Seq2SeqTrainingArguments
类,可以灵活配置训练参数,如学习率、批量大小、训练轮数、评估策略等。 -
自动处理填充和截断: 在处理输入和输出序列时,
Seq2SeqTrainer
可以自动填充和截断序列,以确保它们适应模型的输入要求。 -
集成评估和监控: 支持在训练过程中进行模型评估,并可以根据评估指标(如损失)监控训练进度。用户可以设置评估频率和评估数据集
开始训练,进行100次epoch
训练监控在 Weights & Biases ,Seq2SeqTrainer
能够向 Weights & Biases (wandb) 传输训练监控数据,主要是因为它内置了与 wandb 的集成。以下是一些关键点,解释了这一过程:
-
自动集成:当你使用
Seq2SeqTrainer
时,它会自动检测 wandb 的安装并在初始化时配置相关设置。这意味着你无需手动设置 wandb。 -
回调功能:
Trainer
类提供了回调功能,可以在训练过程中记录各种指标(如损失、准确率等)。这些指标会被自动发送到 wandb。 -
配置管理:
training_args
中的参数可以指定 wandb 的项目名称、运行名称等,从而更好地组织和管理实验。 -
训练循环:在每个训练和评估周期结束时,
Trainer
会调用相应的回调函数,将重要的训练信息(如损失、学习率等)记录到 wandb。 -
可视化:通过 wandb,你可以实时监控训练过程,包括损失曲线、模型性能等,帮助你更好地理解模型的训练动态。
多次试验还可以比较训练性能
训练结束, 损失收敛到0.05410315271151268
{'eval_loss': 0.008576861582696438, 'eval_runtime': 1.3883, 'eval_samples_per_second': 74.912, 'eval_steps_per_second': 5.042, 'epoch': 100.0}
{'train_runtime': 2914.0548, 'train_samples_per_second': 31.914, 'train_steps_per_second': 2.025, 'train_loss': 0.05410315271151268, 'epoch': 100.0}
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5900/5900 [48:31<00:00, 2.03it/s]
wandb:
wandb: 🚀 View run spider at: https://wandb.ai/chenruithinking-4th-paradigm/huggingface/runs/dkccvpp4
wandb: Find logs at: wandb/run-20250207_112702-dkccvpp4/logs
测试下预测能力
import os
from transformers import T5Tokenizer, T5ForConditionalGeneration
# 设置 NCCL 环境变量
os.environ["NCCL_P2P_DISABLE"] = "1"
os.environ["NCCL_IB_DISABLE"] = "1"
# 加载分词器
tokenizer = T5Tokenizer.from_pretrained("t5-base")
model = T5ForConditionalGeneration.from_pretrained("./results/t5-sql-model")
tokenizer.save_pretrained("./results/t5-sql-model")
def generate_sql(question, db_id):
input_text = f"Translate to SQL: {question} [SEP] Tables: {db_id}"
input_ids = tokenizer.encode(input_text, return_tensors="pt") # 使▒~T▒ PyTorch ▒~Z~D▒| ▒~G~O▒| ▒▒~O
output = model.generate(
input_ids,
max_length=512,
num_beams=5, # 或者尝试其他解码策略
early_stopping=True
)
print('output', output)
generated_sql = tokenizer.decode(output[0], skip_special_tokens=True)
return generated_sql
question = "How many singers do we have ?"
db_id = "concert_singer"
evaluation_output = generate_sql(question, db_id)
print("evaluation_output:", evaluation_output)
输出结果
evaluation_output: "db_id": "concert_singer", "question": "How many singers do we have ?", "query": "select count(*) from singer", "schema": "table_names": ["concert_singer"], "column_names": [[0, "count(*)", "text"]]