当前位置: 首页 > article >正文

心法利器[121] | 读源码:用mT5训练一个自动摘要模型(含代码)

心法利器

本栏目主要和大家一起讨论近期自己学习的心得和体会。具体介绍:仓颉专项:飞机大炮我都会,利器心法我还有。

2023年新的文章合集已经发布,获取方式看这里:又添十万字-CS的陋室2023年文章合集来袭,更有历史文章合集,欢迎下载。

往期回顾

T5目前被广泛应用于大量标榜使用“小模型”的文章中,因此最近我也是自己寻找并尝试了有关代码,把这项技术get了起来,现在我尝试通过我的方式来讲一遍,并和大家分享里面里面发现的细节。

代码基本是从这个网站上搬来的:https://xiaosheng.blog/2022/03/29/transformers-note-8,完整项目代码在:https://github.com/jsksxs360/How-to-use-Transformers/tree/main/src/seq2seq_summarization,这里面有很多介绍,我这里按照我的理解展开聊一下。

代码结构

在这里:

|-- arg_config.py
|-- data
|   |-- lcsts_tsv
|   |   |-- data1.tsv
|   |   |-- data2.tsv
|   |   `-- data3.tsv
|   `-- output
|-- data.py
|-- mt5_summary_main.py
|-- run.sh
`-- tools.py

可以看到,这个项目下的代码结构还是比较简单,主要是因为这个摘要项目本身也是比较简单,是一个非常标准的训练模型的项目,那基本就是模型、训练模型、测试到最后的结果的流程。

  • arg_config.py:通过命令行控制的配置文件。

  • data.py:pytorch所需要的数据类,本文用的是LCSTS(http://icrc.hitsz.edu.cn/Article/show/139.html)

  • mt5_summary_main.py:整体训练的流程类。

  • tools.py:工具类,此处就放了个随机数的设置函数。

  • run.sh:执行用的脚本。

然后就开始逐一阅读吧。

基础代码准备

本章先讲训练之外的准备工作。

命令行配置和执行

run.sh是执行用的shell脚本,首先先看这个入口的脚本吧。

export OUTPUT_DIR=./summ_mt5_results/

python3 run_summarization_mt5.py \
    --output_dir=$OUTPUT_DIR \
    --model_type=mT5 \
    --model_checkpoint=csebuetnlp/mT5_multilingual_XLSum \
    --train_file=./data/lcsts_tsv/data1.tsv \
    --dev_file=./data/lcsts_tsv/data2.tsv \
    --test_file=./data/lcsts_tsv/data3.tsv \
    --max_input_length=512 \
    --max_target_length=32 \
    --learning_rate=1e-5 \
    --num_train_epochs=3 \
    --batch_size=32 \
    --beam_search_size=4 \
    --no_repeat_ngram_size=2 \
    --do_train \
    --warmup_proportion=0. \
    --seed=42

这里其实就两行命令,第一句是定义好输出的路径,这里的输出一般是训练后的模型和输出结果,第二句则是执行训练的脚本,可以看到这里面有很多配置项,这些配置项都是通过arg_config.py来定义的。

具体我们来看arg_config.py内部的定义,这里基本把关键配置都弄好了,3种数据集的路径、模型类型、最大输入和输出长度、训练测试预测模式的选择,还有一些必要的训练参数,都是比较完善的,大家甚至可以把这个当做标准的模板。

import argparse

def parse_args():
    parser = argparse.ArgumentParser()

    # Required parameters
    parser.add_argument("--output_dir", default=None, type=str, required=True,
        help="The output directory where the model checkpoints and predictions will be written.",
    )
    parser.add_argument("--train_file", default=None, type=str, required=True, help="The input training file.")
    parser.add_argument("--dev_file", default=None, type=str, required=True, help="The input evaluation file.")
    parser.add_argument("--test_file", default=None, type=str, required=True, help="The input testing file.")
    
    parser.add_argument("--model_type",
        default="bert", type=str, required=True
    )
    parser.add_argument("--model_checkpoint",
        default="bert-large-cased/", type=str, required=True,
        help="Path to pretrained model or model identifier from huggingface.co/models",
    )
    parser.add_argument("--max_input_length", default=256, type=int, required=True)
    parser.add_argument("--max_target_length", default=256, type=int, required=True)
    
    parser.add_argument("--do_train", action="store_true", help="Whether to run training.")
    parser.add_argument("--do_test", action="store_true", help="Whether to run eval on the test set.")
    parser.add_argument("--do_predict", action="store_true", help="Whether to save predicted labels.")
    
    # Other parameters
    parser.add_argument("--learning_rate", default=1e-5, type=float, help="The initial learning rate for Adam.")
    parser.add_argument("--num_train_epochs", default=3, type=int, help="Total number of training epochs to perform.")
    parser.add_argument("--batch_size", default=4, type=int)
    parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")
    parser.add_argument("--beam_search_size", default=4, type=int)
    parser.add_argument("--no_repeat_ngram_size", default=2, type=int)
    
    parser.add_argument("--adam_beta1", default=0.9, type=float,
        help="Epsilon for Adam optimizer."
    )
    parser.add_argument("--adam_beta2", default=0.98, type=float,
        help="Epsilon for Adam optimizer."
    )
    parser.add_argument("--adam_epsilon", default=1e-8, type=float, 
        help="Epsilon for Adam optimizer."
    )
    parser.add_argument("--warmup_proportion", default=0.1, type=float,
        help="Proportion of training to perform linear learning rate warmup for,E.g., 0.1 = 10% of training."
    )
    parser.add_argument("--weight_decay", default=0.01, type=float,
        help="Weight decay if we apply some."
    )
    args = parser.parse_args()
    return args

数据集

此处使用的数据集是LCSTS(http://icrc.hitsz.edu.cn/Article/show/139.html)。

from torch.utils.data import Dataset, DataLoader
import torch

MAX_DATASET_SIZE = 200000

class LCSTS(Dataset):
    # 数据参考:http://icrc.hitsz.edu.cn/Article/show/139.html
    def __init__(self, data_file):
        self.data = self.load_data(data_file)
    
    def load_data(self, data_file):
        Data = {}
        with open(data_file, 'rt', encoding='utf-8') as f:
            for idx, line in enumerate(f):
                if idx >= MAX_DATASET_SIZE:
                    break
                items = line.strip().split('!=!')
                assert len(items) == 2
                Data[idx] = {
                    'title': items[0],
                    'content': items[1]
                }
        return Data
    
    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

def get_dataLoader(args, dataset, model, tokenizer, batch_size=None, shuffle=False):
    
    def collote_fn(batch_samples):
        batch_inputs, batch_targets = [], []
        for sample in batch_samples:
            batch_inputs.append(sample['content'])
            batch_targets.append(sample['title'])
        batch_data = tokenizer(
            batch_inputs, 
            padding=True, 
            max_length=args.max_input_length,
            truncation=True, 
            return_tensors="pt"
        )
        with tokenizer.as_target_tokenizer():
            labels = tokenizer(
                batch_targets, 
                padding=True, 
                max_length=args.max_target_length,
                truncation=True, 
                return_tensors="pt"
            )["input_ids"]
            batch_data['decoder_input_ids'] = model.prepare_decoder_input_ids_from_labels(labels)
            end_token_index = torch.where(labels == tokenizer.eos_token_id)[1]
            for idx, end_idx in enumerate(end_token_index):
                labels[idx][end_idx+1:] = -100
            batch_data['labels'] = labels
        return batch_data
    
    return DataLoader(dataset, batch_size=(batch_size if batch_size else args.batch_size), shuffle=shuffle, 
                      collate_fn=collote_fn)

记录:

  • 这里是比较常规的Dataset的定义,即基础的加载数据。

  • 后面还有一个get_dataLoader用于构造DataLoader

  • 由于此处的任务是摘要任务,因此label也是一段文本,也需要进行转化,这里使用了tokenizer转化为ids,最终训练的目标应该也是这串内容。

with tokenizer.as_target_tokenizer():
    labels = tokenizer(
        batch_targets, 
        padding=True, 
        max_length=args.max_target_length,
        truncation=True, 
        return_tensors="pt"
    )["input_ids"]
    batch_data['decoder_input_ids'] = model.prepare_decoder_input_ids_from_labels(labels)
    end_token_index = torch.where(labels == tokenizer.eos_token_id)[1]
    for idx, end_idx in enumerate(end_token_index):
        labels[idx][end_idx+1:] = -100
    batch_data['labels'] = labels

工具函数

tools.py内是一个设置随机种子的函数,非常适合收藏起来,这点有利于我们做效果的复现。

import random
import os
import numpy as np
import torch

def seed_everything(seed=1029):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    # some cudnn methods can be random even after fixing the seed
    # unless you tell it to be deterministic
    torch.backends.cudnn.deterministic = True

核心训练

接下来就是重头戏,模型的训练和推理,此处作者把他们都写在一块了,内容上是比较规范的,这里我们从主流程开始看。下面是主流程的代码。

if __name__ == '__main__':
    args = parse_args() 
    if args.do_train and os.path.exists(args.output_dir) and os.listdir(args.output_dir):
        raise ValueError(f'Output directory ({args.output_dir}) already exists and is not empty.')
    if not os.path.exists(args.output_dir):
        os.mkdir(args.output_dir)
    args.device = 'cuda' if torch.cuda.is_available() else 'cpu'
    args.n_gpu = torch.cuda.device_count()
    logger.warning(f'Using {args.device} device, n_gpu: {args.n_gpu}')
    # Set seed
    seed_everything(args.seed)
    # Load pretrained model and tokenizer
    logger.info(f'loading pretrained model and tokenizer of {args.model_type} ...')
    tokenizer = AutoTokenizer.from_pretrained(args.model_checkpoint)
    model = AutoModelForSeq2SeqLM.from_pretrained(args.model_checkpoint).to(args.device)
    # Training
    if args.do_train:
        # Set seed
        seed_everything(args.seed)
        train_dataset = LCSTS(args.train_file)
        dev_dataset = LCSTS(args.dev_file)
        train(args, train_dataset, dev_dataset, model, tokenizer)
    # Testing
    save_weights = [file for file in os.listdir(args.output_dir) if file.endswith('.bin')]
    if args.do_test:
        test_dataset = LCSTS(args.test_file)
        test(args, test_dataset, model, tokenizer, save_weights)
    # Predicting
    if args.do_predict:
        test_dataset = LCSTS(args.test_file)
        for save_weight in save_weights:
            logger.info(f'loading weights from {save_weight}...')
            model.load_state_dict(torch.load(os.path.join(args.output_dir, save_weight)))
            logger.info(f'predicting labels of {save_weight}...')

            results = []
            model.eval()
            for s_idx in tqdm(range(len(test_dataset))):
                sample = test_dataset[s_idx]
                pred_summ = predict(args, sample['content'], model, tokenizer)
                results.append({
                    "sentence": sample['content'], 
                    "prediction": pred_summ, 
                    "summarization": sample['title']
                })
            with open(os.path.join(args.output_dir, save_weight + '_test_data_pred.json'), 'wt', encoding='utf-8') as f:
                for exapmle_result in results:
                    f.write(json.dumps(exapmle_result, ensure_ascii=False) + '\n')

代码看着很长,但逐步看下来就会很好理解,接下来是分解动作。

基础参数准备

前面几步是比较基础的基础配置的加载和一些必要参数的初始化。

  • 脚本配置加载。

  • 输出路径的初始化。

  • GPU配置。

  • 随机数配置。

args = parse_args() 
if args.do_train and os.path.exists(args.output_dir) and os.listdir(args.output_dir):
    raise ValueError(f'Output directory ({args.output_dir}) already exists and is not empty.')
if not os.path.exists(args.output_dir):
    os.mkdir(args.output_dir)
args.device = 'cuda' if torch.cuda.is_available() else 'cpu'
args.n_gpu = torch.cuda.device_count()
logger.warning(f'Using {args.device} device, n_gpu: {args.n_gpu}')
# Set seed
seed_everything(args.seed)

紧跟着的是模型和tokenizer的加载。注意,此处没有再单独自定义模型了,而是使用的AutoModelForSeq2SeqLM便可直接加载。

# Load pretrained model and tokenizer
logger.info(f'loading pretrained model and tokenizer of {args.model_type} ...')
tokenizer = AutoTokenizer.from_pretrained(args.model_checkpoint)
model = AutoModelForSeq2SeqLM.from_pretrained(args.model_checkpoint).to(args.device)

然后就是分成3个模式各自的工作了,训练、测试和预测。

训练

首先是训练,简单地,训练就是加载数据然后再训练。

# Training
if args.do_train:
    # Set seed
    seed_everything(args.seed)
    train_dataset = LCSTS(args.train_file)
    dev_dataset = LCSTS(args.dev_file)
    train(args, train_dataset, dev_dataset, model, tokenizer)

这里核心就是这个train函数了。

def train(args, train_dataset, dev_dataset, model, tokenizer):
    """ Train the model """
    train_dataloader = get_dataLoader(args, train_dataset, model, tokenizer, shuffle=True)
    dev_dataloader = get_dataLoader(args, dev_dataset, model, tokenizer, shuffle=False)
    t_total = len(train_dataloader) * args.num_train_epochs
    # Prepare optimizer and schedule (linear warmup and decay)
    no_decay = ["bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], "weight_decay": args.weight_decay},
        {"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0}
    ]
    args.warmup_steps = int(t_total * args.warmup_proportion)
    optimizer = AdamW(
        optimizer_grouped_parameters, 
        lr=args.learning_rate, 
        betas=(args.adam_beta1, args.adam_beta2), 
        eps=args.adam_epsilon
    )
    lr_scheduler = get_scheduler(
        'linear',
        optimizer, 
        num_warmup_steps=args.warmup_steps,
        num_training_steps=t_total
    )
    # Train!
    logger.info("***** Running training *****")
    logger.info(f"Num examples - {len(train_dataset)}")
    logger.info(f"Num Epochs - {args.num_train_epochs}")
    logger.info(f"Total optimization steps - {t_total}")
    with open(os.path.join(args.output_dir, 'args.txt'), 'wt') as f:
        f.write(str(args))
    
    total_loss = 0.
    best_avg_rouge = 0.
    for epoch in range(args.num_train_epochs):
        print(f"Epoch {epoch+1}/{args.num_train_epochs}\n" + 30 * "-")
        total_loss = train_loop(args, train_dataloader, model, optimizer, lr_scheduler, epoch, total_loss)
        dev_rouges = test_loop(args, dev_dataloader, model, tokenizer)
        logger.info(f"Dev Rouge1: {dev_rouges['rouge-1']:>0.2f} Rouge2: {dev_rouges['rouge-2']:>0.2f} RougeL: {dev_rouges['rouge-l']:>0.2f}")
        rouge_avg = dev_rouges['avg']
        if rouge_avg > best_avg_rouge:
            best_avg_rouge = rouge_avg
            logger.info(f'saving new weights to {args.output_dir}...\n')
            save_weight = f'epoch_{epoch+1}_dev_rouge_avg_{rouge_avg:0.4f}_weights.bin'
            torch.save(model.state_dict(), os.path.join(args.output_dir, save_weight))
    logger.info("Done!")

我依次列举一下这里的操作。

  • dataloader初始化。

  • 训练参数初始化,包括学习率参数、warmup和衰减策略、优化方法。

  • 开始训练,依照epoch数量开始循环,这里的train_loop是step级的训练,然后是跑验证集的rouge(摘要指标),并记录最优结果。

train_loop的代码如下:

def train_loop(args, dataloader, model, optimizer, lr_scheduler, epoch, total_loss):
    progress_bar = tqdm(range(len(dataloader)))
    progress_bar.set_description(f'loss: {0:>7f}')
    finish_batch_num = epoch * len(dataloader)
    
    model.train()
    for batch, batch_data in enumerate(dataloader, start=1):
        batch_data = batch_data.to(args.device)
        outputs = model(**batch_data)
        loss = outputs.loss

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        lr_scheduler.step()

        total_loss += loss.item()
        progress_bar.set_description(f'loss: {total_loss/(finish_batch_num + batch):>7f}')
        progress_bar.update(1)
    return total_loss

非常常规的模型反向传播的流程,经典的4段:

optimizer.zero_grad()
loss.backward()
optimizer.step()
lr_scheduler.step()

看了train_loop,顺带看看那test_loop吧,这一步在train内也有用到。

def test_loop(args, dataloader, model, tokenizer):
    preds, labels = [], []
    rouge = Rouge()

    model.eval()
    with torch.no_grad():
        for batch_data in tqdm(dataloader):
            batch_data = batch_data.to(args.device)
            generated_tokens = model.generate(
                batch_data["input_ids"],
                attention_mask=batch_data["attention_mask"],
                max_length=args.max_target_length,
                num_beams=args.beam_search_size,
                no_repeat_ngram_size=args.no_repeat_ngram_size,
            ).cpu().numpy()
            if isinstance(generated_tokens, tuple):
                generated_tokens = generated_tokens[0]
            label_tokens = batch_data["labels"].cpu().numpy()

            decoded_preds = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True, clean_up_tokenization_spaces=False)
            label_tokens = np.where(label_tokens != -100, label_tokens, tokenizer.pad_token_id)
            decoded_labels = tokenizer.batch_decode(label_tokens, skip_special_tokens=True, clean_up_tokenization_spaces=False)

            preds += [' '.join(pred.strip()) for pred in decoded_preds]
            labels += [' '.join(label.strip()) for label in decoded_labels]
    scores = rouge.get_scores(hyps=preds, refs=labels, avg=True)
    result = {key: value['f'] * 100 for key, value in scores.items()}
    result['avg'] = np.mean(list(result.values()))
    return result

test_loop主要就是推理,并且比对预测结果和实际结果的差距。有两个细节:

  • 推理是用的model.generate,而不是训练中的model(**batch_data),这个和大模型的推理是类似的。

  • 然后是需要转化为rouge所需的格式,Rouge这个包对输出结果是有对比要求的。

测试

测试这块也是类似的逻辑,定义好数据集后,就可以开始训练了。这里的权重加载用的是一段很优雅的单行读取。

# Testing
save_weights = [file for file in os.listdir(args.output_dir) if file.endswith('.bin')]
if args.do_test:
    test_dataset = LCSTS(args.test_file)
    test(args, test_dataset, model, tokenizer, save_weights)

测试内部的逻辑就简单多了,基本就是加载后,直接跑前面提到的test_loop就好了。

def test(args, test_dataset, model, tokenizer, save_weights:list):
    test_dataloader = get_dataLoader(args, test_dataset, model, tokenizer, shuffle=False)
    logger.info('***** Running testing *****')
    for save_weight in save_weights:
        logger.info(f'loading weights from {save_weight}...')
        model.load_state_dict(torch.load(os.path.join(args.output_dir, save_weight)))
        test_rouges = test_loop(args, test_dataloader, model, tokenizer)
        logger.info(f"Test Rouge1: {test_rouges['rouge-1']:>0.2f} Rouge2: {test_rouges['rouge-2']:>0.2f} RougeL: {test_rouges['rouge-l']:>0.2f}")

推理

推理相比测试会有些不同,测试重在最终指标的展示,而推理则是要把结果跑出来,然后逐个记录下来。

# Predicting
if args.do_predict:
    test_dataset = LCSTS(args.test_file)
    for save_weight in save_weights:
        logger.info(f'loading weights from {save_weight}...')
        model.load_state_dict(torch.load(os.path.join(args.output_dir, save_weight)))
        logger.info(f'predicting labels of {save_weight}...')

        results = []
        model.eval()
        for s_idx in tqdm(range(len(test_dataset))):
            sample = test_dataset[s_idx]
            pred_summ = predict(args, sample['content'], model, tokenizer)
            results.append({
                "sentence": sample['content'], 
                "prediction": pred_summ, 
                "summarization": sample['title']
            })
        with open(os.path.join(args.output_dir, save_weight + '_test_data_pred.json'), 'wt', encoding='utf-8') as f:
            for exapmle_result in results:
                f.write(json.dumps(exapmle_result, ensure_ascii=False) + '\n')

首先,作者这里是把所有保存好的模型文件都加载出来用来推理(仔细看前面的训练代码会知道模型每个epoch作者都会有检验和保存)

if rouge_avg > best_avg_rouge:
    best_avg_rouge = rouge_avg
    logger.info(f'saving new weights to {args.output_dir}...\n')
    save_weight = f'epoch_{epoch+1}_dev_rouge_avg_{rouge_avg:0.4f}_weights.bin'
    torch.save(model.state_dict(), os.path.join(args.output_dir, save_weight))

加载后走的是predict做预测:

def predict(args, document:str, model, tokenizer):
    inputs = tokenizer(
        document, 
        max_length=args.max_input_length, 
        truncation=True, 
        return_tensors="pt"
    )
    inputs = inputs.to(args.device)
    with torch.no_grad():
        generated_tokens = model.generate(
            inputs["input_ids"],
            attention_mask=inputs["attention_mask"],
            max_length=args.max_target_length,
            num_beams=args.beam_search_size,
            no_repeat_ngram_size=args.no_repeat_ngram_size,
        ).cpu().numpy()
    if isinstance(generated_tokens, tuple):
        generated_tokens = generated_tokens[0]
    decoded_preds = tokenizer.decode(
        generated_tokens[0], 
        skip_special_tokens=True,
        clean_up_tokenization_spaces=False
    )
    return decoded_preds

这里,model.generate生成的是token序列,然后再decode解码。

至此,整个流程就结束。

小结

本文给大家展示的是一个比较完整的摘要任务训练代码,也是为了完善自己对摘要任务训练的理解,可以感受到整个项目的流程还是比较标准的pytorch模型训练流程,比较标志性的dataset/dataloader模块,transformers的模型和tokenizer加载,训练过程的经典4段等,这里的train/test/predict三函数模式也是如此。

标准的格式也给了我们很大的改动空间,后续我会做一个小改动,敬请期待。

f44d165aaf2e8e4898cdaaa6680c6e0c.png


http://www.kler.cn/a/378317.html

相关文章:

  • HarmonyOS 移动应用开发
  • 信息学科平台设计与实现:Spring Boot技术详解
  • 算法实现 - 快速排序(Quick Sort) - 理解版
  • QT——记事本项目
  • 哈尔滨二级等保Oracle数据库默认用户名与密码解析
  • 【双目视觉标定】——3面结构光相机标定实践(获取相机内参)~未完待续
  • 计算机毕业设计Python+大模型新闻自动分类 新闻舆情预测 新闻语料情感分析 新闻推荐系统 朴素贝叶斯分类算法 机器学习 深度学习
  • 【多模态读论文系列】LLaVA论文笔记
  • list与iterator的之间的区别,如何用斐波那契数列探索yield
  • Java后端面试内容总结
  • fetch 与 xmlHttpRequest 请求总结
  • IT运维的365天--018 如何在内网布置一个和外网同域名的网站,并开启SSL(https访问),即外网证书如何在内网使用
  • 【机器学习】回归树
  • 【大语言模型】ACL2024论文-06 探索思维链COT在多模态隐喻检测中的应用
  • Logback 常用配置详解
  • 第十九章 Vue组件之data函数
  • Python Matplotlib 如何处理大数据集的绘制,提高绘图效率
  • lc 73 矩阵置0 ACM模式
  • webpack5
  • 【RK3588 Linux 5.x 内核编程】-设备驱动中的sysfs
  • 【架构艺术】服务架构稳定性的基础保障
  • 嵌入式开发之刷新流
  • SAO-LSSVM分类预测 | SAO-LSSVM雪消融算法优化最小二乘支持向量机多特征分类预测
  • JavaScript 进阶 - 第4天 (黑马笔记)
  • [JAVAEE] 面试题(二) - CAS 和 原子类
  • Java项目实战II基于Spring Boot的秒杀系统设计与实现(开发文档+数据库+源码)