当前位置: 首页 > article >正文

PyTorch Lightning Trainer介绍

PyTorch Lightning 的 Trainer 是框架的核心类,负责自动化训练流程、分布式训练、日志记录、模型保存等复杂操作。通过配置参数即可快速实现高效训练,无需手动编写循环代码。以下是详细介绍和使用示例:

Trainer 的核心功能

  1. 自动化训练循环
    自动处理 training_stepvalidation_steptest_step 的调用,无需手动编写 for epoch in epochs 循环。

  2. 硬件加速支持
    支持 CPU/GPU/TPU、多卡训练(DDP、DeepSpeed)、混合精度训练等。

  3. 训练控制
    控制训练轮数 (max_epochs)、批次大小 (batch_size)、梯度裁剪 (gradient_clip_val) 等。

  4. 日志与监控
    集成 TensorBoard、W&B、MLFlow 等日志工具,监控损失、准确率等指标。

  5. 回调机制
    通过回调函数(如 EarlyStoppingModelCheckpoint)实现早停、模型保存等扩展功能。

Trainer 的常用参数

from pytorch_lightning import Trainer

trainer = Trainer(
    # 基础配置
    max_epochs=10,            # 最大训练轮数
    accelerator="auto",       # 自动选择设备 (CPU/GPU/TPU)
    devices="auto",           # 使用所有可用设备(如多 GPU)
    precision="16-mixed",     # 混合精度训练(FP16)

    # 日志与调试
    logger=True,              # 默认使用 TensorBoard
    log_every_n_steps=10,     # 每 10 个批次记录一次日志
    fast_dev_run=False,       # 快速运行一个批次(调试模式)

    # 回调函数
    callbacks=[
        pl.callbacks.EarlyStopping(monitor="val_loss", patience=3),
        pl.callbacks.ModelCheckpoint(monitor="val_loss", save_top_k=2)
    ],

    # 分布式训练
    strategy="ddp",           # 分布式数据并行策略(多 GPU)
    num_nodes=1,              # 节点数量(多机器训练)
)

使用示例代码

步骤 1:定义 LightningModule
import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl

class LitModel(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.layer1 = nn.Linear(28*28, 128)
        self.layer2 = nn.Linear(128, 10)

    def forward(self, x):
        x = x.view(x.size(0), -1)  # 展平输入
        x = F.relu(self.layer1(x))
        x = self.layer2(x)
        return x

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        self.log("train_loss", loss)  # 自动记录日志
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        self.log("val_loss", loss)     # 自动记录验证损失

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.001)
步骤 2:定义 DataModule
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor

class MNISTDataModule(pl.LightningDataModule):
    def __init__(self, batch_size=32):
        super().__init__()
        self.batch_size = batch_size

    def prepare_data(self):
        MNIST(root="data", download=True)

    def setup(self, stage=None):
        full_dataset = MNIST(root="data", train=True, transform=ToTensor())
        self.train_data, self.val_data = random_split(full_dataset, [55000, 5000])

    def train_dataloader(self):
        return DataLoader(self.train_data, batch_size=self.batch_size, shuffle=True)

    def val_dataloader(self):
        return DataLoader(self.val_data, batch_size=self.batch_size)

dm = MNISTDataModule(batch_size=32)

步骤 3:启动训练

model = LitModel()
trainer = Trainer(
    max_epochs=10,
    accelerator="auto",
    devices="auto",
    logger=True,
    callbacks=[
        pl.callbacks.ModelCheckpoint(monitor="val_loss")
    ]
)

# 开始训练与验证
trainer.fit(model, datamodule=dm)

# 测试(可选)
trainer.test(model, datamodule=dm)

关键功能演示

1. 多 GPU 训练
# 使用 4 个 GPU 训练
trainer = Trainer(devices=4, strategy="ddp")
2. 混合精度训练

# 使用 FP16 混合精度
trainer = Trainer(precision="16-mixed")
3. 早停与模型保存
callbacks = [
    pl.callbacks.EarlyStopping(monitor="val_loss", patience=3),
    pl.callbacks.ModelCheckpoint(
        dirpath="checkpoints/",
        filename="best-model-{epoch:02d}-{val_loss:.2f}",
        save_top_k=2,
        monitor="val_loss"
    )
]
trainer = Trainer(callbacks=callbacks)
4. 调试模式
# 快速验证代码正确性(仅运行一个批次)
trainer = Trainer(fast_dev_run=True)

常见问题

如何恢复训练?
使用 resume_from_checkpoint 参数:

trainer = Trainer(resume_from_checkpoint="path/to/checkpoint.ckpt")

如何限制训练时间?

trainer = Trainer(max_time="00:02:00")  # 最多训练 2 分钟

如何自定义学习率调度器?
在 自定义的 LightningDataModule继承类的 configure_optimizers 方法中返回优化器和调度器:

def configure_optimizers(self):
    optimizer = Adam(self.parameters())
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1)
    return [optimizer], [scheduler]

总结

通过 Trainer,PyTorch Lightning 将训练流程的复杂性封装在几行配置中,开发者只需关注模型逻辑和数据加载。其灵活的参数和回调机制能够覆盖从实验到生产的全流程需求。

参考:

https://lightning.ai/docs/pytorch/stable/common/trainer.html


http://www.kler.cn/a/543220.html

相关文章:

  • 使用 Go 语言调用 SiliconFlow 语音生成 API 的脚本,用于将文本转换为 MP3 格式的语音文件。
  • 使用mermaid画流程图
  • (一)Axure制作移动端登录页面
  • Flutter_学习记录_基本组件的使用记录_2
  • Unity3D实现显示模型线框(shader)
  • STM32-知识
  • Spring 核心技术解析【纯干货版】- XII:Spring 数据访问模块 Spring-R2dbc 模块精讲
  • 如何在WinForms应用程序中读取和写入App.config文件
  • 记忆模块概述
  • 用AI做算法题1
  • 深度学习-111-大语言模型LLM之基于langchain的结构化输出功能实现文本分类
  • 网络工程师 (33)VLAN注册协议——GVRP协议
  • linux 内核结构基础
  • MFC程序设计(十二)绘图
  • 建筑兔零基础自学python记录18|实战人脸识别项目——视频检测07
  • EPL 4.01 Preview
  • 【Elasticsearch】文本分析Text analysis概述
  • 【鸿蒙开发】第二十九章 Stage模型-应用上下文Context、进程、线程
  • Unity 代码优化记录
  • 【c++】shared_ptr是线程安全的吗?
  • fun-transformer学习笔记-Task1——Transformer、Seq2Seq、Encoder-Decoder、Attention之间的关系
  • vivo手机和Windows电脑连接同一个WiFi即可投屏!
  • Spring Cloud 完整引解:优化你的微服务架构
  • GEE批量打开asset权限(anyone can read)
  • YOLOv11融合[AAAI2025]的Mesorch 模型中的高、低频特征提取模块
  • kafka在初始化集群配置当中有哪些重要参数?