当前位置: 首页 > article >正文

【Pytorch】torch.utils.data模块

        torch.utils.data模块主要用于进行数据集处理,是常用的一个包。在构建数据集的过程中经常会用到。要使用data函数必须先导入:

from torch.utils import data

       下面介绍几个经常使用到的类。   

torch.utils.data.DataLoader

DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
           batch_sampler=None, num_workers=0, collate_fn=None,
           pin_memory=False, drop_last=False, timeout=0,
           worker_init_fn=None, *, prefetch_factor=2,
           persistent_workers=False)

        DataLoader构造函数最重要的参数是 dataset,它指示要从中加载数据的数据集对象。PyTorch 支持两种不同类型的数据集——映射式数据集和可迭代式数据集。

        映射式数据集是Dataset 子类的实例,它实现了 __getitem__() 和 __len__() 协议,它表示从索引/键值到数据样本的映射。例如,当使用 dataset[idx] 访问此类数据集时,它可以从磁盘上的文件夹中读取第 idx 幅图像及其对应的标签。

        可迭代式数据集是IterableDataset 子类的实例,它实现了 __iter__() 协议,并表示数据样本上的可迭代对象。这种类型的数据集特别适合随机读取代价高昂甚至不可能的情况,以及批大小取决于获取的数据的情况。例如,当调用 iter(dataset) 时,此类数据集可以返回从数据库、远程服务器甚至实时生成的日志中读取的数据流。

torch.utils.data.Dataset

        表示一个Dataset的抽象类。所有表示键到数据样本映射的数据集都应该继承它。所有子类都应该重写__getitem__(),支持为给定键获取数据样本。子类还可以选择性地重写__len__(),许多Sampler实现和DataLoader的默认选项都期望它返回数据集的大小。子类还可以选择性地实现__getitems__(),以加速批量样本加载。此方法接受批量样本索引列表并返回样本列表。

代码运用示例:

import torch
from torch.utils.data import Dataset, DataLoader

# 自定义数据集
class SimpleDataset(Dataset):
    def __init__(self, data, labels):
        """
        Args:
            data (list or tensor): 输入数据
            labels (list or tensor): 数据对应的标签
        """
        self.data = torch.tensor(data, dtype=torch.float32)  # 转为张量
        self.labels = torch.tensor(labels, dtype=torch.long)  # 转为张量

    def __len__(self):
        """返回数据集的大小"""
        return len(self.data)

    def __getitem__(self, idx):
        """根据索引返回一个样本"""
        return self.data[idx], self.labels[idx]

# 创建数据和标签
data = [1, 2, 3, 4, 5]
labels = [0, 1, 0, 1, 0]

# 实例化数据集
dataset = SimpleDataset(data, labels)

# 用 DataLoader 加载数据
dataloader = DataLoader(dataset, batch_size=2, shuffle=True)

# 遍历 DataLoader
for batch_data, batch_labels in dataloader:
    print("Data:", batch_data)
    print("Labels:", batch_labels)

运行结果:(顺序会随着Shuffle=True发生变化)

torch.utils.data.IterableDataset

        一个可迭代的数据集。所有表示数据样本可迭代的数据集都应该继承它。当数据来自流时,这种形式的数据集特别有用。所有子类都应该重写__iter__(),它将返回此数据集中样本的迭代器。当子类与DataLoader一起使用时,数据集中的每个项目都将从DataLoader迭代器中产生。当num_workers > 0时,每个工作进程将拥有数据集对象的副本,因此通常希望独立配置每个副本以避免工作进程返回重复的数据。get_worker_info()在工作进程中调用时,返回有关工作进程的信息。它可以在数据集的__iter__()方法或DataLoader的worker_init_fn选项中使用来修改每个副本的行为。

代码运用示例:

import torch
from torch.utils.data import IterableDataset, DataLoader

# 自定义 IterableDataset
class NumberStreamDataset(IterableDataset):
    def __init__(self, start, end):
        """
        Args:
            start (int): 起始值
            end (int): 结束值
        """
        self.start = start
        self.end = end

    def __iter__(self):
        """
        定义数据生成逻辑,返回一个迭代器
        """
        for num in range(self.start, self.end):
            yield num

# 创建一个数据集实例
dataset = NumberStreamDataset(start=0, end=10)

# 用 DataLoader 加载数据
dataloader = DataLoader(dataset, batch_size=3)

# 遍历 DataLoader
for batch in dataloader:
    print(batch)

运行结果:

torch.utils.data.TensorDataset(*tensors)

        包装张量的数据集。每个样本将通过沿第一个维度索引张量来检索。参数*tensors (张量)表示第一个维度大小相同的张量。

代码运用示例:

import torch
from torch.utils.data import TensorDataset, DataLoader

# 创建输入张量和标签张量
data = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]])
labels = torch.tensor([0, 1, 0, 1])

# 使用 TensorDataset 封装数据
dataset = TensorDataset(data, labels)

# 使用 DataLoader 加载数据
dataloader = DataLoader(dataset, batch_size=2, shuffle=True)

# 遍历 DataLoader
for batch_data, batch_labels in dataloader:
    print("Batch data:", batch_data)
    print("Batch labels:", batch_labels)

运行结果:(顺序会随着shuffle=True而发生变化)

torch.utils.data.ConcatDataset(datasets)

        将多个数据集连接起来的数据集。此类用于组装不同的现有数据集。参数datasets (序列) 表示要连接的数据集列表

代码运用示例:

import torch
from torch.utils.data import TensorDataset, ConcatDataset, DataLoader

# 创建两个数据集
data1 = torch.tensor([[1.0, 2.0], [3.0, 4.0]])
labels1 = torch.tensor([0, 1])
dataset1 = TensorDataset(data1, labels1)

data2 = torch.tensor([[5.0, 6.0], [7.0, 8.0]])
labels2 = torch.tensor([1, 0])
dataset2 = TensorDataset(data2, labels2)

# 使用 ConcatDataset 拼接两个数据集
concat_dataset = ConcatDataset([dataset1, dataset2])

# 用 DataLoader 加载数据
dataloader = DataLoader(concat_dataset, batch_size=2, shuffle=True)

# 遍历 DataLoader
for batch_data, batch_labels in dataloader:
    print("Batch data:", batch_data)
    print("Batch labels:", batch_labels)

运行结果:

torch.utils.data.Subset(datasetindices)

        指定索引处数据集的子集。参数dataset (Dataset)表示整个数据集,indices (序列) – 为子集选择的整个集合中的索引。

代码运用示例:

import torch
from torch.utils.data import TensorDataset, Subset, DataLoader

# 创建一个原始数据集
data = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]])
labels = torch.tensor([0, 1, 0, 1])
dataset = TensorDataset(data, labels)

# 使用 Subset 提取索引为 [1, 3] 的样本
indices = [1, 3]
subset = Subset(dataset, indices)

# 用 DataLoader 加载子集
dataloader = DataLoader(subset, batch_size=1)

# 遍历 DataLoader
for batch_data, batch_labels in dataloader:
    print("Batch data:", batch_data)
    print("Batch labels:", batch_labels)

运行结果:


http://www.kler.cn/a/403211.html

相关文章:

  • Java LinkedList 详解
  • Spark RDD(弹性分布式数据集)的深度理解
  • 一键部署 200+ 开源软件的 Websoft9 面板,Github 2k+ 星星
  • python操作selenium的简单封装
  • pytorch torch.sign() 方法介绍
  • 第十九天 决策树与随机森林
  • .NET 9与C# 13革新:新数据类型与语法糖深度解析
  • 【课堂笔记】隐私计算实训营第四期:匿踪查询PIR
  • 【软件测试】自动化常用函数
  • 拼多多式社交裂变在欧美市场的困境与突破:Web3 增长的新思考
  • Spring Boot核心概念:应用配置
  • 企事业单位的敏感数据怎么保护比较安全?
  • 嵌入式学习-C嘎嘎-Day03
  • 单片机学习笔记 1. 点亮一个LED灯
  • 创建型设计模式(模版方法、观察者模式、策略模式)
  • 网络安全实施方案
  • 关联度分析、灰色预测GM(1,1)、GM(1,1)残差模型——基于Python实现
  • 类和对象——static 成员,匿名对象(C++)
  • OAI-5G开源通信平台实践(三)
  • linux 软连接的使用
  • tensorflow有哪些具体影响,和chatgpt有什么关系
  • [Unity]【游戏相关】 游戏设计基础:如何创建有效的游戏设计文档
  • C++常用库
  • Git错误:gnutls_handshake() failed: The TLS connection was non-properly terminated
  • mybatis的动态sql用法之排序
  • 同三维T80003JEHS 4K/60帧HDMI/SDI超高清H.265解码器