当前位置: 首页 > article >正文

pytorch dataloader学习

import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np 

torch.manual_seed(1)
# 自定义数据集
class CustomDataset(Dataset):
    def __init__(self):
        # 创建一些示例数据(100个样本,每个样本包含10个特征)
        self.data = torch.randn(100, 10)
        self.labels =torch.from_numpy(np.arange(100))  # 二分类标签

    def __len__(self):
        # 返回数据集的大小
        return len(self.data)

    def __getitem__(self, idx):
        # 根据索引 idx 返回对应的样本和标签
        sample = self.data[idx]
        label = self.labels[idx]
        return sample, label

# 创建数据集的实例
dataset = CustomDataset()

# 使用DataLoader加载数据
# 设置batch_size=16,shuffle=True表示打乱数据顺序
dataloader = DataLoader(dataset, batch_size=100, shuffle=True)

# 迭代DataLoader
for i in range(2):
    for batch_idx, (inputs, labels) in enumerate(dataloader):
        print(f"Batch {batch_idx+1}")
        print(f"Inputs: {inputs.size()}")  # 显示当前batch中输入数据的维度
        print(f"Labels: {labels.size()}")  # 显示当前batch中标签的维度
        print(labels)
        # 在这里你可以对数据进行训练
        # 例如:outputs = model(inputs)

只要是shuffle=True,每次epoch结果的顺序是不一样的,如果想每一次的结果是一样的
在这里插入图片描述

如果shuffle=False

import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np 

torch.manual_seed(1)
# 自定义数据集
class CustomDataset(Dataset):
    def __init__(self):
        # 创建一些示例数据(100个样本,每个样本包含10个特征)
        self.data = torch.randn(100, 10)
        self.labels =torch.from_numpy(np.arange(100))  # 二分类标签

    def __len__(self):
        # 返回数据集的大小
        return len(self.data)

    def __getitem__(self, idx):
        # 根据索引 idx 返回对应的样本和标签
        sample = self.data[idx]
        label = self.labels[idx]
        return sample, label

# 创建数据集的实例
dataset = CustomDataset()

# 使用DataLoader加载数据
# 设置batch_size=16,shuffle=True表示打乱数据顺序
dataloader = DataLoader(dataset, batch_size=100, shuffle=True)

# 迭代DataLoader
for i in range(2):
    for batch_idx, (inputs, labels) in enumerate(dataloader):
        print(f"Batch {batch_idx+1}")
        print(f"Inputs: {inputs.size()}")  # 显示当前batch中输入数据的维度
        print(f"Labels: {labels.size()}")  # 显示当前batch中标签的维度
        print(labels)
        # 在这里你可以对数据进行训练
        # 例如:outputs = model(inputs)

结果如下
在这里插入图片描述


http://www.kler.cn/news/367399.html

相关文章:

  • R语言中常用功能汇总
  • <Tauri>tauri2.0框架下,基于qwik(前端)和rust(后端)结合的桌面程序体验
  • 基于SSM轻型卡车零部件销售系统的设计
  • ️ Vulnhuntr:利用大型语言模型(LLM)进行零样本漏洞发现的工具
  • Spring Boot驱动的厨艺社交平台设计与实现
  • 第11天理解指针
  • ros机器人导航以及物体、动作识别
  • SpringBoot多线程
  • MAC终端SSH连接成功但VSCODE连接失败解决方案
  • java实现的音视频格式转化器
  • Java进阶篇设计模式之一 ----- 单例模式
  • 前端学习---(6)js基础--4
  • RPA技术重塑企业自动化的未来
  • Java-梦幻图书店(图书管理系统)
  • LDR6328:助力小家电实现TYPE-C接口快充输入
  • 无人机喊话器详解!
  • 乐维网管平台(一):如何精准掌控 IP 管理
  • PHPOK 4.8.338 后台任意文件上传漏洞(CVE-2018-12941)复现
  • Spring MVC 知识点全解析
  • Ubuntu 上安装 Redmine 5.1 指南
  • vue实现语音合成功能,Android和wap端
  • word中的内容旋转90度
  • 深度学习中的迁移学习:优化训练流程与提高模型性能的策略,预训练模型、微调 (Fine-tuning)、特征提取
  • springboot056教学资源库(论文+源码)_kaic
  • unity中的组件(Component)
  • 基于卷积神经网络的花卉分类系统,resnet50,mobilenet模型【pytorch框架+python源码】