pytorch数据读入
在机器学习和深度学习中,数据读入是一个非常重要的步骤。PyTorch提供了非常灵活的数据读取方式,主要是通过Dataset
和DataLoader
两个类来完成的。
一、PyTorch常见的数据读取方式
PyTorch中的数据读取主要依赖于两个核心类:Dataset
和DataLoader
。
Dataset
: 定义数据集的格式和数据变换形式。DataLoader
: 将数据集按批次读入,并且可以进行多进程数据加载。
1.1 Dataset
类
Dataset
类是一个抽象类,我们通常需要继承它并实现以下三个方法:
__init__
: 初始化函数,用于传入外部参数,如数据集的路径、数据变换等。__getitem__
: 用于逐个读取样本数据,可以进行一定的数据变换,并返回训练所需的数据和标签。__len__
: 返回数据集的样本总数。
1.2 DataLoader
类
DataLoader
类用于按批次读入数据。它提供了许多参数来控制数据的读取方式,如batch_size
、num_workers
、shuffle
等。
二、构建自己的数据读取流程
为了更好地理解这些概念,我们可以通过一个具体的例子来展示如何构建自己的数据读取流程。
2.1 使用PyTorch自带的ImageFolder
类
PyTorch提供了一个方便的ImageFolder
类,用于读取按一定结构存储的图片数据。我们可以直接使用这个类来读取数据集。
import torch
from torchvision import datasets
# 假设train_path和val_path是数据集的路径
train_data = datasets.ImageFolder(train_path, transform=data_transform)
val_data = datasets.ImageFolder(val_path, transform=data_transform)
这里的data_transform
可以对图像进行一定的变换,如翻转、裁剪等操作。你可以在transform
模块中定义自己的变换。
2.2例子
假设你有一个图像数据集,用于训练一个猫狗分类模型。你的数据集结构如下:
dataset/
train/
cats/
cat1.jpg
cat2.jpg
cat3.jpg
dogs/
dog1.jpg
dog2.jpg
dog3.jpg
val/
cats/
val_cat1.jpg
val_cat2.jpg
dogs/
val_dog1.jpg
val_dog2.jpg
在这个数据集中:
train
文件夹包含训练数据,其中有两个子文件夹cats
和dogs
,分别存放猫和狗的图片。val
文件夹包含验证数据,同样有两个子文件夹cats
和dogs
。
具体代码示例
接下来我们写一段Python代码来使用ImageFolder
类加载这个数据集。
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
# 步骤1: 定义数据变换
data_transform = transforms.Compose([
transforms.Resize((256, 256)), # 调整图像大小为256x256
transforms.RandomHorizontalFlip(), # 进行随机水平翻转
transforms.ToTensor(), # 将图像转换为张量格式
])
# 步骤2: 定义数据集路径
train_path = 'dataset/train' # 训练集路径
val_path = 'dataset/val' # 验证集路径
# 步骤3: 使用ImageFolder加载数据集
train_data = datasets.ImageFolder(train_path, transform=data_transform)
val_data = datasets.ImageFolder(val_path, transform=data_transform)
# 步骤4: 创建DataLoader以便批量加载数据
train_loader = DataLoader(train_data, batch_size=32, shuffle=True)
val_loader = DataLoader(val_data, batch_size=32, shuffle=False)
# 步骤5: 访问一个批次的数据
for images, labels in train_loader:
print(f'Batch of images shape: {images.shape}') # 打印图片的batch尺寸
print(f'Batch of labels shape: {labels.shape}') # 打印标签的batch尺寸
break # 只输出一批数据,便于查看
代码详解
-
导入必要的库
import torch from torchvision import datasets, transforms from torch.utils.data import DataLoader
- 导入PyTorch和必要的模块。
-
定义数据变换
data_transform = transforms.Compose([ transforms.Resize((256, 256)), # 将图像大小调整为256x256 transforms.RandomHorizontalFlip(), # 随机水平翻转 transforms.ToTensor(), # 将图像转换为张量格式 ])
- 在这里,定义了对每张图像的处理,首先将其调整为256x256像素,然后进行随机翻转,最后将图像转换为PyTorch可以理解的张量格式。
-
定义数据集路径
train_path = 'dataset/train' # 训练集路径 val_path = 'dataset/val' # 验证集路径
- 定义你的训练集和验证集的文件夹路径。
-
加载数据集
train_data = datasets.ImageFolder(train_path, transform=data_transform) val_data = datasets.ImageFolder(val_path, transform=data_transform)
- 使用
ImageFolder
加载训练集和验证集。 ImageFolder
自动根据文件夹名称生成标签:cats
会被标记为类别0,dogs
会被标记为类别1。
- 使用
-
创建DataLoader
train_loader = DataLoader(train_data, batch_size=32, shuffle=True) val_loader = DataLoader(val_data, batch_size=32, shuffle=False)
DataLoader
用于将数据集分批次加载。这里定义了每批次包含32个图像。shuffle=True
表示在每个epoch开始时打乱数据顺序,通常在训练过程中使用。
-
访问数据
for images, labels in train_loader: print(f'Batch of images shape: {images.shape}') # 打印图片的batch尺寸 print(f'Batch of labels shape: {labels.shape}') # 打印标签的batch尺寸 break # 只输出一批数据,便于查看
- 最后,遍历
train_loader
中的一批次数据,输出图像和标签的维度信息。 images.shape
将在(batch_size, channels, height, width)
的格式下输出图像的尺寸。labels.shape
输出的会是一个一维张量,表示每张图像的分类标签。
- 最后,遍历