当前位置: 首页 > article >正文

pytorch 自定义Dataset类

torch.utils.data.Dataset 是 PyTorch 数据处理模块中的一个核心类,用于表示一个数据集。通过继承和自定义 Dataset 类,用户可以轻松管理和加载各种类型的数据,如图像、文本、时间序列等。

1. Dataset 类的作用

Dataset 提供了一种标准接口,方便用户自定义数据加载逻辑,尤其是对于大型数据集。每个自定义的数据集类需要实现两个核心方法:

  • __len__():返回数据集中样本的数量。
  • __getitem__(index):根据给定的索引返回数据集中的一个样本(通常包括特征和标签)。

2. 自定义 Dataset

Dataset 是一个抽象类,因此你需要通过继承它来定义自己的数据集,并实现其中的 __len__ 和 __getitem__ 方法。以下是如何自定义一个简单的 Dataset 的示例。

示例代码
import torch
from torch.utils.data import Dataset

# 自定义数据集类,继承自 torch.utils.data.Dataset
class MyDataset(Dataset):
    def __init__(self, data, labels):
        # 初始化数据集,传入数据和标签
        self.data = data
        self.labels = labels

    def __len__(self):
        # 返回数据集中样本的数量
        return len(self.data)

    def __getitem__(self, idx):
        # 根据索引返回一个样本和其对应的标签
        sample = self.data[idx]
        label = self.labels[idx]
        return sample, label

# 示例数据
data = torch.randn(100, 3)  # 100 个样本,每个样本有 3 个特征
labels = torch.randint(0, 2, (100,))  # 100 个样本的标签,二分类(0 或 1)

# 创建数据集实例
dataset = MyDataset(data, labels)

# 访问数据集中的第一个样本
sample, label = dataset[0]
print("Sample:", sample)
print("Label:", label)

解释:

  • __init__(self, data, labels):构造函数中,我们将数据和标签传入并保存为类的成员变量。
  • __len__(self):返回数据集的样本数量。
  • __getitem__(self, idx):根据索引 idx,返回数据和标签。

3. 与 DataLoader 配合使用

自定义的 Dataset 类通常与 DataLoader 配合使用。DataLoader 提供了批量数据加载、打乱顺序、并行加载等功能。

from torch.utils.data import DataLoader

# 使用 DataLoader 加载数据集
dataloader = DataLoader(dataset, batch_size=4, shuffle=True)

# 迭代 DataLoader
for batch_data, batch_labels in dataloader:
    print(batch_data, batch_labels)

解释:

  • batch_size=4:每次加载 4 个样本。
  • shuffle=True:在每个 epoch 之前将数据打乱。

4. 常见的 Dataset 子类

PyTorch 提供了一些常用的 Dataset 子类,如:

  • torchvision.datasets:用于加载图像数据集(如 CIFAR、MNIST 等)。
  • torchtext.datasets:用于加载文本数据集(如 IMDB、WikiText 等)。
  • torch.utils.data.TensorDataset:将一对张量(如数据和标签)封装成一个数据集。

http://www.kler.cn/a/291635.html

相关文章:

  • el-form组件中的常用属性
  • 【Leetcode 热题 100】124. 二叉树中的最大路径和
  • [JAVA备忘录] Lambda 表达式简单介绍
  • C++进阶-2-STL
  • 一区牛顿-拉夫逊算法+分解+深度学习!VMD-NRBO-Transformer-GRU多变量时间序列光伏功率预测
  • 集成自然语言理解服务,让应用 “听得懂人话”
  • CAS单点登录说明文档
  • EdgeGallery:聚焦 5 大行业场景,MEC 开源平台将 5G 能力拓展到边缘_边缘mec平台
  • Redis 讲解运行模式:单机、主从、哨兵、集群使用场景和区别
  • OpenCV 100道面试题及参考答案(7万字长文)
  • 第4章 汇编语言和汇编软件
  • 【C++】static作用总结
  • 【ORACLE】独有的函数
  • 数据结构代码集训day13(适合考研、自学、期末和专升本)
  • 华为认证是什么?HCIA/HCIP/HCIE是什么?
  • Java8对接三方流式接口,并实时输出(GPT)
  • 【数据库】Oracle和Mysql的区别
  • 多媒体应用设计师需要掌握多种软件
  • 动态化-鸿蒙跨端方案介绍
  • C++day6
  • MySQL5.7.36之主从复制增强半同步复制-centos7
  • Linux下数据库相关知识点及SQLite3相关知识,和callback回调函数
  • 【区块链 + 供应链】长虹生产物料质量信息管理系统 | FISCO BCOS应用案例
  • 初始QT!
  • [线程]单例模式 及 指令重排序
  • CodeSys中动态切换3D模型