当前位置: 首页 > article >正文

pytorch数据读入

在机器学习和深度学习中,数据读入是一个非常重要的步骤。PyTorch提供了非常灵活的数据读取方式,主要是通过DatasetDataLoader两个类来完成的。

一、PyTorch常见的数据读取方式

PyTorch中的数据读取主要依赖于两个核心类:DatasetDataLoader

  1. Dataset: 定义数据集的格式和数据变换形式。
  2. DataLoader: 将数据集按批次读入,并且可以进行多进程数据加载。
1.1 Dataset

Dataset类是一个抽象类,我们通常需要继承它并实现以下三个方法:

  • __init__: 初始化函数,用于传入外部参数,如数据集的路径、数据变换等。
  • __getitem__: 用于逐个读取样本数据,可以进行一定的数据变换,并返回训练所需的数据和标签。
  • __len__: 返回数据集的样本总数。
1.2 DataLoader

DataLoader类用于按批次读入数据。它提供了许多参数来控制数据的读取方式,如batch_sizenum_workersshuffle等。

二、构建自己的数据读取流程

为了更好地理解这些概念,我们可以通过一个具体的例子来展示如何构建自己的数据读取流程。

2.1 使用PyTorch自带的ImageFolder

PyTorch提供了一个方便的ImageFolder类,用于读取按一定结构存储的图片数据。我们可以直接使用这个类来读取数据集。

import torch
from torchvision import datasets

# 假设train_path和val_path是数据集的路径
train_data = datasets.ImageFolder(train_path, transform=data_transform)
val_data = datasets.ImageFolder(val_path, transform=data_transform)

这里的data_transform可以对图像进行一定的变换,如翻转、裁剪等操作。你可以在transform模块中定义自己的变换。

2.2例子

假设你有一个图像数据集,用于训练一个猫狗分类模型。你的数据集结构如下:

dataset/
    train/
        cats/
            cat1.jpg
            cat2.jpg
            cat3.jpg
        dogs/
            dog1.jpg
            dog2.jpg
            dog3.jpg
    val/
        cats/
            val_cat1.jpg
            val_cat2.jpg
        dogs/
            val_dog1.jpg
            val_dog2.jpg

在这个数据集中:

  • train文件夹包含训练数据,其中有两个子文件夹catsdogs,分别存放猫和狗的图片。
  • val文件夹包含验证数据,同样有两个子文件夹catsdogs

具体代码示例

接下来我们写一段Python代码来使用ImageFolder类加载这个数据集。

import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# 步骤1: 定义数据变换
data_transform = transforms.Compose([
    transforms.Resize((256, 256)),  # 调整图像大小为256x256
    transforms.RandomHorizontalFlip(),  # 进行随机水平翻转
    transforms.ToTensor(),  # 将图像转换为张量格式
])

# 步骤2: 定义数据集路径
train_path = 'dataset/train'  # 训练集路径
val_path = 'dataset/val'      # 验证集路径

# 步骤3: 使用ImageFolder加载数据集
train_data = datasets.ImageFolder(train_path, transform=data_transform)
val_data = datasets.ImageFolder(val_path, transform=data_transform)

# 步骤4: 创建DataLoader以便批量加载数据
train_loader = DataLoader(train_data, batch_size=32, shuffle=True)
val_loader = DataLoader(val_data, batch_size=32, shuffle=False)

# 步骤5: 访问一个批次的数据
for images, labels in train_loader:
    print(f'Batch of images shape: {images.shape}')  # 打印图片的batch尺寸
    print(f'Batch of labels shape: {labels.shape}')   # 打印标签的batch尺寸
    break  # 只输出一批数据,便于查看

代码详解

  1. 导入必要的库

    import torch
    from torchvision import datasets, transforms
    from torch.utils.data import DataLoader
    
    • 导入PyTorch和必要的模块。
  2. 定义数据变换

    data_transform = transforms.Compose([
        transforms.Resize((256, 256)),  # 将图像大小调整为256x256
        transforms.RandomHorizontalFlip(),  # 随机水平翻转
        transforms.ToTensor(),  # 将图像转换为张量格式
    ])
    
    • 在这里,定义了对每张图像的处理,首先将其调整为256x256像素,然后进行随机翻转,最后将图像转换为PyTorch可以理解的张量格式。
  3. 定义数据集路径

    train_path = 'dataset/train'  # 训练集路径
    val_path = 'dataset/val'      # 验证集路径
    
    • 定义你的训练集和验证集的文件夹路径。
  4. 加载数据集

    train_data = datasets.ImageFolder(train_path, transform=data_transform)
    val_data = datasets.ImageFolder(val_path, transform=data_transform)
    
    • 使用ImageFolder加载训练集和验证集。
    • ImageFolder自动根据文件夹名称生成标签:cats会被标记为类别0,dogs会被标记为类别1。
  5. 创建DataLoader

    train_loader = DataLoader(train_data, batch_size=32, shuffle=True)
    val_loader = DataLoader(val_data, batch_size=32, shuffle=False)
    
    • DataLoader用于将数据集分批次加载。这里定义了每批次包含32个图像。
    • shuffle=True表示在每个epoch开始时打乱数据顺序,通常在训练过程中使用。
  6. 访问数据

    for images, labels in train_loader:
        print(f'Batch of images shape: {images.shape}')  # 打印图片的batch尺寸
        print(f'Batch of labels shape: {labels.shape}')   # 打印标签的batch尺寸
        break  # 只输出一批数据,便于查看
    
    • 最后,遍历train_loader中的一批次数据,输出图像和标签的维度信息。
    • images.shape将在(batch_size, channels, height, width)的格式下输出图像的尺寸。
    • labels.shape输出的会是一个一维张量,表示每张图像的分类标签。


http://www.kler.cn/a/329493.html

相关文章:

  • SVG(Scalable Vector Graphics)全面解析
  • Java测试开发平台搭建(九)前端
  • 使用 Blazor 和 Elsa Workflows 作为引擎的工作流系统开发
  • Crewai + langchain 框架配置第三方(非原生/国产)大模型API
  • Jenkins-Pipeline简述
  • 登录校验Cookie、Session、JWT
  • 常用设计模式之单例模式、策略模式、工厂模式
  • TCP三次握手四次挥手详解
  • HTML5--裸体回顾
  • testRigor测试用例模板记录
  • 从AR眼镜到智能巡检:XR技术的演变与未来潜力
  • 华为仓颉语言入门(7):深入理解 do-while 循环及其应用
  • 利用Java easyExcel库实现高效Excel数据处理
  • mysql学习教程,从入门到精通,SQL GROUP BY 子句(31)
  • 一起了解计算机神经网络
  • 【Linux】第一个小程序——进度条实现
  • 数据分析-29-基于pandas的窗口操作和对JSON格式数据的处理
  • 解决Github打不开或速度慢的问题
  • 职业技能大赛-单元测试笔记(参数化)分享
  • OpenHarmony(鸿蒙南向)——平台驱动指南【DAC】
  • 【floor报错注入】
  • 《深度学习》自然语言处理 统计、神经语言模型 结构、推导解析
  • 【css】如何设计出具有权威性的“机构”网页
  • OpenAI 推理模型 O1 研发历程:团队访谈背后的故事
  • 高防服务器有用吗?租用价格一般多少
  • 【Linux进程间通信】Linux匿名管道详解:构建进程间通信的隐形桥梁