PyTorch 基础数据集:从理论到实践的深度学习基石
一、引言
深度学习作为当今人工智能领域的核心技术,在图像识别、自然语言处理、语音识别等众多领域取得了令人瞩目的成果。而在深度学习的体系中,数据扮演着举足轻重的角色,它是模型训练的基础,如同建筑的基石,决定了模型的性能和泛化能力。PyTorch 作为当下最流行的深度学习框架之一,为开发者提供了丰富且强大的工具来处理数据集。本文将深入探讨 PyTorch 中的基础数据集,从深度学习中数据的重要性出发,详细介绍常见的数据集,深入解析数据集的读取与加载机制,数据集是完成完成训练和测试的前提,一个良好的数据集选择对一个实验的结果有着举足轻重的作用。
二、数据:深度学习的核心要素
2.1 深度学习的关键组成要素
深度学习模型主要由三个关键部分组成:数据、模型架构和训练算法。数据为模型的训练提供了原始信息,模型架构定义了如何对数据进行处理和特征提取,训练算法则负责调整模型的参数,使得模型能够更好地拟合数据,从而实现对未知数据的准确预测。这三个要素相互依存,缺一不可,但数据作为起点,其质量和规模直接影响着模型的最终表现。
2.2 数据对深度学习模型训练的作用
高质量、大规模的数据能够让模型学习到更丰富、更准确的特征表示。在图像识别任务中,大量的图像数据可以让模型学习到不同物体的各种特征,包括形状、颜色、纹理等,从而提高模型对不同场景下物体识别的准确性。数据还可以帮助模型避免过拟合。如果训练数据量过少,模型可能会过度学习训练数据中的噪声和细节,在训练数据集上表现优异,但在测试数据上表现不佳。通过增加数据量或采用数据增强技术,可以增加数据的多样性,使模型学习到更具泛化性的特征。
2.3 数据来源与众包
数据的来源多种多样。常见的来源包括公开数据集、自行采集的数据以及众包数据。公开数据集如 Pascal VOC、COCO 等,由研究机构或企业整理并公开,这些数据集通常具有良好的标注和规范的格式,方便研究者使用。自行采集数据可以根据特定的任务需求,获取更符合实际应用场景的数据,但需要耗费大量的时间和人力成本。众包数据则是通过众包平台,让大量的用户参与数据的标注和采集工作,这种方式可以快速获取大规模的数据,但需要注意数据的质量控制。
三、常见的深度学习数据集
3.1 Pascal VOC
Pascal VOC(Visual Object Classes)是一个在计算机视觉领域广泛使用的数据集,主要用于目标检测、图像分割等任务。该数据集包含了多个类别(如人、车、动物等)的图像,并提供了精确的标注信息,包括物体的类别、位置框等。Pascal VOC 数据集每年都会举办挑战赛,吸引了众多研究者参与,推动了目标检测等领域的技术发展。它的标注格式规范,使得研究者可以方便地将其应用于自己的模型训练中。在目标检测算法的评估中,Pascal VOC 数据集是一个重要的基准数据集,许多新算法都会在该数据集上进行测试和比较。
3.2 COCO
COCO(Common Objects in Context)数据集是另一个极具影响力的数据集,相比于 Pascal VOC,它具有更大的规模和更丰富的标注信息。COCO 数据集包含了日常生活场景中的各种物体,图像数量众多,标注不仅包括物体的边界框,还包括实例分割、关键点检测等信息。这使得 COCO 数据集在多个计算机视觉任务中都有广泛的应用,如目标检测、实例分割、全景分割等。许多先进的模型都是在 COCO 数据集上进行训练和优化的,其复杂的场景和多样化的标注为模型的训练提供了更具挑战性的环境,有助于提升模型的泛化能力。
3.3 Mnist
Mnist(Mixed National Institute of Standards and Technology database)数据集是一个经典的手写数字图像数据集,由美国国家标准与技术研究所整理。该数据集包含了 6 万张训练图像和 1 万张测试图像,图像中的数字范围为 0 - 9。Mnist 数据集的图像尺寸较小,且格式统一,非常适合作为深度学习入门的数据集。许多初学者在学习神经网络时,都会使用 Mnist 数据集进行模型的训练和测试,通过在该数据集上的实践,掌握神经网络的基本原理和训练方法。由于 Mnist 数据集相对简单,模型在该数据集上通常能够取得较高的准确率,这也增强了研究者对深度学习模型的信心。
3.4 Fashion - Mnist
Fashion - Mnist 是一个与 Mnist 类似的数据集,但它的图像内容为时尚物品,如衣服、鞋子等。该数据集同样包含 6 万张训练图像和 1 万张测试图像,其目的是为了替代 Mnist 数据集,提供一个更具挑战性的图像分类任务。相比于 Mnist 数据集,Fashion - Mnist 数据集中的物品类别之间的区分度相对较小,这对模型的特征提取能力提出了更高的要求。许多研究者在研究图像分类算法时,会同时在 Mnist 和 Fashion - Mnist 数据集上进行测试,以评估模型的性能和泛化能力。
3.5 CIFAR
CIFAR 数据集分为 CIFAR - 10 和 CIFAR - 100 两个版本,CIFAR - 10 包含 10 个类别,每个类别有 6000 张图像,共计 6 万张图像;CIFAR - 100 包含 100 个类别,每个类别有 600 张图像。这些图像涵盖了不同的物体和场景,如飞机、汽车、鸟类等。CIFAR 数据集的图像尺寸较小,但相比于 Mnist 和 Fashion - Mnist,其图像内容更加复杂,背景干扰更多,这使得在 CIFAR 数据集上训练模型更具挑战性。许多先进的卷积神经网络模型都会在 CIFAR 数据集上进行训练和评估,以验证模型在复杂图像分类任务上的性能。
3.6 ImageNet
ImageNet 是目前最大的图像数据库之一,它包含了超过 1400 万张图像,涵盖了 2 万多个类别。ImageNet 每年都会举办大规模的图像识别挑战赛(ILSVRC),吸引了全球众多研究机构和企业参与。许多著名的深度学习模型,如 AlexNet、VGGNet、ResNet 等,都是在 ImageNet 数据集上进行训练并取得了突破性的成果。ImageNet 数据集的规模和多样性为深度学习模型的训练提供了丰富的信息,推动了计算机视觉领域的快速发展。在图像识别领域,ImageNet 数据集已成为评估模型性能的重要基准。
3.7 Cityscapes
Cityscapes 数据集专注于城市街景图像,主要用于语义分割任务。该数据集包含了 50 个不同城市的街景图像,提供了精细的像素级标注,包括道路、建筑物、行人、车辆等多个类别。Cityscapes 数据集的图像质量高,标注准确,为城市场景理解和自动驾驶等领域的研究提供了重要的数据支持。许多语义分割算法都是在 Cityscapes 数据集上进行训练和优化的,通过在该数据集上的实验,研究者可以评估模型对复杂城市场景中不同物体的分割能力。
3.8 FakeData
FakeData 并不是真实世界中的数据集,它主要用于测试和调试深度学习模型。在开发模型的过程中,使用 FakeData 可以快速验证模型的结构和训练流程是否正确,避免在真实数据集上进行长时间的训练。FakeData 可以根据需要生成不同形状、大小和数据类型的数据,方便开发者对模型进行各种测试。例如,在测试模型的并行计算性能时,可以使用 FakeData 生成大规模的数据,模拟真实场景下的数据处理情况。
四、PyTorch 中的数据集读取与加载
4.1 torchvision.datasets 包
PyTorch 提供了 torchvision.datasets 包,torchvision含有内部的数据集,仅仅引用即可,无需手动下载,它包含了许多常见的数据集类,如前面介绍的 Mnist、Fashion - Mnist、CIFAR 等。通过这个包,我们可以方便地下载和加载这些数据集。以 Mnist 数据集为例,以下是使用 torchvision.datasets 包加载 Mnist 数据集的代码:
import torchvision
from torchvision import datasets, transforms
# 定义数据预处理转换
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
# 加载训练数据集
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
# 加载测试数据集
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
在上述代码中,首先定义了数据预处理转换,将图像转换为张量并进行归一化处理。然后,通过datasets.MNIST
类分别加载训练数据集和测试数据集,root
参数指定了数据集的存储路径,train=True
表示加载训练集,train=False
表示加载测试集,download=True
表示如果数据集不存在,则自动下载。transform
参数则应用了之前定义的数据预处理转换。
4.2 torch.utils.data.Dataset 的子集
torch.utils.data.Dataset
是 PyTorch 中所有数据集的基类,它定义了数据集的基本操作。当我们需要加载自定义数据集时,通常需要继承torch.utils.data.Dataset
类,并实现__len__
和__getitem__
方法。__len__
方法返回数据集的大小,__getitem__
方法根据索引返回数据集中的一个样本。以下是一个简单的自定义数据集类的示例:
import torch
from torch.utils.data import Dataset
class CustomDataset(Dataset):
def __init__(self, data, labels):
self.data = data
self.labels = labels
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
sample = self.data[idx]
label = self.labels[idx]
return sample, label
在上述代码中,CustomDataset
类接受数据和标签作为输入,并在__init__
方法中进行初始化。__len__
方法返回数据的长度,__getitem__
方法根据索引返回对应的样本和标签。在实际应用中,我们可以根据自定义数据集的特点,对__init__
、__len__
和__getitem__
方法进行更复杂的实现,例如读取图像文件、进行数据增强等。
4.3 torch.utils.data.DataLoader 加载数据集
torch.utils.data.DataLoader
用于将数据集包装成可迭代的对象,以便在训练模型时按批次加载数据。它提供了许多参数来控制数据的加载方式,如batch_size
指定每个批次的样本数量,shuffle
表示是否在每个 epoch 对数据进行打乱,num_workers
指定用于加载数据的子进程数量等。以下是使用DataLoader
加载 Mnist 数据集的代码:
from torch.utils.data import DataLoader
# 创建数据加载器
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=4)
在上述代码中,通过DataLoader
分别创建了训练数据加载器和测试数据加载器。batch_size=64
表示每个批次包含 64 个样本,shuffle=True
表示在训练时对数据进行打乱,以增加数据的随机性,提高模型的泛化能力;shuffle=False
表示在测试时不打乱数据,以便进行准确的评估。num_workers=4
表示使用 4 个子进程来加载数据,加快数据加载速度。在训练模型时,我们可以通过遍历数据加载器来获取每个批次的数据:
for epoch in range(num_epochs):
for batch_idx, (data, target) in enumerate(train_loader):
# 进行模型训练
output = model(data)
loss = criterion(output, target)
optimizer.zero_grad()
loss.backward()
optimizer.step()
在上述代码中,外层循环表示训练的 epoch 数,内层循环通过遍历train_loader
获取每个批次的数据。在每个批次中,将数据输入模型进行前向传播,计算损失,然后进行反向传播和参数更新。
4.4 DataLoader 参数详解
DataLoader
的参数众多,下面对一些重要参数进行详细解释:
batch_size
:指定每个批次的样本数量。较大的batch_size
可以利用更多的计算资源,加快训练速度,但可能会导致内存占用过大;较小的batch_size
可以使模型在训练时更加关注每个样本,有助于提高模型的泛化能力,但会增加训练时间。shuffle
:表示是否在每个 epoch 对数据进行打乱。在训练模型时,打乱数据可以避免模型学习到数据的顺序信息,增加数据的随机性,提高模型的泛化能力。在测试时,通常不需要打乱数据,以便进行准确的评估。sampler
:用于指定样本的采样方式。默认情况下,使用随机采样(RandomSampler
),即每个 epoch 对数据进行打乱。如果需要自定义采样方式,可以实现Sampler
类,并将其作为sampler
参数传入DataLoader
。worker_init_fn
:是一个可选的函数,用于在每个数据加载子进程启动时进行初始化。可以在这个函数中设置随机种子等操作,以确保每个子进程的行为一致。prefetch_factor
:指定每个数据加载子进程预先获取的样本数量。较大的值可以提高数据加载的效率,但会占用更多的内存。batch_sampler
:用于指定批次的采样方式。如果设置了batch_sampler
,则batch_size
、shuffle
、sampler
和drop_last
参数将被忽略。num_workers
:指定用于加载数据的子进程数量。增加num_workers
可以加快数据加载速度,但会占用更多的系统资源。在实际应用中,需要根据系统的硬件资源和数据加载的复杂度来合理设置num_workers
的值。collate_fn
:是一个用于将样本整理成批次的函数。默认情况下,使用default_collate
函数将样本整理成张量形式的批次。如果样本的结构比较复杂,例如包含不同长度的序列等,可以自定义collate_fn
函数来实现合适的整理方式。pin_memory
:表示是否将数据加载到 pinned memory 中。将数据加载到 pinned memory 中可以加快数据从 CPU 到 GPU 的传输速度,但会占用更多的内存。在使用 GPU 进行训练时,通常可以将pin_memory
设置为True
。drop_last
:表示当数据集的样本数量不能被batch_size
整除时,是否丢弃最后一个不完整的批次。如果设置为True
,则丢弃最后一个不完整的批次;如果设置为False
,则最后一个批次的样本数量可能小于batch_size
。timeout
:指定数据加载的超时时间。如果在指定时间内无法获取数据,则会抛出异常。persistent_workers
:表示是否在数据加载完成后保持子进程的存活。如果设置为True
,则在训练过程中,数据加载子进程将一直存活,避免了每次 epoch 重新启动子进程的开销,但会占用更多的系统资源。
五、代码演示与解释
5.1 数据集加载演示与显示
下面通过一个完整的代码示例,演示如何加载 CIFAR - 10 数据集,并显示其中的一些图像:
import torch
import torchvision
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
# 定义数据预处理转换
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
# 加载训练数据集
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
# 加载测试数据集
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
# 创建数据加载器
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=4)
# 显示一些训练图像
def show_images(images, labels, classes):
fig, axes = plt.subplots(1, 5, figsize=(15, 3))
for i in range(5):
axes[i].imshow(images[i].permute(1, 2, 0))
axes[i].set_title(classes[labels[i]])
axes[i].axis('off')
plt.show()
# 获取一个批次的训练数据
dataiter = iter(train_loader)
images, labels = dataiter.next()
# 定义类别名称
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse','ship', 'truck')
# 显示图像
show_images(images, labels, classes)
在上述代码中,首先定义了数据预处理转换,将图像转换为张量并进行归一化处理。然后,使用datasets.CIFAR10
类分别加载训练数据集和测试数据集,并创建了数据加载器。接着,定义了show_images
函数用于显示图像,通过获取一个批次的训练数据,并调用show_images
函数,显示了 5 张训练图像及其对应的类别标签。运行上述代码,可以看到 CIFAR - 10 数据集中的一些图像及其类别标签,直观地了解数据集的内容。
5.2 自定义数据集加载演示
下面演示如何加载自定义的图像数据集。假设我们有一个自定义的图像数据集,其目录结构基本如下:
custom_dataset/
train/
class1/
image1.jpg
image2.jpg
...
class2/
image1.jpg
image2.jpg
...
test/
class1/
image1.jpg
image2.jp
六、小结
我总结了一些常见的数据集,并给出了介绍,包括数据集的类型、规模等,不同的数据集的使用,对实验产生了不同的结果影响,比如我之前复现的一篇论文《CameraCtrl》,是基于Realestate数据集,虽然是视频数据集,但也是与其他数据集相对比之后才采用的,此数据集包含复杂的相机轨迹,对于训练复杂模型更加有益,模型经过训练之后,学习到的结果更好,因此数据集的选择对做实验十分重要,需要对比。感谢大家的观看O(∩_∩)O。