当前位置: 首页 > article >正文

Pytorch 自学笔记(三):利用自定义文本数据集构建Dataset和DataLoader

Pytorch 自学笔记(三)

  • 1. Dataset与DataLoader
    • 1.1 torch.utils.data.Dataset
    • 1.2 torch.utils.data.DataLoader

Pytorch 自学笔记系列的第三篇。针对Pytorch的Dataset和DataLoader进行简单的介绍,同时,介绍如何使用自定义文本数据集构建Dataset和DataLoader,以实现数据集的随机采样与batch加载。(注:文中代码使用python3.7和pytorch1.7.1编写)

1. Dataset与DataLoader

1.1 torch.utils.data.Dataset

torch.utils.data.Dataset是pytorch中定义的数据集抽象类,pytorch中任何的数据集类都必须继承并重写这个类,其源码如下:

class Dataset(Generic[T_co]):
    r"""An abstract class representing a :class:`Dataset`.

    All datasets that represent a map from keys to data samples should subclass
    it. All subclasses should overwrite :meth:`__getitem__`, supporting fetching a
    data sample for a given key. Subclasses could also optionally overwrite
    :meth:`__len__`, which is expected to return the size of the dataset by many
    :class:`~torch.utils.data.Sampler` implementations and the default options
    of :class:`~torch.utils.data.DataLoader`.

    .. note::
      :class:`~torch.utils.data.DataLoader` by default constructs a index
      sampler that yields integral indices.  To make it work with a map-style
      dataset with non-integral indices/keys, a custom sampler must be provided.
    """

    def __getitem__(self, index) -> T_co:
        raise NotImplementedError

    def __add__(self, other: 'Dataset[T_co]') -> 'ConcatDataset[T_co]':
        return ConcatDataset([self, other])

    # No `def __len__(self)` default?
    # See NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]
    # in pytorch/torch/utils/data/sampler.py

而任何继承torch.utils.data.Dataset的数据集类,必须重写__getitem__方法,可以选择性重写__len__方法(若要以该数据集类构建torch.utils.data.Sampler或者torch.utils.data.DataLoader,则必须重写__len__方法)。__getitem__方法的作用为,利用index获得数据集中该index对应的样本(这就要求该数据类中必须维持一个可以;),而__len__方法的作用为返回数据集的样本数量。一个torch.utils.data.Dataset子类的样例如下:

from torch.utils.data import Dataset
import pandas as pd


class MyDataset(Dataset):
    def __init__(self, csv_file, txt_file, root_dir, other_file):
        self.csv_data = pd.read_csv(csv_file)
        with open(txt_file, 'r') as f:
            data_list = f.readlines()

        # 可利用索引下标进行取值的成员变量,list类型
        self.txt_data = data_list
        self.root_dir = root_dir

    # 返回数据集的样本数量
    def __len__(self):
        return len(self.csv_data)

    # 返回数据集中索引为idx的样本
    def __getitem__(self, index):
        data = (self.csv_data[index], self.txt_data[index])
        return data
        

利用自定义的Dataset子类,可以将我们的数据集定义我们需要的数据类,然后通过迭代的方式利用index下标索引来获取数据集中的每一条样本数据。而数据集的batch取样和取样时的shuffle,则需要利用torch.utils.data.DataLoader来实现。

1.2 torch.utils.data.DataLoader

首先需要明确一点,Dataset和DataLoader本质上都是iterable(可迭代对象),都可以实现数据集的迭代访问。而 torch.utils.data.DataLoader相当于是Dataset(数据集)和Sampler(采样器)的组合,即可以在Dataset上进行迭代的自定义采样。同时,DataLoader还支持单进程或多进程加载,自定义加载顺序以及可选的自动批处理(整理)和memory pinning,它还支持 map风格的数据集对象,其参数具体解释如下(参数说明参考了这篇文章,并按照pytorch1.7.1的文档进行了修改):

  1. dataset(Dataset): 传入的数据集类
  2. batch_size(int, optional): 每个batch有多少个样本
  3. shuffle(bool, optional): 在每个epoch开始的时候,对数据进行重新排序(即随机采样)
  4. sampler(Sampler or Iterable, optional): 自定义从数据集中取样本的策略;如果指定这个参数,那么shuffle必须为False;该值可以为任何实现了__len__函数的Iterable对象
  5. batch_sampler(Sampler or Iterable, optional): 与sampler类似,但是一次只返回一个batch的indices(索引),需要注意的是,一旦指定了这个参数,那么batch_size,shuffle,sampler,drop_last就不能再指定了(互斥——Mutually exclusive)
  6. num_workers (int, optional):这个参数决定了有几个进程来处理data loading;0意味着所有的数据都会被load进主进程(默认为0)
  7. collate_fn (callable, optional): 一个函数,该函数的作用是将一个由样本构成的batch_size大小的list转换成mini-batch,该函数的输出即为迭代时获得的batch
  8. pin_memory (bool, optional): 如果设置为True,那么data loader将会在返回它们之前,将tensors拷贝到CUDA中的固定内存(CUDA pinned memory)中
  9. drop_last (bool, optional):如果设置为True:这个是对最后的未完成的batch来说的,比如你的batch_size设置为64,而一个epoch只有100个样本,那么训练的时候后面的36个就被扔掉了;如果为False(默认),那么会继续正常执行,只是最后的batch_size会小一点
  10. timeout(numeric, optional):如果是正数,表明等待从worker进程中收集一个batch等待的时间,若超出设定的时间还没有收集到,那就不收集这个内容了;这个numeric应总是大于等于0;默认为0
  11. worker_init_fn (callable, optional): 每个进程的初始化函数 If not None, this will be called on eachworker subprocess with the worker id (an int in [0, num_workers - 1]) as input, after seeding and before data loading. (default: None)
  12. prefetch_factor (int, optional, keyword-only arg):每个进程预先加载的样本数量。该值2意味着所有的进程预先加载了2 * num_workers个样本(默认为2)
  13. persistent_workers (bool, optional) :如果为True,则迭代完一次数据集后,DataLoader将不会关闭工作进程;这样可以使Worker Dataset实例保持活动状态(默认为False)

利用上一节定义的MyDataset数据集类可以构建一个DataLoader对象:

from torch.utils.data import DataLoader

my_data_loader = DataLoader(myDataset, batch_size=32, shuffle=True)

http://www.kler.cn/a/510944.html

相关文章:

  • 【Django开发】django美多商城项目完整开发4.0第12篇:商品部分,表结构【附代码文档】
  • 《自动驾驶与机器人中的SLAM技术》ch4:基于预积分和图优化的 GINS
  • 深度学习 Pytorch 基本优化思想与最小二乘法
  • 周末总结(2024/01/18)
  • SpringBoot多级配置文件
  • 【C++】面试题整理(未完待续)
  • Qt——界面优化
  • Windows电脑安装File Browser与cpolar轻松搭建本地云盘
  • Vscode:问题解决办法 及 Tips 总结
  • Go语言简洁框架目录和高效的快发框架设计
  • Tomcat下载配置
  • AI agent 在 6G 网络应用,无人机群控场景
  • 安全策略配置实验
  • postgresql链接详解
  • window.location.href 与form method=post 一起使用时需要注意这个问题
  • 全自动化河道水位监测系统:实时传输与远程监控
  • MySQL基于gtid的主从同步配置
  • Mono里运行C#脚本29—mono_trampolines_init
  • 管理口令安全和资源(一)
  • Java锁 可重入锁(递归锁) 深入源码解析 ReentrantLock synchronized
  • Linux TFTP 使用
  • 第38天:Web开发-JS应用NodeJS原型链污染文件系统Express模块数据库通讯审计
  • C语言之文本加密程序设计
  • Three.js贴图加载与环境遮蔽贴图强度设置(五)
  • 【Java回顾】Day7 Java IO|分类(传输方式,数据操作)|零拷贝和NIO
  • Linux 创建用户