当前位置: 首页 > article >正文

transformers训练(NLP)阅读理解(多项选择)

简介

在阅读理解任务中,有一种通过多项选择其中一个答案来训练机器的阅读理解。比如:给定一个或多个文档h,以及一个问题S和对应的多个答案候选,输出问题S的答案E,E是答案候选中的某一个选项。
这样的目的就是通过文档,问题,多个答案选其中一个,就是让机器更好文档中的含义,提高机器的阅读理解。
是不是感觉似陈相识,这不就是语文考试的必考题,阅读理解吗。。。。。。。。。。。。。。

给机器训练的数据集示例

如:

  • Context:老师把一个大玻璃瓶子带到学校,瓶子里装着满满的石头、玻璃碎片和沙子。之后,老师请学生把瓶子里的东西都倒出来,然后再装进去,先从沙子开始。每个学生都试了试,最后都发现没有足够的空间装所有的石头。老师指导学生重新装这个瓶子。这次,先从石头开始,最后再装沙子。石头装进去后,沙子就沉积在石头的周围,最后,所有东西都装进瓶子里了。老师说:“如果我们先从小的东西开始,把小东西装进去之后,大的石头就放不进去了。生活也是如此,如果你的生活先被不重要的事挤满了,那你就无法再装进更大、更重要的事了。”
  • Question:正确的装法是,先装?
  • Choices / Candidates:[“小东西”,“大东西”,“轻的东西”,“软的东西” ]
  • Answer:大东西

技术实现思路

多项选择任务,技术实现,这里难点涉及数据处理和训练与推理。其实就是将数据处理好,喂给模型进行训练与推理,让其理解文本。
这里采用格式是:

[CLS] 文本内容 [SEP] 问题 答案 [SEP]

这里涉及在多个候选答案中只取一个答案,让大模型理解文本。所以需要将文本、问题、答案,拆分为4条数据喂给大模型,在告知它正确答案,这样处理大模型才能读懂数据,也是它读取数据逻辑。

代码部分

# 导入包
import evaluate
import numpy as np
from datasets import DatasetDict
from transformers import AutoTokenizer, AutoModelForMultipleChoice, TrainingArguments, Trainer

# 读取数据集
c3 = DatasetDict.load_from_disk("./c3/")

# 打印训练的前条数据开下接口
examples=c3["train"][:5]
 print(examples)
 
# 将数据处理为训练和测试
c3.pop("test")

# 加载分词器,模型已下载到本地
tokenizer = AutoTokenizer.from_pretrained("chinese-macbert-large")

# 对数据集进行处理,验证数据处理阶段
question_choice = []
labels = []
for idx in range(len(examples["context"])):
    ctx = "\n".join(examples["context"][idx])
    question = examples["question"][idx]
    choices = examples["choice"][idx]
    for choice in choices:
        context.append(ctx)
        question_choice.append(question + " " + choice)
    if len(choices) < 4:
        for _ in range(4 - len(choices)):
            context.append(ctx)
            question_choice.append(question + " " + "不知道")
    print("========:", choices.index(examples["answer"][idx]))
    labels.append(choices.index(examples["answer"][idx]))
tokenized_examples = tokenizer(context, question_choice, truncation="only_first", max_length=256, padding="max_length")     # input_ids: 4000 * 256,
for k, v in tokenized_examples.items():
    print("k:", k, "v:", v)
tokenized_examples = {k: [v[i: i + 4] for i in range(0, len(v), 4)] for k, v in tokenized_examples.items()}     # 1000 * 4 *256
tokenized_examples["labels"] = labels

# 处理数据函数
def process_function(examples):
    # examples, dict, keys: ["context", "quesiton", "choice", "answer"]
    # examples, 1000
    context = []
    question_choice = []
    labels = []
    for idx in range(len(examples["context"])):
        ctx = "\n".join(examples["context"][idx])
        question = examples["question"][idx]
        choices = examples["choice"][idx]
        for choice in choices:
            context.append(ctx)
            question_choice.append(question + " " + choice)
        if len(choices) < 4:
            for _ in range(4 - len(choices)):
                context.append(ctx)
                question_choice.append(question + " " + "不知道")
        labels.append(choices.index(examples["answer"][idx]))
    # 使用分词器,对数据进行分词处理    
    tokenized_examples = tokenizer(context, question_choice, truncation="only_first", max_length=256, padding="max_length")     # input_ids: 4000 * 256,
    tokenized_examples = {k: [v[i: i + 4] for i in range(0, len(v), 4)] for k, v in tokenized_examples.items()}     # 1000 * 4 *256
    tokenized_examples["labels"] = labels
    return tokenized_examples

# 处理数据
tokenized_c3 = c3.map(process_function, batched=True)

# 加载模型
model = AutoModelForMultipleChoice.from_pretrained("chinese-macbert-large")

# 创建评估函数
accuracy = evaluate.load("accuracy")

def compute_metric(pred):
    predictions, labels = pred
    predictions = np.argmax(predictions, axis=-1)
    return accuracy.compute(predictions=predictions, references=labels)
    
# 配置训练参数
args = TrainingArguments(
    output_dir="./muliple_choice",
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=1,
    logging_steps=50,
    eval_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    fp16=True
) 

# 创建训练器 
trainer = Trainer(
    model=model,
    args=args,
    tokenizer=tokenizer,
    train_dataset=tokenized_c3["train"],
    eval_dataset=tokenized_c3["validation"],
    compute_metrics=compute_metric
)  

# 模型训练
trainer.train()

# 模型预测
from typing import Any
import torch

class MultipleChoicePipeline:

    def __init__(self, model, tokenizer) -> None:
        self.model = model
        self.tokenizer = tokenizer
        self.device = model.device

    def preprocess(self, context, quesiton, choices):
        cs, qcs = [], []
        for choice in choices:
            cs.append(context)
            qcs.append(quesiton + " " + choice)
        return tokenizer(cs, qcs, truncation="only_first", max_length=256, return_tensors="pt")

    def predict(self, inputs):
        inputs = {k: v.unsqueeze(0).to(self.device) for k, v in inputs.items()}
        return self.model(**inputs).logits

    def postprocess(self, logits, choices):
        predition = torch.argmax(logits, dim=-1).cpu().item()
        return choices[predition]

    def __call__(self, context, question, choices) -> Any:
        inputs = self.preprocess(context, question, choices)
        logits = self.predict(inputs)
        result = self.postprocess(logits, choices)
        return result

pipe = MultipleChoicePipeline(model, tokenizer)
pipe("小明在北京上班", "小明在哪里上班?", ["北京", "上海", "河北", "海南", "河北", "海南"])

以上就是完整多项选择阅读理解的大模型训练代码

问题

1,这里的难点就是多项选择如何让大模型进行阅读理解训练思想,这里参考的就是语文里的阅读理解。
2,将数据处理成什么样子,大模型才能理解,才能去进行正确的训练。


http://www.kler.cn/a/414317.html

相关文章:

  • 如何为 ext2/ext3/ext4 文件系统的 /dev/centos/root 增加 800G 空间
  • Laravel8.5+微信小程序实现京东商城秒杀方案
  • Maven 依赖项配置
  • kotlin 的循环
  • 【k8s深入理解之 Scheme】全面理解 Scheme 的注册机制、内外部版本、自动转换函数、默认填充函数、Options等机制
  • Android按键点击事件三种实现方法
  • 如何优雅的用PyQt访问http(仅供参考)
  • 架构-微服务-服务治理
  • Wrapper包装类
  • 关于在大模型智能体中知识图谱构建与指令应用
  • Go语言中的sync.Pool详解:高效对象复用
  • Kafka面试题(五)--内含面试重点
  • 深度学习:代码预训练模型
  • Golang项目:实现生产者消费者模式
  • uniapp 安卓和ios震动方法,支持息屏和后台震动,ios和安卓均通过测试
  • 数据结构与算法(排序算法)
  • STM32-C语言基础知识
  • TheadLocal出现的内存泄漏具体泄漏的是什么?弱引用在里面有什么作用?什么情景什么问题?
  • 水母形状电池:助力机器人性能提升
  • 【西瓜书】决策树
  • 网络空间安全之一个WH的超前沿全栈技术深入学习之路(13-2)白帽必经之路——如何用Metasploit 渗透到她的心才不会让我释怀
  • 力扣刷题--42.接雨水【图文详解|超级详细】
  • JAVA项目-------医院挂号系统
  • 鸿蒙征文|鸿蒙技术分享:使用到的开发框架和技术概览
  • 一些k8s和docker的命令
  • javaweb-day03-前端零碎