当前位置: 首页 > article >正文

机器学习复习(9)——自定义dataset

目录

第一种dataset(文件夹名即为标签)

用于将格式(1)转换为格式(2)

第二种dataset(标签在labels文件夹下的对应的txt文件里面)


第一种dataset(文件夹名即为标签)

数据组织格式(1)

--data

----train

------class1(文件夹名字即为标签)

--------image1.jpg

------class2

dataset

from torch.utils.data import Dataset
from  PIL import Image
class Mydata(Dataset):
    def __init__(self,root_dir,label_dir):
        self.root_dir = root_dir
        self.label_dir = label_dir
        self.path = os.path.join(self.root_dir,self.label_dir)
        
        self.img_path= os.listdir(self.path)
        
        
    def __getitem__(self, idx):
        img_name = self.img_path[idx]
        
        img_item_path = os.path.join(self.path,img_name)
        
        img = Image.open(img_item_path)
        
        label =self.label_dir
        return img,label
        
    def __len__(self):
        return len(self.img_path)

root_dir="../PATH/TO/train"
class1_label_dir="class1"
class2_label_dir="class2"

class1_data =Mydata(root_dir,class1_label_dir)
class2_data =Mydata(root_dir,class2_label_dir)    

train_dataset= class1_data+class2_data

用于将格式(1)转换为格式(2)

数据集格式转换

import os
root_dir = 'root_path'
target_dir = 'target_image'
img_path = os.listdir(os.path.join(root_dir, target_dir))
label = target_dir.split('_')[0]
out_dir = 'ants_label'
for i in img_path:
    file_name = i.split('.jpg')[0]
    with open(os.path.join(root_dir, out_dir,"{}.txt".format(file_name)),'w') as f:
        f.write(label)

第二种dataset(标签在labels文件夹下的对应的txt文件里面)

数据组织格式(2)

--data

----train

------images

--------01.jpg

------labels

--------01.txt        (txt里面的内容是label内容:目标检测,分类等)

# 导入PyTorch的数据集工具和其他必要的库
from torch.utils.data import Dataset
import os
from PIL import Image

# 自定义的数据集类,继承自torch.utils.data.Dataset
class Mydate(Dataset):
    def __init__(self, dir_root, dir_image, dir_label):
        # 初始化函数,设置数据集的根目录、图像目录和标签目录
        self.root = dir_root  # 数据集的根目录

        self.image_dir = dir_image  # 存放图像的子目录
        self.image_path = os.path.join(self.root, self.image_dir)  # 图像的完整路径

        self.label_dir = dir_label  # 存放标签的子目录
        self.label_path = os.path.join(self.root, self.label_dir)  # 标签的完整路径

        # 获取图像文件列表和标签文件列表
        self.image_list = os.listdir(os.path.join(dir_root, dir_image))  # 根据图像目录列出所有图像文件
        self.label_list = os.listdir(os.path.join(dir_root, dir_label))  # 根据标签目录列出所有标签文件

    def __getitem__(self, idx):
        # 通过索引获取数据集中的单个样本,包括图像和标签
        image_name = self.image_list[idx]  # 根据索引获取图像文件名
        image_path = os.path.join(self.image_path, image_name)  # 构造图像的完整路径

        label_name = self.label_list[idx]  # 根据索引获取标签文件名
        label_path = os.path.join(self.label_path, label_name)  # 构造标签的完整路径

        img = Image.open(image_path).convert('RGB')  # 打开图像文件并转换为RGB格式
        # 读取标签文件
        with open(label_path, 'r') as f:
            label = f.read().strip()  # 读取标签内容,并去除可能的空白字符
        return img, label

    def __len__(self):
        # 返回数据集中样本的总数
        return len(self.image_list)  # 由于图像列表的长度代表了数据集大小,直接返回其长度

##############测试代码################
if __name__ == "__main__":
    # 指定数据集的根目录、图像目录和标签目录
    dir_root = "root\\path\\train"
    dir_image = "images"
    dir_label = "labels"
    test_data = Mydate(dir_root, dir_image, dir_label)  # 创建数据集实例
    idx = 0  # 指定要获取的样本索引
    img, label = test_data[idx]  # 获取指定索引的样本
    print(label)  # 打印样本的标签


http://www.kler.cn/a/273633.html

相关文章:

  • H2数据库在单元测试中的应用
  • 左值引用(Lvalue Reference)和右值引用(Rvalue Reference)详解
  • Dart语言的语法糖
  • Personal APP
  • React Native 项目 Error: EMFILE: too many open files, watch
  • STM32-笔记37-吸烟室管控系统项目
  • Linux 文件系统:文件描述符、管理文件
  • vue3.x 使用jsplumb进行多列拖拽连线
  • C++ cin标准输入流,及获取多个输入的方法
  • Springboot整合支付宝沙箱支付
  • 移动云COCA架构实现算力跃升,探索人工智能新未来
  • 【C语言】空心正方形图案
  • 【开发】SpringBoot 整合 Redis
  • 自然辩证法
  • bootstrap表格API文档
  • 【Linux】用三种广义进程状态 来理解Linux的进程状态(12)
  • GPT-SoVITS语音合成服务器部署,可远程访问(全部代码和详细部署步骤)
  • 海康、新华三、银江股份、大华等知名企业集结亮相“杭州安防展”
  • 杂记8---多线激光雷达与相机外参标定
  • java项目打包(maven+原生)
  • LeetCode108 将有序数组转换为二叉搜索树
  • 云原生(四)、Docker-Compose
  • js复制内容到剪贴板实现复制粘贴功能
  • git tag标签使用
  • 从底层结构开始学习FPGA(0)----FPGA的硬件架构层次(BEL Site Tile FSR SLR Device)
  • MySQL 锁机制