lightning.pytorch.callbacks内置的Callbacks介绍
PyTorch Lightning 提供了一些 内置回调 (Callback
),可以在训练过程中自动执行 检查点保存、学习率调度、早停 等功能。通过使用 Trainer(callbacks=[...])
来传入这些回调。
PyTorch Lightning 的 Callback 是一种强大的工具,允许用户在训练过程中插入自定义逻辑,而无需修改核心的训练代码。Callback 的设计基于钩子(hooks),在训练流程的特定点执行自定义代码。
1. 内置回调列表
PyTorch Lightning 内置了以下 Callbacks
(可以直接使用,无需自定义):
回调名称 | 功能 |
---|---|
ModelCheckpoint |
自动保存最佳模型(基于验证指标) |
EarlyStopping |
自动停止训练(当验证指标不再改善时) |
LearningRateMonitor |
记录学习率变化(支持 TensorBoard & WandB ) |
RichProgressBar |
使用 rich 库美化训练进度条 |
TQDMProgressBar |
默认的 tqdm 进度条 |
DeviceStatsMonitor |
监控 GPU/CPU 使用情况 |
BatchSizeFinder |
自动寻找最优 batch size |
GradientAccumulationScheduler |
自动调整梯度累积步数 |
ModelSummary |
打印模型结构和参数量 |
StochasticWeightAveraging (SWA) |
使用 SWA 进行权重平均,提高泛化能力 |
2. 详细介绍 & 代码示例
(1)ModelCheckpoint
- 自动保存最佳模型
用于 自动保存模型检查点(ckpt),可以: ✅ 保存最佳模型(基于某个指标,如 val_loss
)。
✅ 定期保存(如每 n
个 epoch 保存一次)。
✅ 限制最大检查点数量(避免磁盘占用过大)。
📜 示例:
from lightning.pytorch.callbacks import ModelCheckpoint
# 仅保存最优模型 (基于 val_loss)
check