PyTorch Lightning Trainer介绍
PyTorch Lightning 的 Trainer
是框架的核心类,负责自动化训练流程、分布式训练、日志记录、模型保存等复杂操作。通过配置参数即可快速实现高效训练,无需手动编写循环代码。以下是详细介绍和使用示例:
Trainer
的核心功能
-
自动化训练循环
自动处理training_step
、validation_step
、test_step
的调用,无需手动编写for epoch in epochs
循环。 -
硬件加速支持
支持 CPU/GPU/TPU、多卡训练(DDP、DeepSpeed)、混合精度训练等。 -
训练控制
控制训练轮数 (max_epochs
)、批次大小 (batch_size
)、梯度裁剪 (gradient_clip_val
) 等。 -
日志与监控
集成 TensorBoard、W&B、MLFlow 等日志工具,监控损失、准确率等指标。 -
回调机制
通过回调函数(如EarlyStopping
,ModelCheckpoint
)实现早停、模型保存等扩展功能。
Trainer
的常用参数
from pytorch_lightning import Trainer
trainer = Trainer(
# 基础配置
max_epochs=10, # 最大训练轮数
accelerator="auto", # 自动选择设备 (CPU/GPU/TPU)
devices="auto", # 使用所有可用设备(如多 GPU)
precision="16-mixed", # 混合精度训练(FP16)
# 日志与调试
logger=True, # 默认使用 TensorBoard
log_every_n_steps=10, # 每 10 个批次记录一次日志
fast_dev_run=False, # 快速运行一个批次(调试模式)
# 回调函数
callbacks=[
pl.callbacks.EarlyStopping(monitor="val_loss", patience=3),
pl.callbacks.ModelCheckpoint(monitor="val_loss", save_top_k=2)
],
# 分布式训练
strategy="ddp", # 分布式数据并行策略(多 GPU)
num_nodes=1, # 节点数量(多机器训练)
)
使用示例代码
步骤 1:定义 LightningModule
import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl
class LitModel(pl.LightningModule):
def __init__(self):
super().__init__()
self.layer1 = nn.Linear(28*28, 128)
self.layer2 = nn.Linear(128, 10)
def forward(self, x):
x = x.view(x.size(0), -1) # 展平输入
x = F.relu(self.layer1(x))
x = self.layer2(x)
return x
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = F.cross_entropy(y_hat, y)
self.log("train_loss", loss) # 自动记录日志
return loss
def validation_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = F.cross_entropy(y_hat, y)
self.log("val_loss", loss) # 自动记录验证损失
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=0.001)
步骤 2:定义 DataModule
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
class MNISTDataModule(pl.LightningDataModule):
def __init__(self, batch_size=32):
super().__init__()
self.batch_size = batch_size
def prepare_data(self):
MNIST(root="data", download=True)
def setup(self, stage=None):
full_dataset = MNIST(root="data", train=True, transform=ToTensor())
self.train_data, self.val_data = random_split(full_dataset, [55000, 5000])
def train_dataloader(self):
return DataLoader(self.train_data, batch_size=self.batch_size, shuffle=True)
def val_dataloader(self):
return DataLoader(self.val_data, batch_size=self.batch_size)
dm = MNISTDataModule(batch_size=32)
步骤 3:启动训练
model = LitModel()
trainer = Trainer(
max_epochs=10,
accelerator="auto",
devices="auto",
logger=True,
callbacks=[
pl.callbacks.ModelCheckpoint(monitor="val_loss")
]
)
# 开始训练与验证
trainer.fit(model, datamodule=dm)
# 测试(可选)
trainer.test(model, datamodule=dm)
关键功能演示
1. 多 GPU 训练
# 使用 4 个 GPU 训练
trainer = Trainer(devices=4, strategy="ddp")
2. 混合精度训练
# 使用 FP16 混合精度
trainer = Trainer(precision="16-mixed")
3. 早停与模型保存
callbacks = [
pl.callbacks.EarlyStopping(monitor="val_loss", patience=3),
pl.callbacks.ModelCheckpoint(
dirpath="checkpoints/",
filename="best-model-{epoch:02d}-{val_loss:.2f}",
save_top_k=2,
monitor="val_loss"
)
]
trainer = Trainer(callbacks=callbacks)
4. 调试模式
# 快速验证代码正确性(仅运行一个批次)
trainer = Trainer(fast_dev_run=True)
常见问题
如何恢复训练?
使用 resume_from_checkpoint
参数:
trainer = Trainer(resume_from_checkpoint="path/to/checkpoint.ckpt")
如何限制训练时间?
trainer = Trainer(max_time="00:02:00") # 最多训练 2 分钟
如何自定义学习率调度器?
在 自定义的 LightningDataModule继承类的 configure_optimizers
方法中返回优化器和调度器:
def configure_optimizers(self):
optimizer = Adam(self.parameters())
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1)
return [optimizer], [scheduler]
总结
通过 Trainer
,PyTorch Lightning 将训练流程的复杂性封装在几行配置中,开发者只需关注模型逻辑和数据加载。其灵活的参数和回调机制能够覆盖从实验到生产的全流程需求。
参考:
https://lightning.ai/docs/pytorch/stable/common/trainer.html