pytorch中的基础数据集
- 数据是深度学习核心之一
- pytorch基础数据集介绍
- 加载/读取/显示/使用
- 代码演示与解释
常见的数据集Pascal VOC/COCO
DataLoader
DataLoader( dataset,
- 含义:指定要加载的数据集,它必须是
torch.utils.data.Dataset
类的子类实例。Dataset
类定义了如何获取单个样本以及数据集的长度,DataLoader
会基于这个数据集进行数据的批量加载。
batch_size=1,
batch_size
:- 含义:每个批次加载的样本数量,默认值为 1。在训练神经网络时,通常会将多个样本组成一个批次一起输入到模型中,以提高训练效率和稳定性。
- 示例:
dataloader = DataLoader(dataset, batch_size=10)
shuffle=False,
- 含义:布尔类型参数,若设置为
True
,则在每个 epoch 开始时打乱数据集的顺序,有助于模型更好地学习数据的分布,提高模型的泛化能力,默认值为False
。 - 示例:
dataloader = DataLoader(dataset, batch_size=10, shuffle=True)
sampler=None,
sampler
:- 含义:自定义采样器,用于指定如何从数据集中采样样本。如果指定了
sampler
,则shuffle
参数将被忽略。采样器必须是torch.utils.data.Sampler
类的子类实例。
batch_sampler=None,
batch_sampler
:
- 含义:自定义批次采样器,用于指定如何将样本组合成批次。如果指定了
batch_sampler
,则batch_size
、shuffle
、sampler
和drop_last
参数将被忽略。批次采样器必须是torch.utils.data.BatchSampler
类的子类实例。
num_workers=0,
num_workers
:- 含义:用于数据加载的子进程数量,设置为大于 0 的值可以实现并行数据加载,提高数据加载速度。默认值为 0,表示在主进程中加载数据。在使用多进程加载时,需要注意内存和 CPU 资源的使用情况。
- 示例:
dataloader = DataLoader(dataset, batch_size=10, num_workers=4)
collate_fn=None,
collate_fn
:
- 含义:用于将多个样本组合成一个批次的函数。默认情况下,
DataLoader
会使用一个简单的默认函数将样本组合成批次,但在处理一些特殊的数据类型或结构时,可能需要自定义collate_fn
函数。
pin_memory=False, drop_last=False,
pin_memory
:- 含义:布尔类型参数,若设置为
True
,则会将数据样本在返回前固定到 CUDA 页锁定内存中,这样可以加快数据从 CPU 到 GPU 的传输速度,适用于使用 GPU 进行训练的场景,默认值为False
。 - 示例:
dataloader = DataLoader(dataset, batch_size=10, pin_memory=True)
- 含义:布尔类型参数,若设置为
timeout=0, worker_init_fn=None,
timeout
:- 含义:在多进程数据加载时,等待从工作进程获取数据的超时时间(单位:秒)。如果在指定的时间内没有获取到数据,则会抛出
TimeoutError
异常。默认值为 0,表示不设置超时时间。 - 示例:
dataloader = DataLoader(dataset, batch_size=10, num_workers=4, timeout=1)
- 含义:在多进程数据加载时,等待从工作进程获取数据的超时时间(单位:秒)。如果在指定的时间内没有获取到数据,则会抛出
multiprocessing_context=None,
multiprocessing_context
:
- 含义:指定多进程的上下文,可以是
'fork'
、'spawn'
或'forkserver'
等。不同的上下文适用于不同的操作系统和场景,默认情况下会根据操作系统自动选择合适的上下文。
generator=None,
generator
:
- 含义:用于生成随机数的生成器,可用于控制数据的打乱顺序等随机操作。可以传入一个
torch.Generator
实例,方便复现实验结果。
prefetch_factor=2,
prefetch_factor
:
- 含义:每个工作进程预取的样本数量,默认值为 2。在多进程数据加载时,工作进程会预先加载一定数量的样本,以提高数据加载的连续性。
persistent_workers=False )
persistent_workers
:
- 含义:布尔类型参数,若设置为
True
,则在每个 epoch 结束后,工作进程不会被销毁,而是保持存活状态,这样可以减少进程创建和销毁的开销,提高数据加载效率,默认值为False
。 -
- torch.utils.data.Dataset的子集
- • torch.utils.data.DataLoader加载数据集