当前位置: 首页 > article >正文

PyTorch使用教程(7)-数据集处理

1、基础概念

在PyTorch中,torch.utils.data模块是处理数据集和数据加载的核心工具。以下是该模块中一些基础概念的理解:
在这里插入图片描述

1.1 Dataset

  • 定义:Dataset是一个抽象类,用于表示数据集。用户需要通过继承Dataset类并实现其__len__和__getitem__方法来创建自定义的数据集。

  • 功能:Dataset定义了数据集的内容,它相当于一个类似列表的数据结构,具有确定的长度,并能够用索引获取数据集中的元素。

  • 类型:Dataset主要分为两种类型:map-style和iterable-style。map-style数据集需要实现__getitem__和__len__方法,而iterable-style数据集则需要实现__iter__方法。

from typing import Generic, TypeVar, List

_T_co = TypeVar('_T_co', covariant=True)

class Dataset(Generic[_T_co]):
    
    def __getitem__(self, index: int) -> _T_co:
        raise NotImplementedError("Subclasses of Dataset should implement __getitem__.")

    def __len__(self) -> int:
        raise NotImplementedError("Subclasses of Dataset should implement __len__.")

    def __add__(self, other: "Dataset[_T_co]") -> "ConcatDataset[_T_co]":
        """
        Adds two datasets. This can be useful when you have two datasets with potentially
        overlapping elements and you want to treat the elements as distinct.
        """
        from .dataset_ops import ConcatDataset
        return ConcatDataset([self, other])

1.2 DataLoader

  • 定义:DataLoader是一个迭代器,用于封装Dataset,并提供一个可迭代对象,方便进行批量加载、数据打乱、并行加载等操作。
  • 功能:DataLoader能够控制batch的大小、batch中元素的采样方法,以及将batch结果整理成模型所需输入形式的方法。
  • 参数:常用的参数包括dataset(表示要加载的数据集对象)、batch_size(表示每个batch的大小)、shuffle(表示是否在每个epoch开始时打乱数据)、num_workers(表示用于数据加载的进程数)等。
DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
           batch_sampler=None, num_workers=0, collate_fn=None,
           pin_memory=False, drop_last=False, timeout=0,
           worker_init_fn=None, *, prefetch_factor=2,
           persistent_workers=False)

1.3 Sampler

  • 定义:Sampler是一个抽象类,用于从数据集中生成索引。
  • 功能:Sampler的作用是在Dataset上面进行抽样,抽样的方式有多种,如按顺序抽样、随机抽样、在子集合中随机抽样、带权重的抽样等。
  • 类型:包括SequentialSampler、RandomSampler、SubsetRandomSampler、WeightedRandomSampler、BatchSampler等。

1.4 Batching

  • 定义:Batching是指将数据集分成多个小批次(batch)进行处理的过程。
  • 功能:Batching可以提高数据处理的效率,并有助于模型训练过程中的梯度更新和收敛。
  • 实现:通过DataLoader的batch_size参数来实现批量加载。

1.5 Shuffling

  • 定义:Shuffling是指在每个epoch开始时打乱数据集中的元素顺序。
  • 功能:Shuffling有助于提高模型的泛化能力,防止模型对数据的顺序产生依赖。
  • 实现:通过DataLoader的shuffle参数来启用数据打乱功能。

1.6 Multi-process Data Loading

  • 定义:Multi-process Data Loading是指使用多个进程来并行加载数据的过程。
  • 功能:Multi-process Data Loading可以显著提高数据加载的速度,尤其是在处理大规模数据集时。
  • 实现:通过DataLoader的num_workers参数来设置并行加载的进程数。

2、创建数据集

在PyTorch中,创建数据集通常涉及继承torch.utils.data.Dataset类并实现其必需的方法。以下是一个详细的步骤指南,用于创建自定义数据集:

  1. 导入必要的库
    首先,确保你已经导入了PyTorch和其他可能需要的库。
import torch
from torch.utils.data import Dataset
  1. 继承Dataset类
    创建一个新的类,继承自Dataset。
class CustomDataset(Dataset):
    def __init__(self, data, labels, transform=None):
        # 初始化数据集,存储数据和标签
        self.data = data
        self.labels = labels
        self.transform = transform
        
        # 确保数据和标签的长度相同
        assert len(self.data) == len(self.labels), "Data and labels must have the same length"

    def __len__(self):
        # 返回数据集的大小
        return len(self.data)

    def __getitem__(self, idx):
        # 根据索引获取数据和标签
        sample = self.data[idx]
        label = self.labels[idx]
        
        # 如果定义了转换,则应用转换
        if self.transform:
            sample = self.transform(sample)
            
        return sample, label
  1. 准备数据和标签
    在创建CustomDataset实例之前,你需要准备好数据和标签。这些数据可以是图像、文本、数值等,具体取决于你的任务。
# 假设你有一些数据和标签(这里只是示例)
data = [torch.randn(3, 32, 32) for _ in range(100)]  # 100个3x32x32的随机图像
labels = [torch.tensor(i % 2) for i in range(100)]   # 100个标签,0或1
  1. 创建数据集实例
    使用你准备好的数据和标签来创建CustomDataset的实例。
dataset = CustomDataset(data, labels)
  1. (可选)应用转换

如果你需要对数据进行预处理或增强,可以定义一个转换函数,并在创建数据集实例时传递给它。

# 定义一个简单的转换函数(例如,将图像数据标准化)
def normalize(sample):
    return (sample - sample.mean()) / sample.std()

# 创建数据集实例时应用转换
dataset = CustomDataset(data, labels, transform=normalize)
  1. 使用DataLoader加载数据
    最后,使用torch.utils.data.DataLoader来加载数据集,以便进行批量处理、打乱数据等。
from torch.utils.data import DataLoader

dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

# 现在你可以遍历dataloader来加载数据了
for batch_data, batch_labels in dataloader:
    # 在这里进行模型训练或评估
    pass

注意事项

  • 确保你的数据和标签是可索引的,通常它们应该是列表、NumPy数组或PyTorch张量。
  • 如果你的数据是图像,并且存储在文件系统中,你可能需要在__getitem__方法中实现图像读取和预处理逻辑。
  • 对于大型数据集,考虑使用torchvision.datasets中提供的预定义数据集类,它们通常包含了常见的图像数据集(如CIFAR、MNIST等)的加载逻辑。
  • 如果数据集太大无法全部加载到内存中,你可以考虑使用torch.utils.data.IterableDataset来创建一个可迭代的数据集,这样你就可以按需加载数据了。

http://www.kler.cn/a/509965.html

相关文章:

  • nginx 配置防爬虫
  • Flink开发中的优化方案
  • 机器学习-基本术语
  • 第十一章 图论
  • 基于本地消息表实现分布式事务
  • 【17】Word:林楚楠-供应链❗
  • 2.7 实战项目: GitHub openai-quickstart
  • RocketMQ源码分析之事务消息分析
  • kubernetes v1.29.XX版本HPA、KPA、VPA并压力测试
  • Json转换类型报错问题:java.lang.Integer cannot be cast to java.math.BigDecimal
  • 记录一次关于spring映射postgresql的jsonb类型的转化器事故,并使用hutool的JSONArray完成映射
  • Leetcode - 周赛432
  • leetcode34-排序数组中查找数组的第一个和最后一个位置
  • Learning Prompt
  • Kubernetes (K8s) 权限管理指南
  • 【Linux】15.Linux进程概念(4)
  • linux 安装jdk1.8
  • 【脑机接口数据处理】bdf文件转化mat文件
  • AI Prompt 设计指南:从基础构建到高质量生成的全面解析
  • h5使用video播放时关掉vant弹窗视频声音还在后台播放
  • Centos7将/dev/mapper/centos-home磁盘空间转移到/dev/mapper/centos-root
  • 分布式CAP理论介绍
  • Dart语言
  • 计算机视觉语义分割——U-Net(Convolutional Networks for Biomedical Image Segmentation)
  • 【视觉惯性SLAM:十六、 ORB-SLAM3 中的多地图系统】
  • 深入探索Go语言中的临时对象池:sync.Pool