pytorch Dataset类代码学习
from torch.utils.data import Dataset
from PIL import Image
import os
class my_data(Dataset):
def __init__(self, root_dir, label_dir): # 初始化类,根据这一个类,来创建特例实例需要调用的一个函数
self.root_dir = root_dir
self.label_dir = label_dir
self.path = os.path.join(self.root_dir, self.label_dir)
self.img_path = os.listdir(self.path)
def __getitem__(self, idx):
img_name = self.img_path[idx]
img_item_path = os.path.join(self.root_dir,self.label_dir, img_name)
img = Image.open(img_item_path)
label = self.label_dir
return img, label
def __len__(self):
return len(self.img_path)
root_dir = "dataset/train"
ants_label_dir = "ants"
bees_label_dir = "bees"
ants_dataset = my_data(root_dir, ants_label_dir)
bees_dataset = my_data(root_dir, bees_label_dir)
train_dataset = ants_dataset + bees_dataset
在控制台中将上述代码粘贴:查看数据集等操作:
...: from PIL import Image
...: import os
........................
...: def __len__(self):
...: return len(self.img_path)
创建数据集,包括路径与标签。还有蚂蚁的数据集。
root_dir = "dataset\train"
ants_label_dir = "ants"
ants_dataset = my_data(root_dir, ants_label_dir)
然而,出现如下的一些报错:
OSError: [WinError 123] 文件名、目录名或卷标语法不正确。: 'dataset\train\\ants'
原因是:
root_dir = "dataset/train"
斜画线反了,不能直接用复制粘贴里面来的。
完整读取数据集里的图片代码:
root_dir = "dataset/train"
ants_label_dir = "ants"
ants_dataset = my_data(root_dir, ants_label_dir)
img, label = ants_dataset[1]
img.show()
如果读取出来的图片反复都是一张,则是因为:读取的是上一次成功读取的图片。
错误原因是在这句代码中:
img, label = ants_dataset[1]
这句中的连接是逗号,并不是.
通过上述的语句,即可实现数据集图片的读取。
两个数据集的相加:
train_dataset = ants_dataset + bees_dataset
在控制台中,使用同样的方法读取:
len(ants_dataset)
输出:Out[23]: 124
len(bees_dataset)
输出:Out[24]: 121
img,label = train_dataset[123]
img.show()
img,label = train_dataset[124]
img.show()