9.pytorch lightning之数据模块LightningDataModule
使用LightningDataModule
LightningDataModule是一个可重用的类,内部封闭处理数据所需的步骤。
其中封装了pytorch处理数据的5个步骤:
- 准备数据:一般是数据集下载过程,如果是本地数据则不需要。
- 数据清理。
- 加载Dataset对象。
- 预处理:旋转,翻转,归一化等
- 将train_loader、val_loader、test_loader封装在一起。
使用LightningDataModule的好处是可以将数据处理整合在一起,并且很容易利用,创建后可以将其用于任何地方。
在原生pytorch代码中,针对train/val/test的不同阶段通常先创建多个dataset和dataloader,而在lightning 则将这些封装在一起,只需在fit()是传入单个ightningDataModule对象。
# regular PyTorch
test_data = MNIST(my_path, train=False, download=True)
predict_data = MNIST(my_path, train=False, download=True)
train_data = MNIST(my_path, train=True, download=True)
train_data, val_data = random_split(train_data, [55000, 5000])
train_loader = DataLoader(train_data, batch_size=32)
val_loader = DataLoader(val_data, batch_size=32)
test_loader = DataLoader(test_data, batch_size=32)
predict_loader = DataLoader(predict_data, batch_size=32)
# lightning 风格
class MNISTDataModule(pl.LightningDataModule):
def __init__(self, data_dir: str = "./"):
super().__init__()
self.data_dir = data_dir
self.transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
# 下载数据
def prepare_data(self):
# download
MNIST(self.data_dir, train=True, download=True)
MNIST(self.data_dir, train=False, download=True)
'''
setup 方法创建Dataset对象,对不同数据集指定预处理方法
'''
def setup(self, stage: str):
# Assign train/val datasets for use in dataloaders
if stage == "fit":
mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)
self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])
# Assign test dataset for use in dataloader(s)
if stage == "test":
self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)
if stage == "predict":
self.mnist_predict = MNIST(self.data_dir, train=False, transform=self.transform)
# 以下方法创建不同阶段的数据加载器
def train_dataloader(self):
return DataLoader(self.mnist_train, batch_size=32)
def val_dataloader(self):
return DataLoader(self.mnist_val, batch_size=32)
def test_dataloader(self):
return DataLoader(self.mnist_test, batch_size=32)
def predict_dataloader(self):
return DataLoader(self.mnist_predict, batch_size=32)
#
mnist = MNISTDataModule(my_path)
model = LitClassifier()
trainer = Trainer()
# fit 训练阶段给定model 和data。
trainer.fit(model, mnist)
另外,val_dataloader,train_dataloader等方法也可以实现在LightningModule子类中,fit时只给定相应的model实例。
同时加载多个数据集
这在半监督场景尤其有用,只需先将多个loader 用list,dict对象包装在一起,再封装在CombinedLoader中,注意CombinedLoader要求版本>2.0