当前位置: 首页 > article >正文

pytorch中的TensorDataset和DataLoader

TensorDataset 详解

TensorDataset 主要用于将多个 Tensor 组合在一起,方便对数据进行统一处理。它可以用于简单地将特征和标签配对,也可以将多个特征张量组合在一起。

1. 将特征和标签组合

假设我们有一组图像数据(特征)和对应的标签,我们可以将它们组合成一个 TensorDataset

import torch
from torch.utils.data import TensorDataset

# 创建输入数据(图像)和标签
images = torch.randn(100, 3, 28, 28)  # 100张图像,每张图像3通道,28x28像素
labels = torch.randint(0, 10, (100,))  # 100个标签,范围在0到9之间

# 创建 TensorDataset
dataset = TensorDataset(images, labels)

# 访问数据集中的特定样本
sample_image, sample_label = dataset[0]
print(f"Sample Image Shape: {sample_image.shape}")  # 输出: Sample Image Shape: torch.Size([3, 28, 28])
print(f"Sample Label: {sample_label}")  # 输出: Sample Label: 3

在这个例子中,我们创建了一个包含100张图像和对应标签的 TensorDataset。通过 dataset[0],我们可以访问第一个样本的图像和标签。

2. 组合多个特征张量

除了将特征和标签组合,TensorDataset 还可以将多个特征张量组合在一起。例如,假设我们有两个不同的特征张量,我们可以将它们组合成一个 TensorDataset

# 创建两个特征张量
feature1 = torch.randn(100, 50)  # 100个样本,每个样本50维
feature2 = torch.randn(100, 30)  # 100个样本,每个样本30维

# 创建 TensorDataset
dataset = TensorDataset(feature1, feature2)

# 访问数据集中的特定样本
sample_feature1, sample_feature2 = dataset[0]
print(f"Sample Feature1 Shape: {sample_feature1.shape}")  # 输出: Sample Feature1 Shape: torch.Size([50])
print(f"Sample Feature2 Shape: {sample_feature2.shape}")  # 输出: Sample Feature2 Shape: torch.Size([30])

在这个例子中,我们创建了一个包含两个特征张量的 TensorDataset,并通过 dataset[0] 访问第一个样本的两个特征。

DataLoader 详解

DataLoader 主要用于批量加载数据,并支持多种数据处理功能,如随机打乱、多线程加载等。

1. 批量处理数据

DataLoader 可以将数据集划分为多个批次(batch),便于模型训练。

from torch.utils.data import DataLoader

# 创建 DataLoader
train_loader = DataLoader(dataset, batch_size=32, shuffle=False)

# 遍历 DataLoader
for batch_features, batch_labels in train_loader:
    print(f"Batch Features Shape: {batch_features.shape}")  # 输出: Batch Features Shape: torch.Size([32, 3, 28, 28])
    print(f"Batch Labels Shape: {batch_labels.shape}")  # 输出: Batch Labels Shape: torch.Size([32])
    # 这里可以进行训练操作,如前向传播、反向传播等

在这个例子中,train_loader 将数据集划分为大小为32的批次。通过遍历 train_loader,我们可以轻松地获取每个批次的特征和标签。

2. 数据打乱

DataLoader 可以通过设置 shuffle=True 来在每个 epoch 开始时随机打乱数据,避免模型学习到数据的顺序。

# 创建 DataLoader,并设置 shuffle=True
train_loader = DataLoader(dataset, batch_size=32, shuffle=True)

# 遍历 DataLoader
for epoch in range(2):  # 假设我们要训练两个 epoch
    for batch_features, batch_labels in train_loader:
        print(f"Epoch {epoch}, Batch Features Shape: {batch_features.shape}")
        # 这里可以进行训练操作

在这个例子中,每次 epoch 开始时,数据都会被随机打乱,确保模型不会受到数据顺序的影响。

3. 多线程加载

DataLoader 支持通过设置 num_workers 参数来使用多线程并行加载数据,加快数据读取速度。

# 创建 DataLoader,并设置 num_workers=4
train_loader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=4)

# 遍历 DataLoader
for batch_features, batch_labels in train_loader:
    print(f"Batch Features Shape: {batch_features.shape}")
    # 这里可以进行训练操作

在这个例子中,我们设置了 num_workers=4,表示使用4个线程来并行加载数据,从而加快数据读取速度。

结合使用 TensorDataset 和 DataLoader

以下是一个完整的示例,展示了如何结合使用 TensorDataset 和 DataLoader 进行数据加载和训练。

import torch
from torch.utils.data import TensorDataset, DataLoader

# 创建输入数据和标签
images = torch.randn(1000, 3, 28, 28)  # 1000张图像,每张图像3通道,28x28像素
labels = torch.randint(0, 10, (1000,))  # 1000个标签,范围在0到9之间

# 创建 TensorDataset
dataset = TensorDataset(images, labels)

# 创建 DataLoader
train_loader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=4)

# 遍历 DataLoader 进行训练
for epoch in range(2):
    for batch_images, batch_labels in train_loader:
        print(f"Epoch {epoch}, Batch Images Shape: {batch_images.shape}")
        print(f"Epoch {epoch}, Batch Labels Shape: {batch_labels.shape}")
        # 这里可以进行训练操作,如前向传播、反向传播等

在这个例子中,我们首先使用 TensorDataset 将图像和标签组合在一起,然后通过 DataLoader 进行批量加载和训练。通过设置 shuffle=True 和 num_workers=4,我们实现了数据的随机打乱和多线程加载。

总结

  • TensorDataset 用于将多个 Tensor 组合在一起,方便对数据进行统一处理。
    • 可以组合特征和标签。
    • 可以组合多个特征张量。
  • DataLoader 用于批量加载数据,支持多种数据处理功能。
    • 支持批量处理数据。
    • 支持数据打乱。
    • 支持多线程加载。


http://www.kler.cn/news/336241.html

相关文章:

  • 红外画面空中目标检测系统源码分享
  • LeetCode讲解篇之139. 单词拆分
  • JS模块化工具requirejs详解
  • webpack/vite的区别
  • Oracle架构之物理存储之日志文件
  • 计算机毕业设计 基于Python的智能文献管理系统的设计与实现 Python+Django+Vue 前后端分离 附源码 讲解 文档
  • 【图像处理】多幅不同焦距的同一个物体的平面图象,合成一幅具有立体效果的单幅图像原理(一)
  • MFC工控项目实例二十二主界面计数背景颜色改变
  • 股市突然暴涨,需要保持理性
  • 突触可塑性与STDP:神经网络中的自我调整机制
  • 探索MinimalModbus:Python中强大的Modbus通信库
  • 【WSL】wsl中ubuntu无法通过useradd添加用户
  • 论文速读:基于渐进式转移的无监督域自适应舰船检测
  • CMU 10423 Generative AI:lec14(Vision Language Model:CLIP、VQ-VAE)
  • WPF 设计属性 设计页面时实时显示 页面涉及集合时不显示处理 设计页面时显示集合样式 显示ItemSource TabControl等集合样式
  • Java如何判断堆区中的对象可以被回收了?
  • 【含开题报告+文档+PPT+源码】基于SSM + Vue的养老院管理系统【包运行成功】
  • 树莓派 mysql (兼容mariadb)登陆问题
  • 【c++】知识精讲:c++数组排序的方法归纳
  • 设置服务器走本地代理