当前位置: 首页 > article >正文

基于 Python 的自然语言处理系列(43):Question Answering

        在这篇文章中,我们将探讨 Question Answering(QA),即问答任务。这种任务有很多不同的形式,但我们将重点放在抽取式问答上。抽取式问答涉及根据给定的文档,提出问题并在文档中找到答案对应的文本片段。

        我们将基于 SQuAD 数据集对 BERT 模型进行微调。SQuAD 数据集包含 Wikipedia 文章的相关问题,这些问题是由众包工人提出的。

1. 加载数据

        在学术界,SQuAD 数据集是用于评估抽取式问答的基准数据集,因此我们在此使用该数据集。更具挑战性的 SQuAD v2 基准数据集还包括没有答案的问题。只要你有一个包含上下文(context)、问题(question)和答案(answers)列的数据集,你就可以按照本节中的步骤进行微调。

SQuAD 数据集

        像往常一样,我们可以通过 load_dataset() 下载并缓存数据集:

import os
os.environ['http_proxy']  = ''
os.environ['https_proxy'] = ''

from datasets import load_dataset
raw_datasets = load_dataset("squad")

        让我们看一下该数据集中的字段:

print("Context: ", raw_datasets["train"][0]["context"])
print("Question: ", raw_datasets["train"][0]["question"])
print("Answer: ", raw_datasets["train"][0]["answers"])

        这里我们可以看到 context(上下文)和 question(问题)的字段,而 answers 字段稍微复杂一些,因为它包含两个列表:一个是答案文本,另一个是答案在上下文中的起始位置。

        对于训练数据集,每个问题通常只有一个正确答案:

raw_datasets["train"].filter(lambda x: len(x["answers"]["text"]) != 1)

        在评估数据集中,每个问题可能有多个不同的答案:

print(raw_datasets["validation"][0]["answers"])
print(raw_datasets["validation"][2]["answers"])

2. 数据预处理

        接下来,我们将数据转换为模型可以理解的输入格式。首先,我们需要将问题和上下文合并在一起,并将它们传递给 tokenizer

from transformers import AutoTokenizer

model_checkpoint = "bert-base-cased"
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

context = raw_datasets["train"][0]["context"]
question = raw_datasets["train"][0]["question"]
inputs = tokenizer(question, context)
tokenizer.decode(inputs["input_ids"])

        我们使用滑动窗口来处理超出最大长度的上下文,并设置 stride 来确保上下文片段有足够的重叠:

 inputs = tokenizer(
    question,
    context,
    max_length=384,
    truncation="only_second",
    stride=128,
    return_overflowing_tokens=True,
    return_offsets_mapping=True,
)

    offset_mapping 返回了每个输入 token 在原始文本中的字符范围。我们可以使用它来找到答案的起始和结束位置,生成训练标签:

answers = raw_datasets["train"][0]["answers"]
start_char = answers["answer_start"][0]
end_char = start_char + len(answers["text"][0])

        接下来,我们定义一个函数来对所有训练数据进行预处理:

def preprocess_training_examples(examples):
    questions = [q.strip() for q in examples["question"]]
    inputs = tokenizer(
        questions,
        examples["context"],
        max_length=384,
        truncation="only_second",
        stride=128,
        return_overflowing_tokens=True,
        return_offsets_mapping=True,
        padding="max_length",
    )
    # 对 offset 和答案进行处理
    # 返回 inputs 并附加 start_positions 和 end_positions
    return inputs

train_dataset = raw_datasets["train"].map(
    preprocess_training_examples,
    batched=True,
    remove_columns=raw_datasets["train"].column_names,
)

        验证集的处理与训练集类似,不同之处在于我们不需要生成标签:

def preprocess_validation_examples(examples):
    questions = [q.strip() for q in examples["question"]]
    inputs = tokenizer(
        questions,
        examples["context"],
        max_length=384,
        truncation="only_second",
        stride=128,
        return_overflowing_tokens=True,
        return_offsets_mapping=True,
        padding="max_length",
    )
    return inputs

validation_dataset = raw_datasets["validation"].map(
    preprocess_validation_examples,
    batched=True,
    remove_columns=raw_datasets["validation"].column_names,
)

3. 评估指标

        为了在评估过程中使用 SQuAD 的标准,我们需要加载 evaluate 库中的 squad 度量函数:

import evaluate

metric = evaluate.load("squad")

        在模型输出预测的 start_logitsend_logits 后,我们可以对它们进行后处理,找到最佳的答案 span:

n_best = 20
max_answer_length = 30
predicted_answers = []

for example in raw_datasets["validation"]:
    # 提取 start_logits 和 end_logits 并找到最佳答案 span
    predicted_answers.append({"id": example["id"], "prediction_text": best_answer})

        然后我们计算最终的评估指标:

theoretical_answers = [{"id": ex["id"], "answers": ex["answers"]} for ex in raw_datasets["validation"]]
metric.compute(predictions=predicted_answers, references=theoretical_answers)

4. 模型训练

        首先,构建数据加载器,并配置训练和评估集:

from torch.utils.data import DataLoader
from transformers import default_data_collator

train_dataloader = DataLoader(
    train_dataset,
    shuffle=True,
    collate_fn=default_data_collator,
    batch_size=8,
)
eval_dataloader = DataLoader(
    validation_dataset, collate_fn=default_data_collator, batch_size=8
)

        接着,我们加载预训练的 BERT 模型,并定义优化器和学习率调度器:

from transformers import AutoModelForQuestionAnswering, AdamW, get_scheduler

model = AutoModelForQuestionAnswering.from_pretrained("bert-base-cased")
optimizer = AdamW(model.parameters(), lr=2e-5)
scheduler = get_scheduler(
    "linear", optimizer=optimizer, num_warmup_steps=0, num_training_steps=num_training_steps
)

        最后,我们使用 Accelerator 进行训练:

from accelerate import Accelerator
accelerator = Accelerator()

model, optimizer, train_dataloader, eval_dataloader = accelerator.prepare(
    model, optimizer, train_dataloader, eval_dataloader
)

# 开始训练
for epoch in range(num_train_epochs):
    model.train()
    for step, batch in enumerate(train_dataloader):
        outputs = model(**batch)
        loss = outputs.loss
        accelerator.backward(loss)
        optimizer.step()
        scheduler.step()
        optimizer.zero_grad()

    # 评估模型
    model.eval()
    # 保存模型

5. 推理

        训练完成后,我们可以使用 Hugging Face 的 pipeline 进行推理:

from transformers import pipeline

question_answerer = pipeline("question-answering", model="你的模型名称")
context = "🤗 Transformers 支持三种流行的深度学习框架:Jax、PyTorch 和 TensorFlow。"
question = "🤗 Transformers 支持哪些深度学习框架?"
question_answerer(question=question, context=context)

结语

        通过本文的内容,我们展示了如何使用 Hugging Face 的 Transformers 进行抽取式问答模型的微调。SQuAD 数据集为我们提供了一个良好的基准,而 BERT 模型的强大性能也使得我们可以较好地处理问答任务。在实际项目中,你也可以轻松地将这类模型应用于类似的任务,只需根据你自己的数据集进行适当的调整即可。        

如果你觉得这篇博文对你有帮助,请点赞、收藏、关注我,并且可以打赏支持我!

欢迎关注我的后续博文,我将分享更多关于人工智能、自然语言处理和计算机视觉的精彩内容。

谢谢大家的支持!


http://www.kler.cn/news/364108.html

相关文章:

  • Docker快速安装Grafana
  • 解决:git SSL certificate problem: unable to get local issuer certificate
  • SIP 业务举例之 Call Forwarding - No Answer(无应答呼叫转移)
  • CentOS 7(Linux)详细安装教程
  • Redis Search系列 - 第四讲 支持中文
  • 工作使用的工具
  • 架构设计(17)大数据框架Hadoop与基础架构CDH
  • 又是一年 1024
  • Python酷库之旅-第三方库Pandas(167)
  • 鸿蒙原生 证书 打包到真机
  • 使用docker-compose部署一个springboot项目(包含Postgres\redis\Mongo\Nginx等环境)
  • STL标准容器库
  • 【华为HCIP实战课程十七】OSPF的4类及5类LSA详解,网络工程师
  • nginx------HTTP模块配置详解
  • 什么是虚拟线程?Java 中虚拟线程的介绍与案例演示
  • 【Unity实战笔记】第二一 · 基于状态模式的角色控制——以UnityChan为例
  • ArcGIS计算落入面图层中的线的长度或面的面积
  • 十七、行为型(命令模式)
  • 社区团购在一线城市的新机遇:定制开发小程序助力用户细分
  • Lua简介
  • 【CSS in Depth 2 精译_054】8.2 CSS 层叠图层(cascade layer)的推荐组织方案
  • Redis 安装部署与常用命令
  • 【H2O2|全栈】JS入门知识(八)DOM(2)
  • rabbitmq 使用注意事项
  • JVM 的定义、内部工作原理以及不同 JVM 实现的区别, Oracle JVM 、 OpenJ9、GraalVM对比。
  • 51 单片机[11]:蜂鸣器播放提示音和音乐