当前位置: 首页 > article >正文

Pytorch:torch.utils.data.DataLoader()

如果读者正在从事深度学习的项目,通常大部分时间都花在了处理数据上,而不是神经网络上。因为数据就像是网络的燃料:它越合适,结果就越快、越准确!神经网络表现不佳的主要原因之一可能是由于数据不佳或理解不足。因此,以更直观的方式理解、预处理数据并将其加载到网络中非常重要。
参考:https://zhuanlan.zhihu.com/p/596730297

DataLoader加载和迭代数据集

Dataloader本质是一个迭代器对象,也就是可以通过for batch_idx,batch_dict in dataloader 来提取数据集,提取的数量由batch_size 参数决定,得到这一batch的数据后,就可以喂入网络开始训练或者推理了。
在迭代的过程中,dataloader会自动调用dataset中的__getitem__ 函数,以获取一帧数据(item)

from torch.utils.data import DataLoader

DataLoader(
    dataset,
    batch_size=1,
    shuffle=False,
    num_workers=0,
    collate_fn=None,
    pin_memory=False,
 )

以U-Net中的代码为例:
具体详见:U-Net代码复现

loader_args = dict(batch_size=batch_size, num_workers=os.cpu_count(), pin_memory=True)
train_loader = DataLoader(train_set, shuffle=True, **loader_args)
val_loader = DataLoader(val_set, shuffle=False, drop_last=True, **loader_args)

1. 数据集

**dataset (Dataset) ** – dataset from which to load the data.
即自定义的数据集,非常重要,因为dataloader会调用dataset的一些重载函数(e.g. getitem && len )

2. 对数据进行批处理

batch_size (int, optional)how many samples per batch to load(default: 1).

3. 在 CUDA 张量上加载数据

pin_memory(bool, optional)If True, the data loader will copy Tensors into device/CUDA pinned memory before returning them. If your data elementsare a custom type, or your collate_fn returns a batch that is a custom type,see the example below.

pin_memory参数直接将数据集加载为 CUDA 张量。它是一个可选参数,接受一个布尔值;如果设置为True,会在返回张量之前张量复制到 CUDA 固定内存中。这样在GPU训练过程中,数据从内存到GPU的复制可以使用异步的方式进行,从而提高数据读取的效率。

通常情况下,当使用GPU训练模型时,数据读取会成为整个训练过程的瓶颈之一。使用pin_memory可以将数据在CPU和GPU之间进行传输时的复制时间减少,从而提高数据加载的速度,加速训练过程。

需要注意的是,使用pin_memory会占用更多的内存空间,因此在内存资源紧张的情况下,需要谨慎使用。同时,在某些情况下(例如数据集比较小的情况下),使用pin_memory并不会带来明显的加速效果。

4.允许多进程

num_workers (int, optional)how many subprocesses to use for dataloading. 0 means that the data will be loaded in the main process.(default: 0)
这也是一个很有意思的参数,按照官方的说法, num_workers 用于设置数据加载过程中使用的子进程数。其默认值为0,即在主进程中进行数据加载,而不使用额外的子进程。

以下是我看到的一个解释,原文链接:https://blog.csdn.net/vonct/article/details/130263743
下面说一下个人的理解,在初始化 dataloader对象时,会根据num_workers创建子线程用于加载数据(主线程数+子线程=num_workers)。每个worker或者说线程都有自己负责的dataset范围(下面统称worker)

每当迭代 dataloader 对象时,工人们(workers)就开始干活了:将数据从数据源(如硬盘)加载到内存(数据加载),当一个worker读取(调用__getitem__)到足够的数据(看你在dataset中怎么定义一个item了)后,会将这些数据封装成一个(即一帧),并将其放到该worker独有的内存队列中。 要注意的是,每次迭代时,worker会尽可能地读数据,直到自己的队列被填满。

当所有workers的队列都被填满时,一个名为sampler的线程将会被创建,它的作用就是收集各workers队列中队首的 ,把他们放到一个各线程共享内存的缓冲队列中,并调用 collate_fn 函数来将 batch_size 个 整合,最后返回给迭代的输出。

这时候大家肯定会有点疑惑,那当迭代到后期时,需要读取的样本都已经在队列中了,是不是意味着这时候工人们已经在休息了?根据chatgpt的回答:是的!下面以一张图来帮助大家理解

在这里插入图片描述

5.合并数据集

collate_fn (Callable, optional)merges a list of samples to form a mini-batch of Tensor(s). Used when using batched loading from a map-style dataset.

整合多个样本到一个batch时需要调用的函数,当 getitem 返回的不是tensor而是字典之类时,需要进行 collate_fn的重载,同时可以进行数据的进一步处理以满足pytorch的输入要求。
以U-Net为例:

def __getitem__(self, idx):
        name = self.ids[idx]
        mask_file = list(self.mask_dir.glob(name + self.mask_suffix + '.*'))
        img_file = list(self.images_dir.glob(name + '.*'))

        assert len(img_file) == 1, f'Either no image or multiple images found for the ID {name}: {img_file}'
        assert len(mask_file) == 1, f'Either no mask or multiple masks found for the ID {name}: {mask_file}'
        mask = load_image(mask_file[0])
        img = load_image(img_file[0])

        assert img.size == mask.size, \
            f'Image and mask {name} should be the same size, but are {img.size} and {mask.size}'

        img = self.preprocess(self.mask_values, img, self.scale, is_mask=False)
        mask = self.preprocess(self.mask_values, mask, self.scale, is_mask=True)

        return {
            'image': torch.as_tensor(img.copy()).float().contiguous(),
            'mask': torch.as_tensor(mask.copy()).long().contiguous()
        }

getitem 返回的是一个包含image和mask的 data_dict 字典,这时候就需要调用自定义的collate_fn来进行打包(待补充。。。)

6.数据采样

sampler (Sampler or Iterable, optional) – defines the strategy to draw samples from the dataset. Can be any Iterable with len implemented. If specified, shufflemust not be specified.

sampler的主要作用是控制样本的采样顺序,并提供样本的索引。在默认情况下,dataloader使用的是SequentialSampler,它按照数据集的顺序依次提取样本,但在某些情况下,我们可能需要自定义采样顺序。比如说想从队尾提取数据。

比如,当我们处理非常大的数据集时,为了提高训练效率,可能需要对数据进行分布式采样,这时候就需要使用DistributedSampler。DistributedSampler会将数据集划分成多个子集,每个子集分配给不同的进程进行采样。在这种情况下,如果使用默认的SequentialSampler,可能会导致各个进程采样到相同的数据,从而降低训练效率。

此外,还有一些自定义的sampler,比如随机采样器(RandomSampler)和加权采样器(WeightedRandomSampler),它们可以按照不同的采样策略对数据集进行采样,从而满足不同的训练需求。

因此,根据不同的训练需求,我们可能需要自定义sampler来控制数据的采样顺序。

原文链接:https://blog.csdn.net/vonct/article/details/130263743


http://www.kler.cn/a/149360.html

相关文章:

  • influxDB 时序数据库安装 flux语法 restful接口 nodjsAPI
  • 【JVM】关于JVM的内部原理你到底了解多少(八股文面经知识点)
  • 【QT】QSS
  • 事件循环 -- 资源总结(浏览器进程模型、事件循环机制、练习题)
  • 重构代码之内联临时变量
  • 「数据要素」行业简报|2024.11.上刊
  • 系列五、Spring整合MyBatis不忽略mapper接口同目录的xxxMapper.xml
  • 搜索引擎语法
  • Alibaba Java诊断工具Arthas查看Dubbo动态代理类
  • 【古月居《ros入门21讲》学习笔记】14_参数的使用与编程方法
  • 你知道显卡型号上的数字是什么意思吗?数字越大就越好吗?
  • 34.基于webpack搭建开发环境
  • ground truth 在深度学习任务中代表的是什么意思?
  • 第二证券:机构密集调研消费电子、半导体产业链
  • 三大录屏软件推荐,让你轻松录制屏幕
  • Vue实现纯前端导入excel数据
  • FFmpeg介绍
  • PHPStudy开发环境解决:启动报错 class websocket/server not found
  • 【问题系列】消费者与MQ连接断开问题解决方案(二)
  • Python的哈希映射:字典
  • API网关
  • Java中的mysql——面试题+答案——第24期
  • 苹果提醒事项怎么用?几个简单步骤就能学会!
  • Hadoop集群升级(3.1.3 -> 3.2.4)
  • 图表控件LightningChart .NET中文教程 - 如何创建WPF 2D热图?(二)
  • C#中的async/await异步编程模型