当前位置: 首页 > article >正文

基于Python的自然语言处理系列(36):使用PyTorch微调(无需Trainer)

        本篇文章将展示如何通过PyTorch手动实现预训练模型的微调,而不依赖Trainer类。我们将以微软研究院的MRPC(Microsoft Research Paraphrase Corpus)数据集为例。MRPC包含句对和标签,用于判断两句是否为语义相同的重述,非常适合用于探索文本分类任务的微调流程。我们会从数据预处理到训练与评估,逐步演示整个过程,最后介绍如何通过Accelerate库提升多设备上的训练效率。

1. 加载数据集和预处理

        首先加载MRPC数据集,并通过BERT的Tokenizer进行分词。我们使用glue任务中的MRPC数据集,并配置BERT的预训练检查点。

from datasets import load_dataset
from transformers import AutoTokenizer, DataCollatorWithPadding

# 加载数据集和BERT tokenizer
raw_datasets = load_dataset("glue", "mrpc")
checkpoint = "bert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)

# 定义分词函数
def tokenize_function(example):
    return tokenizer(example["sentence1"], example["sentence2"], truncation=True)

# 应用分词并生成数据集
tokenized_datasets = raw_datasets.map(tokenize_function, batched=True)
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

数据集后处理

        我们需要移除不必要的列,并将标签列重命名为模型期望的labels

tokenized_datasets = tokenized_datasets.remove_columns(["sentence1", "sentence2", "idx"])
tokenized_datasets = tokenized_datasets.rename_column("label", "labels")
tokenized_datasets.set_format("torch")

数据加载器

        我们创建训练和验证数据的加载器:

from torch.utils.data import DataLoader

train_dataloader = DataLoader(
    tokenized_datasets["train"], shuffle=True, batch_size=8, collate_fn=data_collator
)
eval_dataloader = DataLoader(
    tokenized_datasets["validation"], batch_size=8, collate_fn=data_collator
)

2. 定义模型和优化器

        我们使用BERT模型,并通过AdamW优化器和线性学习率调度器配置训练过程:

from transformers import AutoModelForSequenceClassification, get_scheduler
from torch.optim import AdamW

model = AutoModelForSequenceClassification.from_pretrained(checkpoint, num_labels=2)
optimizer = AdamW(model.parameters(), lr=5e-5)

num_epochs = 3
num_training_steps = num_epochs * len(train_dataloader)
lr_scheduler = get_scheduler(
    "linear", optimizer=optimizer, num_warmup_steps=0, num_training_steps=num_training_steps
)

使用GPU(如果可用)

import torch

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model.to(device)

3. 训练循环

        在训练过程中,我们遍历所有训练数据,并在每次反向传播后更新模型参数。

from tqdm.auto import tqdm

progress_bar = tqdm(range(num_training_steps))
model.train()

for epoch in range(num_epochs):
    for batch in train_dataloader:
        batch = {k: v.to(device) for k, v in batch.items()}
        outputs = model(**batch)
        loss = outputs.loss
        loss.backward()

        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()
        progress_bar.update(1)

4. 评估循环

        我们使用Evaluate库对验证数据进行评估。

import evaluate

metric = evaluate.load("glue", "mrpc")
model.eval()

for batch in eval_dataloader:
    batch = {k: v.to(device) for k, v in batch.items()}
    with torch.no_grad():
        outputs = model(**batch)

    logits = outputs.logits
    predictions = torch.argmax(logits, dim=-1)
    metric.add_batch(predictions=predictions, references=batch["labels"])

print(metric.compute())

5. 使用Accelerate加速分布式训练

        我们可以通过Accelerate库轻松启用多GPU或TPU上的训练。

from accelerate import Accelerator

accelerator = Accelerator()
model = AutoModelForSequenceClassification.from_pretrained(checkpoint, num_labels=2)
optimizer = AdamW(model.parameters(), lr=3e-5)

train_dl, eval_dl, model, optimizer = accelerator.prepare(
    train_dataloader, eval_dataloader, model, optimizer
)

progress_bar = tqdm(range(num_training_steps))
model.train()

for epoch in range(num_epochs):
    for batch in train_dl:
        outputs = model(**batch)
        loss = outputs.loss
        accelerator.backward(loss)
        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()
        progress_bar.update(1)

结语

        在这篇文章中,我们深入探讨了如何通过PyTorch手动实现模型的微调,而不依赖Trainer类。我们从数据的加载和预处理开始,一步步实现了训练和评估循环,并展示了如何使用Accelerate库加速多设备上的分布式训练。这种细粒度的控制方式使我们能够更灵活地进行模型调优,适应不同的项目需求。

        在接下来的博文中,我们将聚焦于Datasets、Preprocessing 和 Streaming,探索如何有效管理和处理大规模数据集,并介绍流数据(Streaming)处理的最佳实践。这些技巧对于优化训练效率和数据加载速度至关重要,敬请期待!

如果你觉得这篇博文对你有帮助,请点赞、收藏、关注我,并且可以打赏支持我!

欢迎关注我的后续博文,我将分享更多关于人工智能、自然语言处理和计算机视觉的精彩内容。

谢谢大家的支持!


http://www.kler.cn/news/354268.html

相关文章:

  • 设计模式详解(命令模式)
  • GPT提示词
  • 限制游客在wordpress某分类下阅读文章的数量
  • 【Redis_Day1】分布式系统和Redis
  • LeetCode刷题日记之贪心算法(一)
  • Unity3D URP画面品质的上限如何详解
  • 【HarmonyOS NEXT】服务端向终端推送消息——获取Push Token
  • 详细指南:如何使用WildCard升级到ChatGPT 4.0
  • 【React】使用脚手架或Vite包两种方式创建react项目
  • 基于NXP LS1023+FPGA的嵌入式解决方案
  • 计算机视觉算法的演进与应用:从基础理论到前沿技术
  • 服务器和中转机协同工作以提高网络安全
  • 一站式讲解Wireshark网络抓包分析的若干场景、过滤条件及分析方法
  • Vue.js 组件开发全攻略:从基础到高级特性详解
  • 性能测试工具JMeter(二)
  • 《工业领域缺陷检测方案:创新与应用》
  • C/C++ 内存分布与管理:简单易懂的入门指南
  • hive 误删表恢复
  • 前端一键复制解决方案分享
  • Qt中的连接类型