当前位置: 首页 > article >正文

`torch.utils.data`模块

在PyTorch中,torch.utils.data模块提供了许多有用的工具来处理和加载数据。以下是对您提到的DataLoader, Subset, BatchSampler, SubsetRandomSampler, 和 SequentialSampler的详细解释以及使用示例。

1. DataLoader

DataLoader是PyTorch中用于加载数据的一个非常重要的类。它封装了数据集(Dataset),并提供了一个可迭代的对象,支持批量加载、打乱数据、多进程数据加载等功能。

示例代码

from torch.utils.data import DataLoader, TensorDataset
import torch

# 假设我们有一些数据
data = torch.randn(100, 3)  # 100个样本,每个样本3个特征
labels = torch.randint(0, 2, (100,))  # 100个标签,每个标签是0或1

# 创建数据集
dataset = TensorDataset(data, labels)

# 创建DataLoader
dataloader = DataLoader(dataset, batch_size=10, shuffle=True)

# 遍历DataLoader
for data_batch, label_batch in dataloader:
    print(data_batch.shape)  # 应为torch.Size([10, 3])
    print(label_batch.shape)  # 应为torch.Size([10])

2. Subset

Subset是一个用于从数据集中选择特定索引的子集的类。这对于分割数据集为训练集、验证集和测试集非常有用。

示例代码

from torch.utils.data import Subset

# 假设dataset是之前创建的TensorDataset
# 选择索引为0到49的样本作为训练集
indices = list(range(50))
train_subset = Subset(dataset, indices)

# 现在train_subset只包含前50个样本
train_dataloader = DataLoader(train_subset, batch_size=10, shuffle=True)

3. BatchSampler

BatchSampler用于从给定的样本列表中批量地采样索引。这允许用户自定义每个batch的采样方式。

示例代码

from torch.utils.data.sampler import BatchSampler, SequentialSampler

# 假设indices是包含所有样本索引的列表
indices = list(range(100))
batch_sampler = BatchSampler(sampler=SequentialSampler(indices), batch_size=10, drop_last=False)

# batch_sampler将返回索引的列表,每个列表代表一个batch
for batch_indices in batch_sampler:
    print(batch_indices)  # 输出形如[0, 1, 2, ..., 9]的列表

4. SubsetRandomSampler

SubsetRandomSampler用于从指定的索引列表中随机采样,但保证每个元素只被采样一次(除非指定了replacement=True)。

示例代码

from torch.utils.data.sampler import SubsetRandomSampler

# 假设indices是包含所有样本索引的列表
indices = list(range(100))
subset_sampler = SubsetRandomSampler(indices)

# subset_sampler可以传递给DataLoader来打乱数据
dataloader = DataLoader(dataset, batch_size=10, sampler=subset_sampler)

5. SequentialSampler

SequentialSampler简单地按照给定的索引顺序来采样。这通常用于不需要打乱数据的场景。

示例代码(已在BatchSampler示例中展示):

from torch.utils.data.sampler import SequentialSampler

# 假设indices是包含所有样本索引的列表
indices = list(range(100))
sampler = SequentialSampler(indices)

# sampler可以传递给DataLoader,但通常不需要显式创建SequentialSampler,
# 因为DataLoader的shuffle=False参数已经实现了相同的功能。

这些工具结合起来可以非常灵活地处理PyTorch中的数据加载和采样任务。


http://www.kler.cn/a/323601.html

相关文章:

  • 深入解析 MySQL 数据库:数据库时区问题
  • 供应链管理、一件代发系统功能及源码分享 PHP+Mysql
  • Easyui ComboBox 数据加载完成之后过滤数据
  • Flink Job更新和恢复
  • 创建vue插件,发布npm
  • 【WPF】Prism学习(二)
  • PostgreSQL 向量扩展插件pgvector安装和使用
  • 高等数学 第11讲 多元函数偏导数的计算与应用_复合函数求偏导_隐函数求偏导_条件极值
  • 什么是原生IP?
  • QT+ESP8266+STM32项目构建三部曲二--阿里云云端处理之云产品流转
  • 学习threejs,绘制二维线
  • 洛谷P1197.星球大战
  • 一道简单的css动态宽度问题?
  • List 循环遍历删除元素
  • 精通推荐算法31:行为序列建模之ETA — 基于SimHash实现检索索引在线化
  • rtsp 协议推流接收(tcp udp)
  • 【深度学习】(9)--调整学习率
  • 后端返回内容有换行标识,前端如何识别换行
  • Linux:LCD驱动开发
  • MySQL:进阶巩固-存储过程
  • 经典Python应用库一览
  • 智慧防灾,科技先行:EasyCVR平台助力地质灾害视频监测系统建设
  • VSCode配置C/C++开发环境
  • MMD模型及动作一键完美导入UE5-Blender方案(三)
  • c++反汇编逆向还原——for循环(笔记)
  • 全景可视化特点+可视化功能实现