深度学习基础(2024-11-02更新到图像尺寸变换 与 裁剪)
1. 名词解释
FFN
- FFN : Feedforward Neural Network,前馈神经网络
- 馈神经网络是一种基本的神经网络架构,也称为多层感知器(Multilayer Perceptron,MLP)
- FFN 一般主要是包括多个全连接层(FC)的网络,其中,全连接层间可以包含 : 激活层、BN层、Dropout 层。
MLP 与 FFN 的区别
在机器学习和深度学习中,MLP(多层感知机)和 FFN(前馈神经网络)在很大程度上可以视为同义词,都指代了一个具有多个层的前馈神经网络结构。
- MLP(多层感知机)更偏向于表达网络结构(多个全连接层)
- FFN(前馈神经网络)更偏向于表达数据以前馈的方式流动
MLP 和 FFN 通常指的是只包含全连接层 和激活函数的神经网络结构。这两者都是基本的前馈神经网络类型,没有包含卷积层或其他复杂的结构。
Logit
“Logit” 通常指的是神经网络中最后一个隐藏层的输出,经过激活函数之前的值。比如:
- 对于二分类问题,logit 是指网络输出的未经过 sigmoid 函数处理的值
- 对于多分类问题,logit 是指网络输出的未经过 softmax 函数处理的值
NLL
NLL 是 Negative Log-Likelihood(负对数似然)的缩写。
在深度学习中,特别是在分类问题中,NLL 经常与交叉熵损失(Cross-Entropy Loss)等价使用。
Anchor Box 与 Anchor Point
- Anchor box 通常表示 一个包含位置和大小信息的四元组 ( x , y , w , h ) (x, y, w, h) (x,y,w,h),而 Anchor point 通常表示 一个二元组 ( x , y ) (x, y) (x,y)。 其中, x x x和 y y y表示框的中心坐标, w w w 和 h h h表示框的宽度和高度。
- Anchor box 是目标检测中用于定义目标位置和大小的一种方式。而 Anchor point 主要用于在图像上生成 anchor box 的位置,生成的 anchor box 会在 anchor point 的周围不同尺寸和宽高比的情况下进行缩放,形成一系列不同形状的框。
parameter efficient
参数效率高,指的是网络在达到良好性能的同时所使用的参数数量较少。
Deep Supervision
Deep Supervision 是一种训练策略,旨在提高网络的梯度流动,并促使网络更快地收敛,并且有助于缓解梯度消失问题。Deep Supervision 的核心思想是在网络的不同层中引入额外的监督信号,而不仅仅在最后一层输出进行监督训练。具体来说:Deep Supervision 会使用网络的中间层输出,计算出一部分损失函数,然后和网络最后一层的损失函数一起,对网络的参数进行优化。
DP 与 DDP
DP : DataParallel,数据并行
DDP :Distributed Data Parallel,分布式数据并行
感受野(Receptive Field)
1、介绍
感受野(receptive field)是卷积神经网络输出特征图上的像素点在原始图像上所能看到的(映射的)区域的大小,它决定了该像素对输入图像的感知范围(获取信息的范围)。较小的感受野可以捕捉到更细节的特征,而较大的感受野可以捕捉到更全局的特征。
如果连续进行 2次卷积操作,卷积核大小都为 3x3,stride=1, padding=0, 如下图,layer3上的每一个像素点在 layer1上的感受野 为 5x5
2、感受野计算公式
感受野计公式 :
F
(
i
)
=
(
F
(
i
+
1
)
−
1
)
×
S
t
r
i
d
e
+
K
s
i
z
e
F(i)=(F(i+1)-1)\times Stride + Ksize
F(i)=(F(i+1)−1)×Stride+Ksize 或
F
i
n
=
(
F
o
u
t
−
1
)
×
S
t
r
i
d
e
+
K
s
i
z
e
F_{in}=(F_{out}-1)\times Stride + Ksize
Fin=(Fout−1)×Stride+Ksize
其中:
- F ( i ) F(i) F(i) :在第 i i i层的感受野
- S t r i d e Stride Stride:第 i i i层步距
- K s i z e Ksize Ksize:第 i i i层卷积或池化的 kernel size
3、计算举例
求 :layer3 上的每个像素在 layer1 上的感受野。
1)先来计算 layer3 上的一个像素(
F
(
3
)
=
1
F(3)=1
F(3)=1)在 layer2 上的感受野 :
F
(
2
)
=
(
F
(
3
)
−
1
)
×
S
t
r
i
d
e
+
K
s
i
z
e
=
(
1
−
1
)
×
2
+
2
=
2
F(2) = (F(3)-1) \times Stride + Ksize = (1 -1) \times 2 + 2 = 2
F(2)=(F(3)−1)×Stride+Ksize=(1−1)×2+2=2
2)计算 layer3 上的一个像素(
F
(
3
)
=
1
,
F
(
2
)
=
2
F(3)=1, \; F(2)=2
F(3)=1,F(2)=2 )在 layer1 上的感受野 :
F
(
1
)
=
(
F
(
2
)
−
1
)
×
S
t
r
i
d
e
+
K
s
i
z
e
=
(
2
−
1
)
×
2
+
3
=
5
F(1)=(F(2)-1)\times Stride + Ksize =(2 -1)\times 2 + 3 = 5
F(1)=(F(2)−1)×Stride+Ksize=(2−1)×2+3=5
如果仅计算 layer2 上的一个像素( F(2)=1 )在 layer1 上的感受野 :
F
(
1
)
=
(
F
(
2
)
−
1
)
×
S
t
r
i
d
e
+
K
s
i
z
e
=
(
1
−
1
)
×
2
+
3
=
3
F(1)=(F(2)-1)\times Stride + Ksize = (1 -1)\times 2 + 3 = 3
F(1)=(F(2)−1)×Stride+Ksize=(1−1)×2+3=3
2. tensor 相关
tensor 内部存储结构
1、数据区域和元数据
PyTorch 中的 tensor 内部结构通常包含了 数据区域(Storage) 和 元数据(Metadata) :
- 数据区域 : 存储了 tensor 的实际数据,且数据被保存为连续的数组。比如: a = torch.tensor([[1, 2, 3], [4, 5, 6]]),它的数据在存储区的保存形式为 [1, 2, 3, 4, 5, 6]
- 元数据 :包含了 tensor 的一些描述性信息,比如 : 尺寸(Size)、步长(Stride)、数据类型(Data Type) 等信息
占用内存的主要是 数据区域,且取决于 tensor 中元素的个数, 而元数据占用内存较少。
采用这种 【数据区域 + 元数据】 的数据存储方式,主要是因为深度学习的数据动辄成千上万,数据量巨大,所以采取这样的存储方式以节省内存
2、查看 tensor 的存储区数据: storage()
虽然 .storage()
方法即将被弃用,而是改用 .untyped_storage()
,但为了笔记中展示方便,我们仍然使用 .storage()
方法。.untyped_storage()
方法的输出太长了,不方便截图放在笔记中。
a = torch.tensor([[1, 2, 3],
[4, 5, 6]])
print(a.storage())
3、查看 tensor 的步长: stride()
stride()
: 在指定维度 (dim) 上,存储区中的数据元素,从一个元素跳到下一个元素所必须的步长
a = torch.randn(3, 2)
print(a.stride()) # (2, 1)
解读:
在第 0 维,想要从一个元素跳到下一个元素,比如从 a[0][0] 到 a[1][0] ,需要经过 2个元素,步长是 2
在第 1 维,想要从一个元素跳到下一个元素,比如从 a[0][0] 到 a[0][1], 需要经过 1个元素,步长是 1
4、查看 tensor 的偏移量:storage_offset()
表示 tensor 的第 0 个元素与真实存储区的第 0 个元素的偏移量
a = torch.tensor([1, 2, 3, 4, 5])
b = a[1:] # tensor([2, 3, 4, 5])
c = a[3:] # tensor([4, 5])
print(b.storage_offset()) # 1
print(c.storage_offset()) # 3
- b 的第 0 个元素与 a 的第 0 个元素之间的偏移量是 1
- c 的第 0 个元素与 a 的第 0 个元素之间的偏移量是 3
5、代码举例
-
一般来说,一个 tensor 有着与之对应的 storage, storage 是在 data 之上封装的接口。
-
不同 tensor 的元数据一般不同,但却可能使用相同的 storage。
-
data_ptr()
:- 返回的是张量数据 (storage 数据)存储的实际内存地址,确切来说是张量数据的起始内存地址。
- data_ptr 中的 ptr 是 pointer(指针)的缩写,对应于 C 语言中的指针,因为 Python 的底层就是由 C 实现的
-
id(a)
:- 返回的是 a 在 Python 内存管理系统中的唯一标识符。虽然这个标识符通常与对象的内存地址有关,但它并不直接表示内存地址。
1)观察一
import torch
a = torch.arange(0, 6)
print('a = {}\n'.format(a))
print('tensor a 存储区的数据内容 :{}\n'.format(a.storage()))
print('tensor a 相对于存储区数据的偏移量 :{}\n'.format(a.storage_offset()))
print('*'*20, '\n')
b = a.view(2,3)
print('b = {}\n'.format(b))
print('tensor b 存储区的数据内容 :{}\n'.format(b.storage()))
print('tensor b 相对于存储区数据的偏移量 :{}\n'.format(b.storage_offset()))
2)观察二
import torch
a = torch.tensor([1, 2, 3, 4, 5, 6])
b = a.view(2, 3)
print(a.data_ptr()) # 140623757700864
print(b.data_ptr()) # 140623757700864
print(id(a)) # 4523755392
print(id(b)) # 4602540464
a.data_ptr()
与b.data_ptr()
一样,说明 tensor a 和 tensor b 共享相同的存储区,即,它们指向相同的底层数据存储对象。id(a)
和id(b)
不一样,是因为虽然 a 和b 共享storage 数据,但是 它们 有不同的 size 或者 strides 、 storage_offset 等其他属性
3)观察三
import torch
a = torch.tensor([1, 2, 3, 4, 5, 6])
c = a[2:]
print(c.storage())
print('\n', '*'*20, '\n')
print('tensor a 首元素的内存地址 : {}'.format(a.data_ptr()))
print('tensor c 首元素的内存地址 : {}'.format(c.data_ptr()))
print(c.data_ptr() - a.data_ptr())
print('\n', '*'*20, '\n')
c[0] = -100
print(a)
data_ptr()
返回 tensor 首元素的内存地址- c 和 a 的首元素内存地址相差 16,每个元素占用 8 个字节(LongStorage), 也就是首元素相差两个元素
- 改变 c 的首元素, a 对应位置的元素值也被改变
6、总结
- 由上可知,绝大多数操作并不修改 tensor 的数据,只是修改了 tensor 的元数据,比如修改 tensor 的 offset 、stride 和 size ,这种做法更节省内存,同时提升了处理速度。
- 有些操作会导致 tensor 不连续,这时需要调用 torch.contiguous 方法将其变成连续的数据,该方法会复制数据到新的内存,不再与原来的数据共享 storage。
3.Dataset 与 DataLoader
- Dataset 作用 :
- 定义和管理如何获取单个数据样本及其标签
- 包含 加载数据/读取数据、预处理数据、图像增强 等一系列操作
- 返回 单个数据样本 的处理结果
- DataLoader 作用 :
- 指定数据读取规则,一般通过 sampler 指定
- 指定 batch数据的打包规则,通过 collate_fn 指定
- 每次迭代,返回的是 一个batch 的数据
生成 Dataset 方式一 :自定义Dataset
所谓的 自定义 dataset ,即:我们自己去写一个 Dataset 类 :
- 一般需要继承 torch.utils.data.Dataset
- 继承 torch.utils.data.Dataset 主要是为了与 DataLoader 保持兼容,确保数据集遵循 DataLoader 的接口标准,方便后续使用 PyTorch 提供的工具,比如 :批量加载、打乱数据、并行处理等功能
- 并且满足和 DataLoader 进行交互的规范 :
- DataLoader 会调用 Dataset 的 len() 和 getitem() 方法,所以自定义 Dataset 类必须实现这两个方法,如此才能保证 DataLoader 可以正确地加载和操作你的数据集
1、自定义 Dataset 的三个重要方法
创建自定义 Dataset 时,必须实现的3个方法 :init()、len()、 getitem()。
这些方法定义了数据集的基本结构和行为,也是 DataLoader 可以正确的从 Dataset 中读取数据的基础。
1)init 方法
- 参数: 根据需要传递一些参数,例如文件路径、数据转换等。
- 作用: 可以在这里进行一些初始化工作,例如:设置文件路径、定义数据转换transforms 等。
def __init__(self, data_folder, train, transform=None):
self.data_folder = data_folder
self.transform = transform
self.file_list = os.listdir(data_folder)
self.train = train
2)len 方法
def __len__(self):
return len(self.file_list)
- 返回值: 需返回数据集中的样本的总数。
- 作用:
- 方便通过调用 len(dataset) 来获取数据量,其中 dataset 为 Dataset 对象
- Dataloader 会用它 和 batch_size 一起来计算一个epoch 要迭代多少个 steps: s t e p s = l e n ( d a t a s e t ) b a t c h s i z e steps = \frac{len(dataset)}{batchsize} steps=batchsizelen(dataset)
3)getitem 方法
def __getitem__(self, idx):
img_name = os.path.join(self.data_folder, self.file_list[idx])
original_image = Image.open(img_name)
label = img_name.split('_')[-1].split('.')[0]
if self.train:
image = self.transform(original_image)
else:
image = self.transform(original_image)
return image, label
- 参数: index 是样本的索引。
- 返回值: 返回数据集中索引指定的样本。通常是一个包含输入数据和对应标签的元组。这里可以根据自己的需求,进行自定义。
- 作用: 根据给定的索引返回数据集中的一个样本。这是用于获取数据集中单个样本的方法。
比如,可以通过 dataset[0] 来获取 dataset 中的索引为 0 的样本
以上这三个方法一起定义了 PyTorch 中的 dataset 类,并支持使用 torch.utils.data.DataLoader 来加载数据并进行训练。
2、使用举例
用 CIRFAR-100 数据集生成 Dataset
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import os
class CustomDataset(Dataset):
def __init__(self, data_folder, train, transform=None):
self.data_folder = data_folder
self.transform = transform
self.file_list = os.listdir(data_folder)
self.train = train
def __getitem__(self, idx):
img_name = os.path.join(self.data_folder, self.file_list[idx])
original_image = Image.open(img_name)
label = img_name.split('_')[-1].split('.')[0]
if self.train:
image = self.transform(original_image)
else:
image = self.transform(original_image)
return image, label
def __len__(self):
return len(self.file_list)
images_dir = "/Users/enzo/Documents/GitHub/dataset/CIFAR/cifar-100-images/train"
dataset = CustomDataset(images_dir, train=True, transform=transforms.ToTensor())
print(len(dataset))
# 通过 dataset 对象,获取索引为 0 的样本
sample_image, sample_label = dataset[0]
print("Sample Image Shape:", sample_image.shape)
print("Sample Label:", sample_label)
输出:
生成 Dataset 方式二 :torchvision.datasets 模块
1、pytorch 官方支持下载的数据集
官网地址 : 点击查看
注 :
- 对于一部分数据集,提供下载功能
- 对于一部分数据集,不提供下载功能 (具体情况取决于数据集的来源和许可协议)
2、torchvision.datasets 模块
以获取 MNIST 数据集为例 (pytorch官方文档地址 : 点击查看)
MNIST 全称:mixed national institute of standards and technology database
train_dataset = torchvision.datasets.MNIST(root,
train=True,
transform=None,
target_transform= None
download=True)
参数 :
- root :数据集存放的路径
- train:如果是True, 下载训练集 trainin.pt; 如果是False,下载测试集 test.pt。 默认是True
- transform:一系列作用在PIL图片上的转换操作
- download:是否下载数据集,默认为 False
- 若设置 download=True
- root 目录下没有该数据集,数据集将会被下载到 root 指定的位置。
- root 目录下已经存在该数据集,则不会重新下载,而是会直接使用已存在的数据,以节省时间
- 若设置 download=False,程序将会在 root 指定的位置查找数据集,如果数据集不存在,则会抛出错误。
- 若设置 download=True
3、举例 1:torchvision.datasets.MNIST
- 因为是单通道,所以 transforms.Normalize 的均值和标准差 仅指定了一个值
- 记得把数据集的下载地址换掉,换成你想要它下载到的位置
import torchvision
from torchvision.transforms import transforms
import torch.utils.data as data
import matplotlib.pyplot as plt
batch_size = 5
my_transform = transforms.Compose([transforms.ToTensor(),
transforms.Normalize(mean=[0.5], # mean=[0.485, 0.456, 0.406]
std=[0.5])]) # std=[0.229, 0.224, 0.225]
train_dataset = torchvision.datasets.MNIST(root="./",
train=True,
transform=my_transform,
download=True)
val_dataset = torchvision.datasets.MNIST(root="./",
train=False,
transform=my_transform,
download=True)
train_loader = data.DataLoader(train_dataset,
batch_size=batch_size,
shuffle=True)
val_loader = data.DataLoader(val_dataset,
batch_size=batch_size,
shuffle=True)
print(len(train_dataset))
print(len(val_dataset))
image, label = next(iter(train_loader))
print(image.shape)
print(label)
for i in range(batch_size):
plt.subplot(1, batch_size, i + 1)
plt.title(label[i].item())
plt.axis("off")
plt.imshow(image[i].permute(1, 2, 0))
plt.show()
输出:
4、举例 2:torchvision.datasets.CocoDetection
官方文档 : 点击查看
torchvision.datasets.CocoDetection 不支持 COCO 数据集下载
在使用 torchvision.datasets.CocoDetection
之前,需要确保已经下载并淮备好COCO数
据集的图像和标注文件。然后使用 torchvision.datasets.CocoDetection 类来加载 COCO数据集。
torchvision.datasets.CocoDetection(root,
annFile,
transform=None,
target_transform=None,
transforms=None)
参数 :
- root : 指定图片地址 (本地已经下载下来的图像地址)
- annFile : 指定标注文件地址( 本地已经下载下来的标注文件地址)
- transform : 图像处理 (用于PIL)
- target_transform : 标注处理
- transforms : 图像和标注的处理
使用举例:
- 记得把数据集的下载地址换掉,换成你的 COCO数据集地址
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
import torch
import torchvision
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torchvision.transforms import functional as F
import random
def collate_fn_coco(batch):
return tuple(zip(*batch))
coco_det = datasets.CocoDetection(root="./COCO2017/train2017",
annFile="./COCO2017/annotations/instances_train 2017.json")
sampler = torch.utils.data.SequentialSampler(coco_det) # RandomSampler
batch_sampler = torch.utils.data.BatchSampler(sampler, 1, drop_last=True)
data_loader = torch.utils.data.DataLoader(coco_det,
batch_sampler=batch_sampler,
collate_fn=collate_fn_coco)
# 可视化
iterator = iter(data_loader)
imgs, gts = next(iterator)
img, gts_one_img = imgs[0], gts[0]
bboxes = []
ids = []
for gt in gts_one_img:
bboxes.append([gt['bbox'][0],
gt['bbox'][1],
gt['bbox'][2],
gt['bbox'][3]
])
ids.append(gt['category_id'])
fig, ax = plt.subplots()
for box, id in zip(bboxes, ids):
x = int(box[0])
y = int(box[1])
w = int(box[2])
h = int(box[3])
rect = plt.Rectangle((x, y), w, h, edgecolor='r', linewidth=2, facecolor='none')
ax.add_patch(rect)
ax.text(x, y, id, backgroundcolor="r")
plt.axis("off")
plt.imshow(img)
plt.show()
输出效果:
DataLoader
1、torch.utils.data.DataLoader
官方文档 :点击查看
from torch.utils.data import DataLoader
data_loader = DataLoader(dataset,
batch_size=1,
shuffle=None,
sampler=None,
batch_sampler=None,
num_workers=0,
collate_fn=None,
pin_memory=False,
drop_last=False,
timeout=0
)
参数:
- dataset : 加载数据的数据集
- batch size : 每批返回的数据量,默认值是 1
- shuffle:是否在每个 epoch 内将数据打乱顺序。默认值为False
- sampler :从数据集中提取的样本序列。可以用来自定义样本的采样策路。默认值为None
- batch_sampler :与sampler类似,但是一次返回一个 batch的索引,用于自定义 batch。它与 与 batch size、shuffle、sampler 和 drop last 互斥
- num workers : 用于数据加载的子进程数。0表示主进程加载。默认值为0
- collate_fn: 用于指定如何组合样本数据。如果为None,那么将默认使用默认的组合方法
- drop_last : 如果数据集的大小不能被 batch _size 整除,那么是否丢弃最后一个数据批次。默认值为 False
- pin_memory : 将数据固定在内存的锁页内存中,加速数据读取的速度。默认值为False.
- timeout : workers :等待 collect 一个 batch 的数据的超时时间。默认为 0,表示一直等待
2、常用参数图示
dataset 对 Dataloader 有 2个作用 :
- 通过 dataset 的 length 方法,dataloader 可以知道数据量,从而根据数据量生成相应的索引列表
- dataloader 会将索引,传给 dataset 的 getitem 方法,通过 getitem 方法对数据进行处理,并返回处理好的数据
3、Dataset 与 Dataloader 的内部交互细节 举例
num_workers 与 pin_memory
1、参数 num_workers
参数 num_workers 参数用于指定 加载数据的子进程的数量
- num_workers=0 :(默认值) 表示只有主进程去加载 batch数据,这个可能会是一个瓶颈。
- num_workers=1 :表示只有一个子进程加载数据,主进程不参与,这仍可能导致速度慢。
- num_workers>0 :表示指定数量的子进程并行加载数据,且主进程不参与。
增加num_workers可以提高加载速度,但也会增加 CPU 和 内存的使用。
通常建议将 num_workers 参数设置为等于或小于 CPU 核心数,以有效平衡数据加载效率和系统资源占用率。
进程之间是动态调度的,谁先做完一个样本:
batch_size = 16
nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8]) # number of workers
train_dataloader = torch.utils.data.DataLoader(train_dataset,
batch_size=batch_size,
num_workers=nw,
shuffle=True,
pin_memory=True,
collate_fn=collate_fn)
2、数据加载过程是如何并行的
一个进程处理一个 batch 的数据,假设设置
num_workers=2
,则 进程1 处理一个batch 的数据,进程2 处理另一个batch 的数据
并行工作流程 :
- 初始化:创建 DataLoader 实例时,通过参数 num_workers 指定并行加载的子进程数量
- 子进程加载数据:子进程独立于主进程运行,每个子进程的拿着一个batch 的索引,并行的到 dataset 的 getitem 中预处理数据
- 数据准备:处理好的数据,放入缓冲区以备主进程请求
- 数据请求:主进程在 for 循环中请求下一个 batch
- 数据传输:主进程请求数据时,从缓冲区获取已经准备好的 batch
- 循环迭代:主进程不断请求数据,子进程并行的处理后续的 batch 数据
3、pin_memory
- 若设置
pin_memory=True
,数据会被加载到CPU的锁定内存中,从而提高数据从 CPU 到 GPU 的传输效率
这是因为锁定的内存(pinned memory
)可以更快地被复制到GPU,因为它是连续的,并且已经准备好被传输。
- 若设置
pin_memory=False
,则数据是被存放在分页内存(pageable memory)中,当我们想要把数据从 cpu 移动到 gpu 上 (执行.to('cuda')
的时候), 需要先将数据从分页内存中 移动到锁页内存中,然后再传输到 GPU 上
所以,设置 pin_memory=True ,节省的是 将数据从 分页内存移动到锁页内存中 的这段时间
如果你的训练完全在CPU上进行,不涉及GPU,那就没有必要设置 pin_memory=True
。因为在这种情况下,数据不需要被传输到GPU,因此不需要使用锁定内存来加速这一过程。可以将 pin_memory
设置为 False
,以简化内存管理。
sampler 与 batch_sampler
1、sampler
torch.utils.data.DataLoader 的参数 sampler 接收的通常是一个实现了 Sampler 接口的对象,比如 :
sampler = SequentialSampler(dataset) # 使用 SequentialSampler
dataloader = DataLoader(dataset, batch_size=8, sampler=sampler)
通过 sampler 对象来控制数据集的索引顺序,从而影响数据从数据集中的抽取方式
1)pytorch 提供的,可以直接使用的几种 sampler
# 顺序抽样,按照数据集的顺序逐个抽取样本
torch.utils.data.sampler.SequentialSampler()
# 随机抽样,数据集中的样本以随机顺序被抽取
torch.utils.data.sampler.RandomSampler()
# 从指定的样本索引子集内进行随机抽样
torch.utils.data.sampler.SubsetRandomSampler()
# 根据样本的权重随机抽样,不同样本有不同的抽样概率
torch.utils.data.sampler.WeightedRandomSampler()
2)可以自定义 sampler,比如以下是 yolov5 中自定义的 sampler :
参数 sampler 有一部分功能,是和 参数 shuffle 是重叠的:
- SequentialSampler 效果等价于 shuffle=False
- RandomSampler 效果等价于 shuffle=Ture
Pytorch 提供 sampler 参数,主要是为提升灵活性,支持用户更灵活地设计数据加载的方式
下面我们主要介绍 SequentialSampler 和 RandomSampler, 只要大家通过 SequentialSampler 、RandomSampler 掌握了 sampler 的工作原理,便可以愉快的自定义的去设计 sampler 了。
1)顺序采样 SequentialSampler
作用 :接收一个 Dataset 对象,输出数据包中样本量的顺序索引
举例 1
import torch.utils.data.sampler as sampler
data = list([17, 22, 3, 41, 8])
seq_sampler = sampler.SequentialSampler(data_source=data)
for index in seq_sampler:
print("index: {}".format(index))
相关源码
class SequentialSampler(Sampler):
data_source: Sized
def __init__(self, data_source: Sized) -> None:
self.data_source = data_source
def __iter__(self) -> Iterator[int]:
return iter(range(len(self.data_source)))
def __len__(self) -> int:
return len(self.data_source)
- init 接收参数:Dataset 对象
- iter 返回一个可迭代对象(返回的是索引值),因为 SequentialSampler 是顺序采样,所以返回的索引是顺序数值序列
- len 返回 dataset 中数据个数
举例 2
from torch.utils.data import DataLoader, Dataset
from torch.utils.data.sampler import SequentialSampler
class myDataset(Dataset):
def __init__(self, data):
self.data = data
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.data[idx]
# 示例数据 :0 到 19 的整数
data = [i for i in range(20)]
dataset = myDataset(data)
# 使用 SequentialSampler
sampler = SequentialSampler(dataset)
# 创建 DataLoader
dataloader = DataLoader(dataset, batch_size=8, sampler=sampler)
# 使用 DataLoader 迭代数据
for data in dataloader:
print(data)
2)随机采样 RandomSampler
作用 :接收一个 Dataset 对象,输出数据包中样本量的随机索引 (可指定是否可重复)。
举例 1
import torch.utils.data.sampler as sampler
data = list([17, 22, 3, 41, 8])
seq_sampler = sampler.RandomSampler(data_source=data)
for index in seq_sampler:
print("index: {}".format(index))
相关源码 (删减版本)
class RandomSampler(Sampler):
def __init__(self, data_source, replacement=False, num_samples=None):
self.data_source = data_source
self.replacement = replacement
self._num_samples = num_samples
def num_samples(self):
if self._num_samples is None:
return len(self.data_source)
return self._num_samples
def __len__(self):
return self.num_samples
def __iter__(self):
n = len(self.data_source)
if self.replacement:
# 生成的随机数是可能重复的
return iter(torch.randint(high=n, size=(self.num_samples,), dtype=torch.int64).tolist())
# 生成的随机数是不重复的
return iter(torch.randperm(n).tolist())
查看 torch.randperm() 的使用 :
- init 参数 :
- data_source (Dataset): 采样的 Dataset 对象
- replacement (bool): 如果为 True,则抽取的样本是有放回的。默认为 False
- num_samples (int): 抽取样本的数量,默认是len(dataset)。当 replacement 是 True 时,应被实例化
- iter 返回一个可迭代对象(返回的是索引),因为 RandomSampler 是随机采样,所以返回的索引是随机的数值序列 (当 replacement=False 时,生成的排列是无重复的)
- len 返回 dataset 中样本量
举例 2
from torch.utils.data import DataLoader, Dataset
from torch.utils.data.sampler import RandomSampler
class myDataset(Dataset):
def __init__(self, data):
self.data = data
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.data[idx]
# 示例数据 :0 到 19 的整数
data = [i for i in range(20)]
dataset = myDataset(data)
# 使用 SequentialSampler
sampler = RandomSampler(dataset)
# 创建 DataLoader
dataloader = DataLoader(dataset, batch_size=8, sampler=sampler)
# 使用 DataLoader 迭代数据
for data in dataloader:
print(data)
2、sampler 与 shuffle 的互斥
- 参数 sampler 与 参数 shuffle 是互斥的,不要同时使用 sampler 和 shuffle
- 因为 shuffle 的默认值为 False,所以代码会兼容 shuffle 等于默认值 False 的情况,即 :
- 当同时设置了 shuffle 与 sampler,且 shuffle=True,会报错
- 当同时设置了 shuffle 与 sampler,且 shuffle=False,具体逻辑按照 sampler
3、批采样 BatchSampler
官方文档 :
https://pytorch.org/docs/stable/data.html#torch.utils.data.BatchSampler
torch.utils.data.DataLoaderde 的参数 batch_sample, 接收的一般是 torch.utils.data.BatchSampler 对象,
torch.utils.data.BatchSampler 的作用 : 包装另一个采样器,生成一个小批量索引采样器
torch.utils.data.BatchSampler(sampler, batch_size, drop_last)
举例 1
import torch.utils.data.sampler as sampler
data = list([17, 22, 3, 41, 8])
seq_sampler = sampler.SequentialSampler(data_source=data)
batch_sampler = sampler.BatchSampler(seq_sampler, 2, False )
for index in batch_sampler:
print(index)
相关源码 (删减版本)
class BatchSampler(Sampler):
def __init__(self, sampler, batch_size, drop_last):、
self.sampler = sampler
self.batch_size = batch_size
self.drop_last = drop_last
def __iter__(self):
batch = []
for idx in self.sampler:
batch.append(idx)
# 如果采样个数和batch_size相等则本次采样完成
if len(batch) == self.batch_size:
yield batch
batch = []
# for 结束后在不需要剔除不足batch_size的采样个数时返回当前batch
if len(batch) > 0 and not self.drop_last:
yield batch
def __len__(self):
# 在不进行剔除时,数据的长度就是采样器索引的长度
if self.drop_last:
return len(self.sampler) // self.batch_size
else:
return (len(self.sampler) + self.batch_size - 1) // self.batch_size
- 参数 :
- sampler : 其他采样器实例
- batch_size :批量大小
- drop_last :为 “True”时,如果最后一个batch 采样得到的数据个数小于batch_size,则抛弃最后一个batch的数据
举例 2
import torch
from torch.utils.data import DataLoader, Dataset
from torch.utils.data.sampler import SequentialSampler, BatchSampler
class myDataset(Dataset):
def __init__(self, data):
self.data = data
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.data[idx]
# 示例数据 :# 生成 0 到 19 的整数
data = [i for i in range(20)]
dataset = myDataset(data)
# 使用 SequentialSampler 顺序采样
sequential_sampler = SequentialSampler(dataset)
# 使用 BatchSampler 将 SequentialSampler 和 batch_size 结合
batch_sampler = BatchSampler(sequential_sampler, batch_size=8, drop_last=False)
# 创建 DataLoader,使用 BatchSampler
dataloader = DataLoader(dataset, batch_sampler=batch_sampler)
# 使用 DataLoader 迭代数据
for data in dataloader:
print(data)
4、BatchSampler 与 其他参数的互斥
如果你在 DataLoader(dataset, batch_sampler=batch_sampler) 中指定了参数 batch_sampler, 那么就不能再指定参数 batch_size、shuffle、sampler、和 drop_last 了,他们互斥。
因为:
- 你在生成torch.utils.data.sampler.BatchSampler() 的时候,就已经制定过 batch_size、sampler、和 drop_last 这些参数了,
- batch_sampler 与 shuffle 作用一致,所以也互斥
比如,如下代码就会报错,因为在 DataLoader 中重复指定了 batch_size
random_sampler = sampler.RandomSampler(data_source=dataset)
batch_sampler = sampler.BatchSampler(random_sampler, batch_size=2, drop_last=False)
dataloader = DataLoader(dataset, batch_size=2, batch_sampler=batch_sampler)
重写 collate_fn 实例
1、collate_fn 函数作用
在使用 torch.utils.data.dataset 时,参数 collate_fn
接受一个函数,该函数的函数名通常就定义为: collate_fn
collate_fn
函数的作用 :将多个 经过 dataset.getitem()
处理好的 样本数据,组合成一个 batch 的数据。
相关代码见最后【4、附】部分
2、默认 collate_fn 函数
简易实现版本 :
def default_collate(batch):
# 检查样本类型并处理
if isinstance(batch[0], torch.Tensor):
return torch.stack(batch, dim=0)
elif isinstance(batch[0], (list, tuple)):
return [default_collate(samples) for samples in zip(*batch)]
elif isinstance(batch[0], dict):
return {key: default_collate([d[key] for d in batch]) for key in batch[0]}
elif isinstance(batch[0], int):
return torch.tensor(batch) # 将 int 转换为 Tensor
raise TypeError(f"Unsupported type: {type(batch[0])}")
3、自定义 collate_fn 函数
1)常见场景
举例 :一个 batch 中的 多张图片,经过 dataset.getitem() 方法,得到的图像输出尺寸不一样 (比如,可能因为 图像增强 使用 的 transforms ,设计的 最后一步处理方式是范围内的随机裁剪)
因为 网络要求输入数据的尺寸形式为 (batch_size, channel, high,width), 为了将多张图像数据打包成一个batch 的数据形式,需要将图像加上padding,保证所有图像尺寸一致,进而组成 batch 的数据形式
collate_fn 函数中需要处理的内容为 :
- 对比 batch 中,所有图像的宽和高,找到最长的宽度 和 最长的高度
- 将所有的图像都 padding 到最长的宽度 和 最长的高度
- 处理的得到 mask 数据,用于标注 : 哪些位置是 有效像素,哪些位置是 padding
- 将所有数据处理成 batch 的格式,进行返回
2)相关代码实现
相关代码 : 点击跳转
data_loader_train = DataLoader(dataset_train,
batch_sampler=batch_sampler_train,
collate_fn=utils.collate_fn,
num_workers=args.num_workers,
pin_memory=True)
data_loader_val = DataLoader(dataset_val,
args.batch_size,
sampler=sampler_val,
drop_last=False,
collate_fn=utils.collate_fn,
num_workers=args.num_workers,
pin_memory=True)
相关代码 :点击跳转
def collate_fn(batch):
batch = list(zip(*batch))
batch[0] = nested_tensor_from_tensor_list(batch[0])
return tuple(batch)
def _max_by_axis(the_list):
# type: (List[List[int]]) -> List[int]
maxes = the_list[0]
for sublist in the_list[1:]:
for index, item in enumerate(sublist):
maxes[index] = max(maxes[index], item)
return maxes
def nested_tensor_from_tensor_list(tensor_list: List[Tensor]):
# TODO make this more general
if tensor_list[0].ndim == 3:
# TODO make it support different-sized images
max_size = _max_by_axis([list(img.shape) for img in tensor_list])
# min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list]))
batch_shape = [len(tensor_list)] + max_size
b, c, h, w = batch_shape
dtype = tensor_list[0].dtype
device = tensor_list[0].device
tensor = torch.zeros(batch_shape, dtype=dtype, device=device)
mask = torch.ones((b, h, w), dtype=torch.bool, device=device)
for img, pad_img, m in zip(tensor_list, tensor, mask):
pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
m[: img.shape[1], :img.shape[2]] = False
else:
raise ValueError('not supported')
return NestedTensor(tensor, mask)
4、附
注 :更换 cifar-100 在你本地的路径
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import os
torch.manual_seed(121)
torch.cuda.manual_seed(121)
label_dict = {
'apple': 0,
'aquarium_fish': 1,
'baby': 2,
'bear': 3,
'beaver': 4,
'bed': 5,
'bee': 6,
'beetle': 7,
'bicycle': 8,
'bottle': 9,
'bowl': 10,
'boy': 11,
'bridge': 12,
'bus': 13,
'butterfly': 14,
'camel': 15,
'can': 16,
'castle': 17,
'caterpillar': 18,
'cattle': 19,
'chair': 20,
'chimpanzee': 21,
'clock': 22,
'cloud': 23,
'cockroach': 24,
'couch': 25,
'crab': 26,
'crocodile': 27,
'cup': 28,
'dinosaur': 29,
'dolphin': 30,
'elephant': 31,
'flatfish': 32,
'forest': 33,
'fox': 34,
'girl': 35,
'hamster': 36,
'house': 37,
'kangaroo': 38,
'keyboard': 39,
'lamp': 40,
'lawn_mower': 41,
'leopard': 42,
'lion': 43,
'lizard': 44,
'lobster': 45,
'man': 46,
'maple_tree': 47,
'motorcycle': 48,
'mountain': 49,
'mouse': 50,
'mushroom': 51,
'oak_tree': 52,
'orange': 53,
'orchid': 54,
'otter': 55,
'palm_tree': 56,
'pear': 57,
'pickup_truck': 58,
'pine_tree': 59,
'plain': 60,
'plate': 61,
'poppy': 62,
'porcupine': 63,
'possum': 64,
'rabbit': 65,
'raccoon': 66,
'ray': 67,
'road': 68,
'rocket': 69,
'rose': 70,
'sea': 71,
'seal': 72,
'shark': 73,
'shrew': 74,
'skunk': 75,
'skyscraper': 76,
'snail': 77,
'snake': 78,
'spider': 79,
'squirrel': 80,
'streetcar': 81,
'sunflower': 82,
'sweet_pepper': 83,
'table': 84,
'tank': 85,
'telephone': 86,
'television': 87,
'tiger': 88,
'tractor': 89,
'train': 90,
'trout': 91,
'tulip': 92,
'turtle': 93,
'wardrobe': 94,
'whale': 95,
'willow_tree': 96,
'wolf': 97,
'woman': 98,
'worm': 99
}
def default_collate(batch):
# 检查样本类型并处理
if isinstance(batch[0], torch.Tensor):
return torch.stack(batch, dim=0)
elif isinstance(batch[0], (list, tuple)):
return [default_collate(samples) for samples in zip(*batch)]
elif isinstance(batch[0], dict):
return {key: default_collate([d[key] for d in batch]) for key in batch[0]}
elif isinstance(batch[0], int):
return torch.tensor(batch) # 将 int 转换为 Tensor
raise TypeError(f"Unsupported type: {type(batch[0])}")
class CustomDataset(Dataset):
def __init__(self, data_folder, train, transform=None):
self.data_folder = data_folder
self.transform = transform
self.file_list = os.listdir(data_folder)
self.train = train
def __getitem__(self, idx):
img_name = os.path.join(self.data_folder, self.file_list[idx])
original_image = Image.open(img_name)
label_name = img_name.split('_')[-1].split('.')[0]
label_idx = label_dict[label_name]
if self.train:
image = self.transform(original_image)
else:
image = self.transform(original_image)
return image, label_idx
def __len__(self):
return len(self.file_list)
images_dir = "/Users/enzo/Documents/GitHub/dataset/CIFAR/cifar-100-images/train"
dataset = CustomDataset(images_dir, train=True, transform=transforms.ToTensor())
data_loader = DataLoader(dataset,
batch_size=2,
shuffle=True,
collate_fn=default_collate)
for data in data_loader:
image, label = data
RandomSampler 与 shuffle=True 的区别
效果完全没有区别,只是实现方式不一样。
- shuffle=True 的实现方式: 在每个 epoch 开始时将整个数据集打乱,然后按照打乱后的顺序划分 batch。再按照batch_size 个数依次提取数据
- sampler.BatchSampler(random_sampler) 的实现方式:(数据不会打乱)
- step 1、RandomSampler 会生成随机的索引。
- step 2、BatchSampler 根据上面随机出来的索引生成 batch 组。
- step 3、拿着每个batch 组的索引去取 数据
相同点:
- 每个epoch 都会重新打乱
- 都不会重复采样,除非你通过参数指定了可以重复采样
其他说明:
- shuffle=True 的性能更高一些,而 BatchSampler灵活性更高,因为你可以通过 BatchSampler 设计更复杂的采样方式
- 在 Dataloader 中使用 batch_sampler 的常见目的之一,是为了兼容 DistributedSampler,比如:
if args.distributed:
sampler_train = DistributedSampler(dataset_train)
sampler_val = DistributedSampler(dataset_val, shuffle=False)
else:
sampler_train = torch.utils.data.RandomSampler(dataset_train)
sampler_val = torch.utils.data.SequentialSampler(dataset_val)
batch_sampler_train = torch.utils.data.BatchSampler(sampler_train, args.batch_size, drop_last=True)
data_loader_train = DataLoader(dataset_train,
batch_sampler=batch_sampler_train,
collate_fn=utils.collate_fn,
)
data_loader_val = DataLoader(dataset_val,
args.batch_size,
sampler=sampler_val,
drop_last=False,
collate_fn=utils.collate_fn,
)
跑个小例子,看一下 :
import torch
import torch.utils.data.sampler as sampler
from torch.utils.data import Dataset, DataLoader
class MyDataset(Dataset):
def __init__(self):
self.data = [1, 2, 3, 4, 5]
def __len__(self):
return len(self.data)
def __getitem__(self, index):
return self.data[index]
dataset = MyDataset()
# =============================================
random_sampler = sampler.RandomSampler(data_source=dataset)
batch_sampler = sampler.BatchSampler(random_sampler, batch_size=2, drop_last=False)
dataloader1 = DataLoader(dataset, batch_sampler=batch_sampler)
for epoch in range(3):
for index, data in enumerate(dataloader1):
print(index, data)
print('*'*30)
# =============================================
dataloader2 = DataLoader(dataset, batch_size=2, shuffle=True)
for epoch in range(3):
for index, data in enumerate(dataloader2):
print(index, data)
数据处理&数据增强
数据预处理 和 数据增强,我们一般都是使用 torchvision.transforms
模块来完成的。
我敢说,当你掌握了 torchvision.transforms
的使用方法之后,一定在数据预处理 和 数据增强 方面毫无压力。
官网地址 :
https://pytorch.org/vision/stable/transforms.html#others
简单使用举例:
1、训练阶段
from torchvision.transforms import transforms
my_transform = transforms.Compose([transforms.RandomResizedCrop(img_size),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])])
2、推理阶段
from torchvision.transforms import transforms
my_transform = transforms.Compose([transforms.Resize(original_size*1.143)
transforms.CenterCrop(img_size),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])])
注:
- 操作顺序 : 几何变换 / 颜色变换 —>>
ToTensor()
—>>Normalize()
ToTensor()
和Normalize()
是数据处理的最后2步- 经过
Normalize()
处理后,得到的数据,一般可直接输入到模型中使用
1、图像尺寸变换 与 裁剪
1)transforms.Resize
官方文档 : 点击跳转
torchvision.transforms.Resize(size,
interpolation=InterpolationMode.BILINEAR,
max_size=None)
作用:将图像按照指定的插值方式,resize到指定的尺寸。
参数:
size
: 输出的图像尺寸。可以是元组 (h, w) ,也可以是单个整数。- 如果 size 是元组,则输出大小将分别匹配 h, w 的大小
- 如果 size 是整数,则图像较小的边将被resize 到此数字,并保持宽高比
interpolation
: 选用如下插值方法将图像 resize 到输出尺寸PIL.Image.NEAREST
最近邻差值PIL.Image.BILINEAR
双线性差值(默认)PIL.Image.BICUBIC
双三次差值
max_size
:输出图像的较长边的最大值。仅当 size 为单个整数时才支持此功能。如果图像的较长边在根据 size 缩放后大于 max_size,则 size 将被覆盖,使较长边等于 max_size,这时较短边会小于 size。
from PIL import Image
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
original_img = Image.open('image.jpg') # https://p.ipic.vip/7pvisy.jpg
print(original_img.size) # (3280, 1818)
img_1 = transforms.Resize(1500, max_size=None)(original_img)
print(img_1.size)(2706, 1500) # (2706, 1500)
img_2 = transforms.Resize((1500, 1500))(original_img)
print(img_2.size) # (1500, 1500)
img_3 = transforms.Resize(1500, max_size=1600)(original_img)
print(img_3.size) # (1600, 886)
plt.subplot(141)
plt.axis("off")
plt.imshow(original_img)
plt.subplot(142)
plt.axis("off")
plt.imshow(img_1)
plt.subplot(143)
plt.axis("off")
plt.imshow(img_2)
plt.subplot(144)
plt.axis("off")
plt.imshow(img_3)
plt.show()
2)transforms.CenterCrop
官方文档 : 点击跳转
功能:从图片中心裁剪出尺寸为 size 的图片
参数:
size
: 所需裁剪的图片尺寸,即输出图像尺寸
注意:
- 若切正方形,
transforms.CenterCrop(100)
和transforms.CenterCrop((100, 100))
,两种写法,效果一样 - 如果设置的输出的尺寸 大于原图像尺寸,则会在四周补 padding,padding 颜色为黑色(像素值为0)
举例:
from PIL import Image
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
original_img = Image.open('image.jpg') # https://p.ipic.vip/7pvisy.jpg
print(original_img.size) # (3280, 1818)
img_1 = transforms.CenterCrop(1500)(original_img)
img_2 = transforms.CenterCrop((1500, 1500))(original_img)
img_3 = transforms.CenterCrop((3000, 3000))(original_img)
plt.subplot(141)
plt.axis("off")
plt.imshow(original_img)
plt.subplot(142)
plt.axis("off")
plt.imshow(img_1)
plt.subplot(143)
plt.axis("off")
plt.imshow(img_2)
plt.subplot(144)
plt.axis("off")
plt.imshow(img_3)
plt.show()
3)transforms.RandomCrop
官方文档 : 点击跳转
功能:
- 从图片中随机裁剪出尺寸为 size 的图片
- 如果设置了参数
padding
,先添加 padding,再从padding后的图像中随机裁剪出大小为size的图片
参数: size
:所需裁剪的图片尺寸,即输出图像尺寸padding
: 设置填充大小- 当
padding
值形式式为 a 时,上下左右均填充 a 个像素 - 当
padding
值形式式为 (a, b) 时,左右填充 a 个像素,上下填充 b 个像素 - 当
padding
值形式式为 (a, b, c, d) 时,左上右下分别填充 a,b,c,d
- 当
pad_if_needed
:当原图像尺寸小于设置的输出图像尺寸(由参数size指定),是否填充,默认为 Falsepadding_mode
:若pad_if_needed
设置为 True,则此参数起作用, 默认值为 “constant”"constant"
: 像素值由参数 fill 指定 (默认填充黑色,像素值为0)"edge"
: padding 的像素值 为图像边缘像素值"reflect"
: 镜像填充,最后一个像素不镜像。([1,2,3,4] --> [3,2,1,2,3,4,3,2])"symmetric"
: 镜像填充,最后一个像素也镜像。([1,2,3,4] -->[2,1,1,2,3,4,4,3])
fill
:指定填充像素值,当 padding_mode 为 constant 时起作用,默认填充黑色,像素值为0
注意:
- 同时指定参数
padding_mode
和 参数fill
时,若padding_mode
值不为"constant"
,则 参数fill不起作用。 - 若指定的输出图像尺寸size 大于输入图像尺寸,并且指定参数
pad_if_needed= False
,则会报错类似如下
举例:
from PIL import Image
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
original_img = Image.open('image.jpg') # https://p.ipic.vip/7pvisy.jpg
print(original_img.size) # (3280, 1818)
img_1 = transforms.RandomCrop(1500, padding=500)(original_img)
img_2 = transforms.RandomCrop(3000, pad_if_needed=True, fill=(255, 0, 0))(original_img)
img_3 = transforms.RandomCrop(3000, pad_if_needed=True, padding_mode="symmetric")(original_img)
plt.subplot(141)
plt.axis("off")
plt.imshow(original_img)
plt.subplot(142)
plt.axis("off")
plt.imshow(img_1)
plt.subplot(143)
plt.axis("off")
plt.imshow(img_2)
plt.subplot(144)
plt.axis("off")
plt.imshow(img_3)
plt.show()
4)transforms.RandomResizedCrop
官方文档 : 点击跳转
torchvision.transforms.RandomResizedCrop(size,
scale=(0.08, 1.0),
ratio=(0.75, 1.3333333333333333),
interpolation=InterpolationMode.BILINEAR)
功能:
- Step 1 : 将图像进行随机裁剪,裁剪出的图像需满足:
- 裁剪后的图像面积 占原图像面积的比例 在指定的范围内
- 裁剪后的图像高宽比 在指定范围内
- Step 2 :将 Step 1 得到的图像通过指定的方式,进行缩放
参数:size
: 输出的图像尺寸scale
: 随机缩放面积比例,默认随机选取 (0.08, 1) 之间的一个数ratio
: 随机长宽比,默认随机选取 (0.75, 1.33333 ) 之间的一个数。超过这个比例范围会有明显的失真interpolation
: 选用如下插值方法将图像 resize 到输出尺寸PIL.Image.NEAREST
最近邻差值PIL.Image.BILINEAR
双线性差值(默认)PIL.Image.BICUBIC
双三次差值
举例:
from PIL import Image
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
original_img = Image.open('image.jpg') # https://p.ipic.vip/7pvisy.jpg
print(original_img.size) # (3280, 1818)
img = transforms.RandomResizedCrop(1500)(original_img)
plt.subplot(121)
plt.imshow(original_img)
plt.subplot(122)
plt.imshow(img)
plt.show()