PyTorch加载自己的数据集
PyTorch加载自己的数据集
- 0.引言
- 1.定义自己的Dataset类
0.引言
- 主要内容来源。
PyTorch提供了几种方法来加载自己的数据集。下面是一些常用的方法:
- 1.使用
torch.utils.data.Dataset
类创建自定义数据集
这是一种常见的方式,用于自定义数据集。创建一个类,继承自torch.utils.data.Dataset,并重写__len__()
和__getitem__()
方法。__len__()
方法应该返回数据集的大小,__getitem__()
方法应该返回一个样本。例如,以下是一个自定义数据集类的示例:
import torch
from torch.utils.data import Dataset
class MyDataset(Dataset):
def __init__(self, data, targets):
self.data = data
self.targets = targets
def __len__(self):
return len(self.data)
def __getitem__(self, index):
x = self.data[index]
y = self.targets[index]
return x, y
- 2.使用
torch.utils.data.DataLoader
类加载数据集
torch.utils.data.DataLoader
类用于加载数据集。它可以自动对数据集进行批处理、打乱和多线程加载。下面是一个使用DataLoader加载数据集的示例:
from torch.utils.data import DataLoader
dataset = MyDataset(data, targets)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
- 3.使用
torchvision.datasets
模块加载常见数据集
torchvision.datasets模块提供了一些常见的数据集,例如MNIST、CIFAR等。可以使用这些数据集来测试模型或学习如何加载数据集。以下是一个使用torchvision.datasets加载MNIST数据集的示例:
import torchvision.datasets as datasets
import torchvision.transforms as transforms
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
trainset = datasets.MNIST(root='./data', train=True,
download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=32,
shuffle=True, num_workers=2)
上面的代码将下载MNIST数据集,并使用ToTensor()和Normalize()转换图像。然后使用DataLoader加载数据集。
1.定义自己的Dataset类
创建一个类,继承自torch.utils.data.Dataset,并重写__len__()
和__getitem__()
方法:
__init__
用于向类中传入外部参数,同时定义样本集__len__()
方法应该返回数据集的大小__getitem__()
方法应该返回一个样本
例如,以下是一个自定义数据集类的示例:
import torch
from torch.utils.data import Dataset
class MyDataset(Dataset):
def __init__(self, data, targets):
self.data = data
self.targets = targets
def __len__(self):
return len(self.data)
def __getitem__(self, index):
x = self.data[index]
y = self.targets[index]
return x, y
这里另外给出一个例子,其中图片存放在一个文件夹,另外有一个csv文件给出了图片名称对应的标签。这种情况下需要自己来定义Dataset类:
class MyDataset(Dataset):
def __init__(self, data_dir, info_csv, image_list, transform=None):
"""
Args:
data_dir: path to image directory.
info_csv: path to the csv file containing image indexes
with corresponding labels.
image_list: path to the txt file contains image names to training/validation set
transform: optional transform to be applied on a sample.
"""
label_info = pd.read_csv(info_csv)
image_file = open(image_list).readlines()
self.data_dir = data_dir
self.image_file = image_file
self.label_info = label_info
self.transform = transform
def __getitem__(self, index):
"""
Args:
index: the index of item
Returns:
image and its labels
"""
image_name = self.image_file[index].strip('\n')
raw_label = self.label_info.loc[self.label_info['Image_index'] == image_name]
label = raw_label.iloc[:,0]
image_name = os.path.join(self.data_dir, image_name)
image = Image.open(image_name).convert('RGB')
if self.transform is not None:
image = self.transform(image)
return image, label
def __len__(self):
return len(self.image_file)
构建好Dataset后,就可以使用DataLoader来按批次读入数据了,实现代码如下:
from torch.utils.data import DataLoader
train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, num_workers=4, shuffle=True, drop_last=True)
val_loader = torch.utils.data.DataLoader(val_data, batch_size=batch_size, num_workers=4, shuffle=False)
其中:
-
batch_size:样本是按“批”读入的,batch_size就是每次读入的样本数
-
num_workers:有多少个进程用于读取数据,Windows下该参数设置为0,Linux下常见的为4或者8,根据自己的电脑配置来设置
-
shuffle:是否将读入的数据打乱,一般在训练集中设置为True,验证集中设置为False
-
drop_last:对于样本最后一部分没有达到批次数的样本,使其不再参与训练
这里可以看一下加载的数据。PyTorch中的DataLoader的读取可以使用next和iter来完成
import matplotlib.pyplot as plt
images, labels = next(iter(val_loader))
print(images.shape)
plt.imshow(images[0].transpose(1,2,0))
plt.show()