机器学习复习(9)——自定义dataset
目录
第一种dataset(文件夹名即为标签)
用于将格式(1)转换为格式(2)
第二种dataset(标签在labels文件夹下的对应的txt文件里面)
第一种dataset(文件夹名即为标签)
数据组织格式(1)
--data
----train
------class1(文件夹名字即为标签)
--------image1.jpg
------class2
dataset
from torch.utils.data import Dataset
from PIL import Image
class Mydata(Dataset):
def __init__(self,root_dir,label_dir):
self.root_dir = root_dir
self.label_dir = label_dir
self.path = os.path.join(self.root_dir,self.label_dir)
self.img_path= os.listdir(self.path)
def __getitem__(self, idx):
img_name = self.img_path[idx]
img_item_path = os.path.join(self.path,img_name)
img = Image.open(img_item_path)
label =self.label_dir
return img,label
def __len__(self):
return len(self.img_path)
root_dir="../PATH/TO/train"
class1_label_dir="class1"
class2_label_dir="class2"
class1_data =Mydata(root_dir,class1_label_dir)
class2_data =Mydata(root_dir,class2_label_dir)
train_dataset= class1_data+class2_data
用于将格式(1)转换为格式(2)
数据集格式转换
import os
root_dir = 'root_path'
target_dir = 'target_image'
img_path = os.listdir(os.path.join(root_dir, target_dir))
label = target_dir.split('_')[0]
out_dir = 'ants_label'
for i in img_path:
file_name = i.split('.jpg')[0]
with open(os.path.join(root_dir, out_dir,"{}.txt".format(file_name)),'w') as f:
f.write(label)
第二种dataset(标签在labels文件夹下的对应的txt文件里面)
数据组织格式(2)
--data
----train
------images
--------01.jpg
------labels
--------01.txt (txt里面的内容是label内容:目标检测,分类等)
# 导入PyTorch的数据集工具和其他必要的库
from torch.utils.data import Dataset
import os
from PIL import Image
# 自定义的数据集类,继承自torch.utils.data.Dataset
class Mydate(Dataset):
def __init__(self, dir_root, dir_image, dir_label):
# 初始化函数,设置数据集的根目录、图像目录和标签目录
self.root = dir_root # 数据集的根目录
self.image_dir = dir_image # 存放图像的子目录
self.image_path = os.path.join(self.root, self.image_dir) # 图像的完整路径
self.label_dir = dir_label # 存放标签的子目录
self.label_path = os.path.join(self.root, self.label_dir) # 标签的完整路径
# 获取图像文件列表和标签文件列表
self.image_list = os.listdir(os.path.join(dir_root, dir_image)) # 根据图像目录列出所有图像文件
self.label_list = os.listdir(os.path.join(dir_root, dir_label)) # 根据标签目录列出所有标签文件
def __getitem__(self, idx):
# 通过索引获取数据集中的单个样本,包括图像和标签
image_name = self.image_list[idx] # 根据索引获取图像文件名
image_path = os.path.join(self.image_path, image_name) # 构造图像的完整路径
label_name = self.label_list[idx] # 根据索引获取标签文件名
label_path = os.path.join(self.label_path, label_name) # 构造标签的完整路径
img = Image.open(image_path).convert('RGB') # 打开图像文件并转换为RGB格式
# 读取标签文件
with open(label_path, 'r') as f:
label = f.read().strip() # 读取标签内容,并去除可能的空白字符
return img, label
def __len__(self):
# 返回数据集中样本的总数
return len(self.image_list) # 由于图像列表的长度代表了数据集大小,直接返回其长度
##############测试代码################
if __name__ == "__main__":
# 指定数据集的根目录、图像目录和标签目录
dir_root = "root\\path\\train"
dir_image = "images"
dir_label = "labels"
test_data = Mydate(dir_root, dir_image, dir_label) # 创建数据集实例
idx = 0 # 指定要获取的样本索引
img, label = test_data[idx] # 获取指定索引的样本
print(label) # 打印样本的标签