当前位置: 首页 > article >正文

9.pytorch lightning之数据模块LightningDataModule

使用LightningDataModule

LightningDataModule是一个可重用的类,内部封闭处理数据所需的步骤。

其中封装了pytorch处理数据的5个步骤:

  1. 准备数据:一般是数据集下载过程,如果是本地数据则不需要。
  2. 数据清理。
  3. 加载Dataset对象。
  4. 预处理:旋转,翻转,归一化等
  5. 将train_loader、val_loader、test_loader封装在一起。

使用LightningDataModule的好处是可以将数据处理整合在一起,并且很容易利用,创建后可以将其用于任何地方。

在原生pytorch代码中,针对train/val/test的不同阶段通常先创建多个dataset和dataloader,而在lightning 则将这些封装在一起,只需在fit()是传入单个ightningDataModule对象。

# regular PyTorch
test_data = MNIST(my_path, train=False, download=True)
predict_data = MNIST(my_path, train=False, download=True)
train_data = MNIST(my_path, train=True, download=True)
train_data, val_data = random_split(train_data, [55000, 5000])

train_loader = DataLoader(train_data, batch_size=32)
val_loader = DataLoader(val_data, batch_size=32)
test_loader = DataLoader(test_data, batch_size=32)
predict_loader = DataLoader(predict_data, batch_size=32)
# lightning 风格
class MNISTDataModule(pl.LightningDataModule):
    def __init__(self, data_dir: str = "./"):
        super().__init__()
        self.data_dir = data_dir
        self.transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])

    # 下载数据
    def prepare_data(self):
        # download
        MNIST(self.data_dir, train=True, download=True)
        MNIST(self.data_dir, train=False, download=True)
    '''
    setup 方法创建Dataset对象,对不同数据集指定预处理方法
    '''
    def setup(self, stage: str):
        # Assign train/val datasets for use in dataloaders
        if stage == "fit":
            mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)
            self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])

        # Assign test dataset for use in dataloader(s)
        if stage == "test":
            self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)

        if stage == "predict":
            self.mnist_predict = MNIST(self.data_dir, train=False, transform=self.transform)

    # 以下方法创建不同阶段的数据加载器
    def train_dataloader(self):
        return DataLoader(self.mnist_train, batch_size=32)

    def val_dataloader(self):
        return DataLoader(self.mnist_val, batch_size=32)

    def test_dataloader(self):
        return DataLoader(self.mnist_test, batch_size=32)

    def predict_dataloader(self):
        return DataLoader(self.mnist_predict, batch_size=32)
    # 
    mnist = MNISTDataModule(my_path)
    model = LitClassifier()

    trainer = Trainer()
    # fit 训练阶段给定model 和data。
    trainer.fit(model, mnist)

另外,val_dataloader,train_dataloader等方法也可以实现在LightningModule子类中,fit时只给定相应的model实例。

同时加载多个数据集

这在半监督场景尤其有用,只需先将多个loader 用list,dict对象包装在一起,再封装在CombinedLoader中,注意CombinedLoader要求版本>2.0


http://www.kler.cn/a/15682.html

相关文章:

  • uniapp自动注册机制:easycom
  • 【网络安全】XSS注入
  • ASP.NET 部署到IIS,访问其它服务器的共享文件 密码设定
  • 前端三大组件之CSS,三大选择器,游戏网页仿写
  • LeetCode --- 143周赛
  • 023、ELK 从入门到实践
  • mysql笔记
  • python实战应用讲解-【numpy数组篇】常用函数(六)(附python示例代码)
  • HPC的资料
  • 【Docker】Dockerfile简介
  • 被修饰成单栋的倾斜摄影处理思路
  • 有理函数的不定积分
  • 《港联证券》半导体复苏预期“抢跑”产业现实 细分市场缓慢回温
  • ETL工具 - Kettle 转换算子介绍
  • Linux进程间通信 - 信号(signal) 与 管道(pipe) 与 消息队列
  • 【VM服务管家】VM4.0软件使用_1.1 环境配置类
  • SpringBoot 中 4 种常用的数据库访问方式
  • Microsoft Bitlocker企业级管理部署方案
  • 在京东工作8年的程序员,35岁被裁拿到30多万的赔偿,终于自由了
  • 2023天梯赛补题
  • 华为OD机试 - 模拟商场优惠打折(Python)
  • 回溯算法秒杀2
  • c++算法——vector
  • Apache Hudi初探(二)(与spark的结合)
  • 用 AudioGPT 输入自然语言,可以让 ChatGPT 唱歌了?
  • 借助尾号限行 API 实现限行规则应用的设计思路分析