Pytorch 自学笔记(三):利用自定义文本数据集构建Dataset和DataLoader
Pytorch 自学笔记(三)
- 1. Dataset与DataLoader
- 1.1 torch.utils.data.Dataset
- 1.2 torch.utils.data.DataLoader
Pytorch 自学笔记系列的第三篇。针对Pytorch的Dataset和DataLoader进行简单的介绍,同时,介绍如何使用自定义文本数据集构建Dataset和DataLoader,以实现数据集的随机采样与batch加载。(注:文中代码使用python3.7和pytorch1.7.1编写)
1. Dataset与DataLoader
1.1 torch.utils.data.Dataset
torch.utils.data.Dataset
是pytorch中定义的数据集抽象类,pytorch中任何的数据集类都必须继承并重写这个类,其源码如下:
class Dataset(Generic[T_co]):
r"""An abstract class representing a :class:`Dataset`.
All datasets that represent a map from keys to data samples should subclass
it. All subclasses should overwrite :meth:`__getitem__`, supporting fetching a
data sample for a given key. Subclasses could also optionally overwrite
:meth:`__len__`, which is expected to return the size of the dataset by many
:class:`~torch.utils.data.Sampler` implementations and the default options
of :class:`~torch.utils.data.DataLoader`.
.. note::
:class:`~torch.utils.data.DataLoader` by default constructs a index
sampler that yields integral indices. To make it work with a map-style
dataset with non-integral indices/keys, a custom sampler must be provided.
"""
def __getitem__(self, index) -> T_co:
raise NotImplementedError
def __add__(self, other: 'Dataset[T_co]') -> 'ConcatDataset[T_co]':
return ConcatDataset([self, other])
# No `def __len__(self)` default?
# See NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]
# in pytorch/torch/utils/data/sampler.py
而任何继承torch.utils.data.Dataset
的数据集类,必须重写__getitem__
方法,可以选择性重写__len__
方法(若要以该数据集类构建torch.utils.data.Sampler
或者torch.utils.data.DataLoader
,则必须重写__len__
方法)。__getitem__
方法的作用为,利用index获得数据集中该index对应的样本(这就要求该数据类中必须维持一个可以;),而__len__
方法的作用为返回数据集的样本数量。一个torch.utils.data.Dataset
子类的样例如下:
from torch.utils.data import Dataset
import pandas as pd
class MyDataset(Dataset):
def __init__(self, csv_file, txt_file, root_dir, other_file):
self.csv_data = pd.read_csv(csv_file)
with open(txt_file, 'r') as f:
data_list = f.readlines()
# 可利用索引下标进行取值的成员变量,list类型
self.txt_data = data_list
self.root_dir = root_dir
# 返回数据集的样本数量
def __len__(self):
return len(self.csv_data)
# 返回数据集中索引为idx的样本
def __getitem__(self, index):
data = (self.csv_data[index], self.txt_data[index])
return data
利用自定义的Dataset子类,可以将我们的数据集定义我们需要的数据类,然后通过迭代的方式利用index下标索引来获取数据集中的每一条样本数据。而数据集的batch取样和取样时的shuffle,则需要利用torch.utils.data.DataLoader
来实现。
1.2 torch.utils.data.DataLoader
首先需要明确一点,Dataset和DataLoader本质上都是iterable(可迭代对象),都可以实现数据集的迭代访问。而 torch.utils.data.DataLoader
相当于是Dataset(数据集)和Sampler(采样器)的组合,即可以在Dataset上进行迭代的自定义采样。同时,DataLoader还支持单进程或多进程加载,自定义加载顺序以及可选的自动批处理(整理)和memory pinning,它还支持 map风格的数据集对象,其参数具体解释如下(参数说明参考了这篇文章,并按照pytorch1.7.1的文档进行了修改):
- dataset(Dataset): 传入的数据集类
- batch_size(int, optional): 每个batch有多少个样本
- shuffle(bool, optional): 在每个epoch开始的时候,对数据进行重新排序(即随机采样)
- sampler(Sampler or Iterable, optional): 自定义从数据集中取样本的策略;如果指定这个参数,那么shuffle必须为False;该值可以为任何实现了
__len__
函数的Iterable对象 - batch_sampler(Sampler or Iterable, optional): 与sampler类似,但是一次只返回一个batch的indices(索引),需要注意的是,一旦指定了这个参数,那么batch_size,shuffle,sampler,drop_last就不能再指定了(互斥——Mutually exclusive)
- num_workers (int, optional):这个参数决定了有几个进程来处理data loading;0意味着所有的数据都会被load进主进程(默认为0)
- collate_fn (callable, optional): 一个函数,该函数的作用是将一个由样本构成的batch_size大小的list转换成mini-batch,该函数的输出即为迭代时获得的batch
- pin_memory (bool, optional): 如果设置为True,那么data loader将会在返回它们之前,将tensors拷贝到CUDA中的固定内存(CUDA pinned memory)中
- drop_last (bool, optional):如果设置为True:这个是对最后的未完成的batch来说的,比如你的batch_size设置为64,而一个epoch只有100个样本,那么训练的时候后面的36个就被扔掉了;如果为False(默认),那么会继续正常执行,只是最后的batch_size会小一点
- timeout(numeric, optional):如果是正数,表明等待从worker进程中收集一个batch等待的时间,若超出设定的时间还没有收集到,那就不收集这个内容了;这个numeric应总是大于等于0;默认为0
- worker_init_fn (callable, optional): 每个进程的初始化函数 If not None, this will be called on eachworker subprocess with the worker id (an int in [0, num_workers - 1]) as input, after seeding and before data loading. (default: None)
- prefetch_factor (int, optional, keyword-only arg):每个进程预先加载的样本数量。该值2意味着所有的进程预先加载了2 * num_workers个样本(默认为2)
- persistent_workers (bool, optional) :如果为True,则迭代完一次数据集后,DataLoader将不会关闭工作进程;这样可以使Worker Dataset实例保持活动状态(默认为False)
利用上一节定义的MyDataset
数据集类可以构建一个DataLoader对象:
from torch.utils.data import DataLoader
my_data_loader = DataLoader(myDataset, batch_size=32, shuffle=True)