当前位置: 首页 > article >正文

pytorch Dataset类代码学习

from torch.utils.data import  Dataset
from PIL import Image
import os


class my_data(Dataset):
    def __init__(self, root_dir, label_dir): # 初始化类,根据这一个类,来创建特例实例需要调用的一个函数
        self.root_dir = root_dir
        self.label_dir = label_dir
        self.path = os.path.join(self.root_dir, self.label_dir)
        self.img_path = os.listdir(self.path)



    def __getitem__(self, idx):
        img_name = self.img_path[idx]
        img_item_path = os.path.join(self.root_dir,self.label_dir, img_name)
        img = Image.open(img_item_path)
        label = self.label_dir
        return img, label

    def __len__(self):
        return len(self.img_path)

root_dir = "dataset/train"
ants_label_dir = "ants"
bees_label_dir = "bees"
ants_dataset = my_data(root_dir, ants_label_dir)
bees_dataset = my_data(root_dir, bees_label_dir)

train_dataset = ants_dataset + bees_dataset

在控制台中将上述代码粘贴:查看数据集等操作:

  ...: from PIL import Image
  ...: import os
........................
  ...:     def __len__(self):
  ...:         return len(self.img_path)

创建数据集,包括路径与标签。还有蚂蚁的数据集。

root_dir = "dataset\train"
ants_label_dir = "ants"
ants_dataset = my_data(root_dir, ants_label_dir)

然而,出现如下的一些报错

OSError: [WinError 123] 文件名、目录名或卷标语法不正确。: 'dataset\train\\ants'

原因是:

root_dir = "dataset/train"

斜画线反了,不能直接用复制粘贴里面来的。

完整读取数据集里的图片代码:

root_dir = "dataset/train"
ants_label_dir = "ants"
ants_dataset = my_data(root_dir, ants_label_dir)
img, label = ants_dataset[1]
img.show()

如果读取出来的图片反复都是一张,则是因为:读取的是上一次成功读取的图片。

错误原因是在这句代码中:

img, label = ants_dataset[1]

这句中的连接是逗号,并不是.

通过上述的语句,即可实现数据集图片的读取。

两个数据集的相加:

train_dataset = ants_dataset + bees_dataset

在控制台中,使用同样的方法读取:

len(ants_dataset)
输出:Out[23]: 124
len(bees_dataset)
输出:Out[24]: 121
img,label = train_dataset[123]
img.show()
img,label = train_dataset[124]
img.show()


http://www.kler.cn/a/282336.html

相关文章:

  • HDMI之SBTM
  • [241115] Debian 12.8 发布 | Mistral AI 推出批量 API,成本降低 50%
  • 从零开始学习 sg200x 多核开发之 eth0 MAC 地址修改
  • 人工智能训练师 综合测试题库一
  • Notepad++的完美替代
  • 魔方和群论
  • 在PyCharm终端使用where命令不返回路径问题
  • 顶级域名服务器 - TLD服务器
  • RK方案有时一开机要设置GPIO口点平
  • Sentinel-1 Level 1数据处理的详细算法定义(九)
  • QT多线程遍历注册表
  • vray材质forC4D测试
  • Java相关工具/插件的安装教程汇总
  • SFF1604-ASEMI无人机专用SFF1604
  • HarmonyOS NEXT 实战开发:实现日常提醒应用
  • vue报错解决
  • python进阶篇-day01-面向对象基础
  • (154)时序收敛--->(04)时序收敛四
  • C语言关键字
  • 最大子数组(有限制)
  • 无人机和老鹰,谁飞得更快?
  • 多模态论文学习8.29
  • Postman注册使用
  • 离职赔偿一览表
  • 接口(interface)使用方法:
  • 57.给定一组不重叠的区间,实现一个算法在这些区间中插入一个新的区间(如果有必要的话进行合并)