当前位置: 首页 > article >正文

TensorDataser和DataLoader的解释与使用

TensorDataset 是 PyTorch 提供的一个工具类,用于将多个张量(Tensors)打包成一个数据集(Dataset),便于配合 DataLoader 进行批量加载和数据管理。

一、概念

作用:将多个张量(如特征张量、标签张量)按样本对齐,合并为一个数据集对象。

适用场景:监督学习任务、多输入、多输出模型、简化数据加载流程,兼容DataLoader的批处理、打乱数据等操作

二、使用步骤

1.导入库

import torch
from torch.utils.data import TensorDataset, DataLoader

 2.准备数据

创建一些特征张量和标签,数据类型可以是numpy或list列表

特征张量:(样本数,特征维度)

标签张量:(样本数,)

# 示例数据(100个样本,每个样本5个特征)
features = torch.randn(100, 5)  # 随机生成特征
labels = torch.randint(0, 2, (100,))  # 二分类标签(0或1)

3。创建TensorDataser,将特征张量和标签合并成一个大张量

dataset = TensorDataset(features, labels)

 4.使用DataLoader加载数据,将合并好的张量按照指定大小来切分批次

dataloader = DataLoader(
    dataset,
    batch_size=32,
    shuffle=True,  # 训练时打乱数据
    num_workers=2  # 多进程加载数据(可选)
)

# 遍历批次
for batch_features, batch_labels in dataloader:
    # 在此处执行模型的前向传播、损失计算等操作
    print("Batch features shape:", batch_features.shape)
    print("Batch labels shape:", batch_labels.shape)

三、注意事项

  1. 张量对齐:所有传入 TensorDataset 的张量的第一个维度(样本数)必须相同。
  2. 数据类型:确保张量的数据类型(dtype)与模型输入要求一致(如 float32 或 int64)。
  3. 设备位置:若使用 GPU,需将张量放在 GPU 上(通过 tensor.to(device)),或让 DataLoader 自动处理。
  4. 内存限制:数据量过大时,优先使用 Dataset 的子类(如 IterableDataset)动态加载数据,避免内存溢出。

 四、完整代码

import torch
from torch.utils.data import TensorDataset, DataLoader

# 1. 生成模拟数据
num_samples = 1000
features = torch.randn(num_samples, 10)  # 10维特征
labels = torch.randn(num_samples)       # 连续值标签

# 2. 创建 TensorDataset
dataset = TensorDataset(features, labels)

# 3. 定义 DataLoader
dataloader = DataLoader(
    dataset,
    batch_size=64,
    shuffle=True,
    num_workers=4
)

# 4. 遍历数据
for batch_idx, (batch_features, batch_labels) in enumerate(dataloader):
    print(f"Batch {batch_idx}:")
    print("  Features shape:", batch_features.shape)  # [64, 10]
    print("  Labels shape:  ", batch_labels.shape)    # [64]

http://www.kler.cn/a/600751.html

相关文章:

  • 最长公共子序列LCS -- 全面分析版
  • 爱普生SG-3031CMA有源晶振在汽车雷达中的应用
  • vue2相关 基础命令
  • [NO-WX179]基于springboot+微信小程序的在线选课系统
  • W、M、C练题笔记(持续更新中)
  • 适合各个层次的 7 个计算机视觉项目【1】:植物病害检测
  • 内核编程十二:打印内核态进程的属性
  • 传统 embedding vs. P-Tuning 里的 embedding
  • (二)手眼标定——概述+原理+常用方法汇总+代码实战(C++)
  • 稳定运行的以Microsoft Azure Cosmos DB数据库为数据源和目标的ETL性能变差时提高性能方法和步骤
  • 深入解析 C++20 中的 std::bind_front:高效函数绑定与参数前置
  • 【蓝桥杯每日一题】3.25
  • MySQL数据库中常用的命令
  • 竞品已占据市场先机,如何找到差异化突破口
  • Next.js 中间件鉴权绕过漏洞 (CVE-2025-29927)
  • showdoc在服务器docker部署后如何关闭注册功能
  • XSS复现漏洞简单前八关靶场
  • 余弦退火算法与学习率预热
  • 依肤婗:以科研实力引领 问题性肌肤护理新标准
  • Apache HBase平衡器架构