当前位置: 首页 > article >正文

PyTorch Lightning Callback介绍

PyTorch Lightning Callback 介绍

在 PyTorch 中,callbacks(回调函数)不是原生支持的核心功能,但在深度学习中非常常见,尤其是用来监控训练过程、调整超参数或执行特定的任务。许多高级深度学习框架(如 PyTorch Lightning 和 FastAI)都基于 PyTorch,并内置了 callback 支持。

PyTorch Lightning 提供了一个易于扩展的回调机制,允许用户在训练过程中插入自定义逻辑。回调类继承自 pytorch_lightning.callbacks.Callback,可以覆盖以下方法:

常用方法
  • on_fit_start: 在训练(fit)开始时调用。
  • on_fit_end: 在训练(fit)结束时调用。
  • on_train_epoch_start: 在每个训练 epoch 开始时调用。
  • on_train_epoch_end: 在每个训练 epoch 结束时调用。
  • on_validation_epoch_start: 在每个验证 epoch 开始时调用。
  • on_validation_epoch_end: 在每个验证 epoch 结束时调用。
  • on_test_epoch_start: 在测试 epoch 开始时调用。
  • on_test_epoch_end: 在测试 epoch 结束时调用。
  • on_train_batch_end: 在每个训练 batch 结束时调用。
  • on_validation_batch_end: 在每个验证 batch 结束时调用。
  • on_test_batch_end: 在每个测试 batch 结束时调用。

示例: 自定义 Callback

以下示例实现了一个打印日志的回调:

from pytorch_lightning.callbacks import Callback

class PrintCallback(Callback):
    def on_train_epoch_end(self, trainer, pl_module):
        print(f"Epoch {trainer.current_epoch}: Training ended!")

    def on_validation_epoch_end(self, trainer, pl_module):
        print(f"Epoch {trainer.current_epoch}: Validation ended!")

使用时将回调传递给 Trainer

from pytorch_lightning import Trainer

trainer = Trainer(callbacks=[PrintCallback()])

基于 Hydra 配置实例化 Callback

Hydra 是一个灵活的配置管理工具,常用于深度学习项目中动态管理超参数。通过结合 Hydra 和 PyTorch Lightning,可以动态配置并实例化 Callback。

步骤:

1. 安装 Hydra

pip install hydra-core --upgrade

2. 定义 Hydra 配置文件: 创建一个 YAML 配置文件(如 config.yaml)来管理 Callback 的配置:

callbacks:
  model_checkpoint:
    _target_: pytorch_lightning.callbacks.ModelCheckpoint
    monitor: "val_loss"
    save_top_k: 1
    mode: "min"

  early_stopping:
    _target_: pytorch_lightning.callbacks.EarlyStopping
    monitor: "val_loss"
    patience: 5
    mode: "min"

3. 在代码中动态实例化: 使用 hydra.utils.instantiate 方法实例化回调对象:

import hydra
from hydra.utils import instantiate
from pytorch_lightning import Trainer
from omegaconf import OmegaConf

@hydra.main(config_path=".", config_name="config")
def main(cfg):
    # Instantiate callbacks from config
    callbacks = [instantiate(cfg.callbacks[key]) for key in cfg.callbacks]

    # Example: Define a simple PyTorch Lightning model
    from pytorch_lightning import LightningModule
    import torch.nn.functional as F

    class SimpleModel(LightningModule):
        def __init__(self):
            super().__init__()
            self.layer = torch.nn.Linear(10, 1)

        def forward(self, x):
            return self.layer(x)

        def training_step(self, batch, batch_idx):
            x, y = batch
            y_hat = self(x)
            loss = F.mse_loss(y_hat, y)
            return loss

        def configure_optimizers(self):
            return torch.optim.Adam(self.parameters(), lr=0.001)

    # Instantiate trainer
    trainer = Trainer(callbacks=callbacks, max_epochs=10)

    # Simulated data loader
    from torch.utils.data import DataLoader, TensorDataset
    import torch

    x = torch.rand(100, 10)
    y = torch.rand(100, 1)
    train_loader = DataLoader(TensorDataset(x, y), batch_size=32)

    model = SimpleModel()
    trainer.fit(model, train_loader)

if __name__ == "__main__":
    main()
解释:如何通过配置文件动态管理 Callback
  1. 配置文件中,_target_ 指定回调类的完整路径。
  2. 使用 hydra.utils.instantiate 根据配置动态实例化对象。
  3. 将实例化后的回调传递给 Trainer
优势
  1. 动态配置:通过 YAML 文件可以快速更改回调逻辑而无需修改代码。
  2. 模块化管理:方便管理多个回调类,清晰直观。
  3. 灵活性:支持自定义 Callback 和 Lightning 内置回调的结合使用。

此方法适用于多种场景,比如动态调整模型保存路径、早停策略等。


http://www.kler.cn/a/454983.html

相关文章:

  • 如何设置爬虫的User-Agent?
  • java实现网络IO高并发编程java NIO
  • 在uniapp中如何自定义一个图标
  • 【软件工程】十万字知识点梳理 | 期末复习专用
  • docker mysql5.7安装
  • .net core 的软件开发模式
  • 欲海航舟:探寻天性驱动下的欲望演变与人生驾驭
  • ArcGIS土地利用数据制备、分析及基于FLUS模型土地利用预测(数据采集、处理、分析、制图)
  • Python数据可视化小项目
  • 【Redis】Redis 安装与启动
  • Go 计算Utf8字符串的长度 不要超过mysql字段的最大长度
  • springboot502基于WEB的牙科诊所管理系统(论文+源码)_kaic
  • Linux知识点回顾(期末提分篇)
  • 文档大师:打造一站式 Word 报告解决方案1
  • Java实现观察者模式
  • 同步与异步日志系统的深入探讨与应用
  • 箭头函数与普通函数的区别
  • 使用 .NET 6 或 .NET 8 上传大文件
  • 【远程桌面】被窥屏
  • selenium浏览器下载汇总