pytorch中的transform用法
在 PyTorch 中,transform
主要用于数据预处理和数据增强,尤其在计算机视觉任务中,通过 torchvision.transforms
模块进行图像的变换。transforms
可以对图像进行一系列操作,如裁剪、旋转、缩放、归一化等,以增强数据集的多样性,并提高模型的泛化能力。
1. torchvision.transforms
模块概述
torchvision.transforms
是 PyTorch 提供的一个图像转换工具,它包含一系列的变换操作。常见的转换操作包括:
- 图像大小调整(Resize)
- 裁剪(Crop)
- 图像翻转(Flip)
- 颜色调整(Color Jitter)
- 图像归一化(Normalization)
- 转换为张量(ToTensor)
2. 常用的 transforms
操作
from torchvision import transforms
1) transforms.ToTensor()
将图像转换为 PyTorch 张量(Tensor),并且自动将图像的像素值缩放到 [0, 1] 的范围内。
transform = transforms.ToTensor()
image_tensor = transform(image)
2) transforms.Resize()
调整图像的大小,可以指定一个单一的大小或宽度/高度。
transform = transforms.Resize((224, 224)) # 调整为 224x224 的尺寸
image_resized = transform(image)
3) transforms.CenterCrop()
和 transforms.RandomCrop()
CenterCrop
会从图像的中心裁剪出指定大小的区域;RandomCrop
会随机裁剪出一个指定大小的区域。
transform = transforms.CenterCrop(224) # 从中心裁剪出 224x224 的区域
image_cropped = transform(image)
# 或者使用随机裁剪
transform = transforms.RandomCrop(224)
image_random_cropped = transform(image)
4) transforms.RandomHorizontalFlip()
和 transforms.RandomVerticalFlip()
进行水平或垂直的随机翻转。
transform = transforms.RandomHorizontalFlip(p=0.5) # 50% 的概率进行水平翻转
image_flipped = transform(image)
5) transforms.Normalize()
对图像的每个通道进行归一化。通常用来调整图像的颜色通道,使其符合模型训练时的要求。
transform = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
image_normalized = transform(image_tensor) # 对每个通道进行归一化
6) transforms.ColorJitter()
随机调整图像的亮度、对比度、饱和度和色相。适用于增强数据集的多样性。
transform = transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1)
image_jittered = transform(image)
7) transforms.RandomRotation()
对图像进行随机旋转。
transform = transforms.RandomRotation(30) # 随机旋转 -30 到 30 度之间
image_rotated = transform(image)
3. 多种 transforms
组合使用
通常,我们会将多个变换操作组合成一个 Compose
,使得一个图像依次经过多个变换步骤。
transform = transforms.Compose([
transforms.Resize(256),
transforms.RandomCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
image_transformed = transform(image)
上面的代码会将图像:
- 调整为 256x256
- 随机裁剪为 224x224
- 进行水平翻转
- 转换为张量
- 归一化图像
4. 结合 Dataset
使用 transforms
通常,我们会将 transforms
与 torch.utils.data.Dataset
和 torch.utils.data.DataLoader
结合使用,用于训练过程中的数据预处理。
from torchvision import datasets
from torch.utils.data import DataLoader
transform = transforms.Compose([
transforms.Resize(256),
transforms.RandomCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
train_dataset = datasets.ImageFolder(root='path_to_train_data', transform=transform)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
在上面的代码中,ImageFolder
是一个 PyTorch 提供的通用图像数据集类,用于加载目录结构为类标签的图像数据。transform
用于对数据集中的每个图像进行预处理。
5. 自定义 transform
类
如果 torchvision.transforms
中的预定义操作不能满足需求,我们还可以自定义一个转换类。例如,如果你想为每张图片添加噪声:
from PIL import Image
import numpy as np
class AddGaussianNoise(object):
def __init__(self, mean=0., std=1.):
self.mean = mean
self.std = std
def __call__(self, image):
image = np.array(image)
noise = np.random.normal(self.mean, self.std, image.shape)
noisy_image = image + noise
noisy_image = np.clip(noisy_image, 0, 255) # 保证像素值在 [0, 255] 范围内
return Image.fromarray(noisy_image.astype(np.uint8))
# 使用自定义转换
transform = transforms.Compose([
transforms.Resize(256),
transforms.RandomCrop(224),
AddGaussianNoise(mean=0, std=0.1), # 添加高斯噪声
transforms.ToTensor(),
])
image = Image.open('path_to_image.jpg')
transformed_image = transform(image)
总结
transforms
是 PyTorch 中处理图像数据的一组强大工具,适用于图像预处理和数据增强。- 通过
transforms.Compose()
可以组合多个转换操作。 ToTensor()
、Resize()
、RandomCrop()
、Normalize()
等是常用的转换。- 通过
DataLoader
可以高效地加载批量数据,并在训练过程中对每个样本应用转换。