pytorch 自定义Dataset类
torch.utils.data.Dataset
是 PyTorch 数据处理模块中的一个核心类,用于表示一个数据集。通过继承和自定义 Dataset
类,用户可以轻松管理和加载各种类型的数据,如图像、文本、时间序列等。
1. Dataset
类的作用
Dataset
提供了一种标准接口,方便用户自定义数据加载逻辑,尤其是对于大型数据集。每个自定义的数据集类需要实现两个核心方法:
__len__()
:返回数据集中样本的数量。__getitem__(index)
:根据给定的索引返回数据集中的一个样本(通常包括特征和标签)。
2. 自定义 Dataset
Dataset
是一个抽象类,因此你需要通过继承它来定义自己的数据集,并实现其中的 __len__
和 __getitem__
方法。以下是如何自定义一个简单的 Dataset
的示例。
示例代码
import torch
from torch.utils.data import Dataset
# 自定义数据集类,继承自 torch.utils.data.Dataset
class MyDataset(Dataset):
def __init__(self, data, labels):
# 初始化数据集,传入数据和标签
self.data = data
self.labels = labels
def __len__(self):
# 返回数据集中样本的数量
return len(self.data)
def __getitem__(self, idx):
# 根据索引返回一个样本和其对应的标签
sample = self.data[idx]
label = self.labels[idx]
return sample, label
# 示例数据
data = torch.randn(100, 3) # 100 个样本,每个样本有 3 个特征
labels = torch.randint(0, 2, (100,)) # 100 个样本的标签,二分类(0 或 1)
# 创建数据集实例
dataset = MyDataset(data, labels)
# 访问数据集中的第一个样本
sample, label = dataset[0]
print("Sample:", sample)
print("Label:", label)
解释:
__init__(self, data, labels)
:构造函数中,我们将数据和标签传入并保存为类的成员变量。__len__(self)
:返回数据集的样本数量。__getitem__(self, idx)
:根据索引idx
,返回数据和标签。
3. 与 DataLoader
配合使用
自定义的 Dataset
类通常与 DataLoader
配合使用。DataLoader
提供了批量数据加载、打乱顺序、并行加载等功能。
from torch.utils.data import DataLoader
# 使用 DataLoader 加载数据集
dataloader = DataLoader(dataset, batch_size=4, shuffle=True)
# 迭代 DataLoader
for batch_data, batch_labels in dataloader:
print(batch_data, batch_labels)
解释:
batch_size=4
:每次加载 4 个样本。shuffle=True
:在每个 epoch 之前将数据打乱。
4. 常见的 Dataset
子类
PyTorch 提供了一些常用的 Dataset
子类,如:
torchvision.datasets
:用于加载图像数据集(如 CIFAR、MNIST 等)。torchtext.datasets
:用于加载文本数据集(如 IMDB、WikiText 等)。torch.utils.data.TensorDataset
:将一对张量(如数据和标签)封装成一个数据集。