当前位置: 首页 > article >正文

深度学习基础(2024-11-02更新到图像尺寸变换 与 裁剪)

1. 名词解释

FFN

  • FFN : Feedforward Neural Network,前馈神经网络
  • 馈神经网络是一种基本的神经网络架构,也称为多层感知器(Multilayer Perceptron,MLP)
  • FFN 一般主要是包括多个全连接层(FC)的网络,其中,全连接层间可以包含 : 激活层、BN层、Dropout 层。

MLP 与 FFN 的区别

在机器学习和深度学习中,MLP(多层感知机)和 FFN(前馈神经网络)在很大程度上可以视为同义词,都指代了一个具有多个层的前馈神经网络结构。

  • MLP(多层感知机)更偏向于表达网络结构(多个全连接层)
  • FFN(前馈神经网络)更偏向于表达数据以前馈的方式流动

MLP 和 FFN 通常指的是只包含全连接层 和激活函数的神经网络结构。这两者都是基本的前馈神经网络类型,没有包含卷积层或其他复杂的结构。

Logit

“Logit” 通常指的是神经网络中最后一个隐藏层的输出,经过激活函数之前的值。比如:

  • 对于二分类问题,logit 是指网络输出的未经过 sigmoid 函数处理的值
  • 对于多分类问题,logit 是指网络输出的未经过 softmax 函数处理的值

NLL

NLL 是 Negative Log-Likelihood(负对数似然)的缩写。
在深度学习中,特别是在分类问题中,NLL 经常与交叉熵损失(Cross-Entropy Loss)等价使用。

Anchor Box 与 Anchor Point

  • Anchor box 通常表示 一个包含位置和大小信息的四元组 ( x , y , w , h ) (x, y, w, h) (x,y,w,h),而 Anchor point 通常表示 一个二元组 ( x , y ) (x, y) (x,y)。 其中, x x x y y y表示框的中心坐标, w w w h h h表示框的宽度和高度。
  • Anchor box 是目标检测中用于定义目标位置和大小的一种方式。而 Anchor point 主要用于在图像上生成 anchor box 的位置,生成的 anchor box 会在 anchor point 的周围不同尺寸和宽高比的情况下进行缩放,形成一系列不同形状的框。

parameter efficient

参数效率高,指的是网络在达到良好性能的同时所使用的参数数量较少。

Deep Supervision

Deep Supervision 是一种训练策略,旨在提高网络的梯度流动,并促使网络更快地收敛,并且有助于缓解梯度消失问题。Deep Supervision 的核心思想是在网络的不同层中引入额外的监督信号,而不仅仅在最后一层输出进行监督训练。具体来说:Deep Supervision 会使用网络的中间层输出,计算出一部分损失函数,然后和网络最后一层的损失函数一起,对网络的参数进行优化。

DP 与 DDP

DP : DataParallel,数据并行
DDP :Distributed Data Parallel,分布式数据并行

感受野(Receptive Field)

1、介绍

感受野(receptive field)是卷积神经网络输出特征图上的像素点在原始图像上所能看到的(映射的)区域的大小,它决定了该像素对输入图像的感知范围(获取信息的范围)。较小的感受野可以捕捉到更细节的特征,而较大的感受野可以捕捉到更全局的特征。
在这里插入图片描述
如果连续进行 2次卷积操作,卷积核大小都为 3x3,stride=1, padding=0, 如下图,layer3上的每一个像素点在 layer1上的感受野 为 5x5
在这里插入图片描述

2、感受野计算公式

感受野计公式 : F ( i ) = ( F ( i + 1 ) − 1 ) × S t r i d e + K s i z e F(i)=(F(i+1)-1)\times Stride + Ksize F(i)=F(i+1)1×Stride+Ksize F i n = ( F o u t − 1 ) × S t r i d e + K s i z e F_{in}=(F_{out}-1)\times Stride + Ksize Fin=Fout1×Stride+Ksize
其中:

  • F ( i ) F(i) F(i) :在第 i i i层的感受野
  • S t r i d e Stride Stride:第 i i i层步距
  • K s i z e Ksize Ksize:第 i i i层卷积或池化的 kernel size

3、计算举例

求 :layer3 上的每个像素在 layer1 上的感受野。
在这里插入图片描述
1)先来计算 layer3 上的一个像素( F ( 3 ) = 1 F(3)=1 F(3)=1)在 layer2 上的感受野 :
F ( 2 ) = ( F ( 3 ) − 1 ) × S t r i d e + K s i z e = ( 1 − 1 ) × 2 + 2 = 2 F(2) = (F(3)-1) \times Stride + Ksize = (1 -1) \times 2 + 2 = 2 F(2)=(F(3)1)×Stride+Ksize=(11)×2+2=2

2)计算 layer3 上的一个像素( F ( 3 ) = 1 ,    F ( 2 ) = 2 F(3)=1, \; F(2)=2 F(3)=1F(2)=2 )在 layer1 上的感受野 :
F ( 1 ) = ( F ( 2 ) − 1 ) × S t r i d e + K s i z e = ( 2 − 1 ) × 2 + 3 = 5 F(1)=(F(2)-1)\times Stride + Ksize =(2 -1)\times 2 + 3 = 5 F(1)=(F(2)1)×Stride+Ksize=(21)×2+3=5

如果仅计算 layer2 上的一个像素( F(2)=1 )在 layer1 上的感受野 :
F ( 1 ) = ( F ( 2 ) − 1 ) × S t r i d e + K s i z e = ( 1 − 1 ) × 2 + 3 = 3 F(1)=(F(2)-1)\times Stride + Ksize = (1 -1)\times 2 + 3 = 3 F(1)=F(2)1×Stride+Ksize=11×2+3=3

2. tensor 相关

tensor 内部存储结构

1、数据区域和元数据

PyTorch 中的 tensor 内部结构通常包含了 数据区域(Storage) 和 元数据(Metadata) :

  • 数据区域 : 存储了 tensor 的实际数据,且数据被保存为连续的数组。比如: a = torch.tensor([[1, 2, 3], [4, 5, 6]]),它的数据在存储区的保存形式为 [1, 2, 3, 4, 5, 6]
  • 元数据 :包含了 tensor 的一些描述性信息,比如 : 尺寸(Size)、步长(Stride)、数据类型(Data Type) 等信息

占用内存的主要是 数据区域,且取决于 tensor 中元素的个数, 而元数据占用内存较少。
采用这种 【数据区域 + 元数据】 的数据存储方式,主要是因为深度学习的数据动辄成千上万,数据量巨大,所以采取这样的存储方式以节省内存
在这里插入图片描述


2、查看 tensor 的存储区数据: storage()

虽然 .storage() 方法即将被弃用,而是改用 .untyped_storage(),但为了笔记中展示方便,我们仍然使用 .storage() 方法。.untyped_storage() 方法的输出太长了,不方便截图放在笔记中。

a = torch.tensor([[1, 2, 3],
                  [4, 5, 6]])

print(a.storage())

在这里插入图片描述


3、查看 tensor 的步长: stride()

stride() : 在指定维度 (dim) 上,存储区中的数据元素,从一个元素跳到下一个元素所必须的步长

a = torch.randn(3, 2)
print(a.stride())  # (2, 1)

解读:
在这里插入图片描述
在第 0 维,想要从一个元素跳到下一个元素,比如从 a[0][0] 到 a[1][0] ,需要经过 2个元素,步长是 2
在第 1 维,想要从一个元素跳到下一个元素,比如从 a[0][0] 到 a[0][1], 需要经过 1个元素,步长是 1

4、查看 tensor 的偏移量:storage_offset()

表示 tensor 的第 0 个元素与真实存储区的第 0 个元素的偏移量

a = torch.tensor([1, 2, 3, 4, 5])
b = a[1:]   # tensor([2, 3, 4, 5])
c = a[3:]   # tensor([4, 5])
print(b.storage_offset())   # 1
print(c.storage_offset())   # 3
  • b 的第 0 个元素与 a 的第 0 个元素之间的偏移量是 1
  • c 的第 0 个元素与 a 的第 0 个元素之间的偏移量是 3

5、代码举例

  • 一般来说,一个 tensor 有着与之对应的 storage, storage 是在 data 之上封装的接口。

  • 不同 tensor 的元数据一般不同,但却可能使用相同的 storage。

  • data_ptr()

    • 返回的是张量数据 (storage 数据)存储的实际内存地址,确切来说是张量数据的起始内存地址。
    • data_ptr 中的 ptr 是 pointer(指针)的缩写,对应于 C 语言中的指针,因为 Python 的底层就是由 C 实现的
  • id(a)

    • 返回的是 a 在 Python 内存管理系统中的唯一标识符。虽然这个标识符通常与对象的内存地址有关,但它并不直接表示内存地址。

1)观察一

import torch

a = torch.arange(0, 6)
print('a = {}\n'.format(a))
print('tensor a 存储区的数据内容 :{}\n'.format(a.storage()))
print('tensor a 相对于存储区数据的偏移量 :{}\n'.format(a.storage_offset()))

print('*'*20, '\n')

b = a.view(2,3)
print('b = {}\n'.format(b))
print('tensor b 存储区的数据内容 :{}\n'.format(b.storage()))
print('tensor b 相对于存储区数据的偏移量 :{}\n'.format(b.storage_offset()))

在这里插入图片描述
2)观察二

import torch

a = torch.tensor([1, 2, 3, 4, 5, 6])
b = a.view(2, 3)

print(a.data_ptr())   # 140623757700864
print(b.data_ptr())   # 140623757700864

print(id(a))   # 4523755392
print(id(b))   # 4602540464

在这里插入图片描述

  • a.data_ptr()b.data_ptr() 一样,说明 tensor a 和 tensor b 共享相同的存储区,即,它们指向相同的底层数据存储对象。
  • id(a)id(b) 不一样,是因为虽然 a 和b 共享storage 数据,但是 它们 有不同的 size 或者 strides 、 storage_offset 等其他属性

3)观察三

import torch

a = torch.tensor([1, 2, 3, 4, 5, 6])
c = a[2:]

print(c.storage())

print('\n', '*'*20, '\n')

print('tensor a 首元素的内存地址 : {}'.format(a.data_ptr()))
print('tensor c 首元素的内存地址 : {}'.format(c.data_ptr()))
print(c.data_ptr() - a.data_ptr())

print('\n', '*'*20, '\n')

c[0] = -100
print(a)

在这里插入图片描述

  • data_ptr() 返回 tensor 首元素的内存地址
  • c 和 a 的首元素内存地址相差 16,每个元素占用 8 个字节(LongStorage), 也就是首元素相差两个元素
  • 改变 c 的首元素, a 对应位置的元素值也被改变

6、总结

  1. 由上可知,绝大多数操作并不修改 tensor 的数据,只是修改了 tensor 的元数据,比如修改 tensor 的 offset 、stride 和 size ,这种做法更节省内存,同时提升了处理速度。
  2. 有些操作会导致 tensor 不连续,这时需要调用 torch.contiguous 方法将其变成连续的数据,该方法会复制数据到新的内存,不再与原来的数据共享 storage。

3.Dataset 与 DataLoader

  • Dataset 作用 :
    • 定义和管理如何获取单个数据样本及其标签
    • 包含 加载数据/读取数据、预处理数据、图像增强 等一系列操作
    • 返回 单个数据样本 的处理结果
  • DataLoader 作用 :
    • 指定数据读取规则,一般通过 sampler 指定
    • 指定 batch数据的打包规则,通过 collate_fn 指定
    • 每次迭代,返回的是 一个batch 的数据

在这里插入图片描述

生成 Dataset 方式一 :自定义Dataset

所谓的 自定义 dataset ,即:我们自己去写一个 Dataset 类 :

  • 一般需要继承 torch.utils.data.Dataset
    • 继承 torch.utils.data.Dataset 主要是为了与 DataLoader 保持兼容,确保数据集遵循 DataLoader 的接口标准,方便后续使用 PyTorch 提供的工具,比如 :批量加载、打乱数据、并行处理等功能
  • 并且满足和 DataLoader 进行交互的规范 :
    • DataLoader 会调用 Dataset 的 len() 和 getitem() 方法,所以自定义 Dataset 类必须实现这两个方法,如此才能保证 DataLoader 可以正确地加载和操作你的数据集

1、自定义 Dataset 的三个重要方法

创建自定义 Dataset 时,必须实现的3个方法 :init()、len()、 getitem()。
这些方法定义了数据集的基本结构和行为,也是 DataLoader 可以正确的从 Dataset 中读取数据的基础。
1)init 方法

  • 参数: 根据需要传递一些参数,例如文件路径、数据转换等。
  • 作用: 可以在这里进行一些初始化工作,例如:设置文件路径、定义数据转换transforms 等。
def __init__(self, data_folder, train, transform=None):
    self.data_folder = data_folder
    self.transform = transform
    self.file_list = os.listdir(data_folder)
    self.train = train

2)len 方法

def __len__(self):
    return len(self.file_list)
  • 返回值: 需返回数据集中的样本的总数。
  • 作用:
    • 方便通过调用 len(dataset) 来获取数据量,其中 dataset 为 Dataset 对象
    • Dataloader 会用它 和 batch_size 一起来计算一个epoch 要迭代多少个 steps: s t e p s = l e n ( d a t a s e t ) b a t c h s i z e steps = \frac{len(dataset)}{batchsize} steps=batchsizelen(dataset)

3)getitem 方法

def __getitem__(self, idx):
    img_name = os.path.join(self.data_folder, self.file_list[idx])
    original_image = Image.open(img_name)
    label = img_name.split('_')[-1].split('.')[0]

    if self.train:
        image = self.transform(original_image)
    else:
        image = self.transform(original_image)

    return image, label
  • 参数: index 是样本的索引。
  • 返回值: 返回数据集中索引指定的样本。通常是一个包含输入数据和对应标签的元组。这里可以根据自己的需求,进行自定义。
  • 作用: 根据给定的索引返回数据集中的一个样本。这是用于获取数据集中单个样本的方法。
    比如,可以通过 dataset[0] 来获取 dataset 中的索引为 0 的样本

以上这三个方法一起定义了 PyTorch 中的 dataset 类,并支持使用 torch.utils.data.DataLoader 来加载数据并进行训练。


2、使用举例

用 CIRFAR-100 数据集生成 Dataset
在这里插入图片描述

from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import os


class CustomDataset(Dataset):
    def __init__(self, data_folder, train, transform=None):
        self.data_folder = data_folder
        self.transform = transform
        self.file_list = os.listdir(data_folder)
        self.train = train

    def __getitem__(self, idx):
        img_name = os.path.join(self.data_folder, self.file_list[idx])
        original_image = Image.open(img_name)
        label = img_name.split('_')[-1].split('.')[0]

        if self.train:
            image = self.transform(original_image)
        else:
            image = self.transform(original_image)

        return image, label

    def __len__(self):
        return len(self.file_list)


images_dir = "/Users/enzo/Documents/GitHub/dataset/CIFAR/cifar-100-images/train"
dataset = CustomDataset(images_dir, train=True, transform=transforms.ToTensor())

print(len(dataset))

# 通过 dataset 对象,获取索引为 0 的样本
sample_image, sample_label = dataset[0]
print("Sample Image Shape:", sample_image.shape)
print("Sample Label:", sample_label)

输出:
在这里插入图片描述

生成 Dataset 方式二 :torchvision.datasets 模块

1、pytorch 官方支持下载的数据集

官网地址 : 点击查看
在这里插入图片描述
注 :

  • 对于一部分数据集,提供下载功能
  • 对于一部分数据集,不提供下载功能 (具体情况取决于数据集的来源和许可协议)

2、torchvision.datasets 模块

以获取 MNIST 数据集为例 (pytorch官方文档地址 : 点击查看)
MNIST 全称:mixed national institute of standards and technology database

train_dataset = torchvision.datasets.MNIST(root,    
                                           train=True,               
                                           transform=None,  
                                           target_transform= None  
                                           download=True)

参数 :

  • root :数据集存放的路径
  • train:如果是True, 下载训练集 trainin.pt; 如果是False,下载测试集 test.pt。 默认是True
  • transform:一系列作用在PIL图片上的转换操作
  • download:是否下载数据集,默认为 False
    • 若设置 download=True
      • root 目录下没有该数据集,数据集将会被下载到 root 指定的位置。
      • root 目录下已经存在该数据集,则不会重新下载,而是会直接使用已存在的数据,以节省时间
    • 若设置 download=False,程序将会在 root 指定的位置查找数据集,如果数据集不存在,则会抛出错误。

3、举例 1:torchvision.datasets.MNIST

  • 因为是单通道,所以 transforms.Normalize 的均值和标准差 仅指定了一个值
  • 记得把数据集的下载地址换掉,换成你想要它下载到的位置
import torchvision
from torchvision.transforms import transforms
import torch.utils.data as data
import matplotlib.pyplot as plt

batch_size = 5

my_transform = transforms.Compose([transforms.ToTensor(),
                                   transforms.Normalize(mean=[0.5],  # mean=[0.485, 0.456, 0.406]
                                                        std=[0.5])])  # std=[0.229, 0.224, 0.225]

train_dataset = torchvision.datasets.MNIST(root="./",
                                           train=True,
                                           transform=my_transform,
                                           download=True)

val_dataset = torchvision.datasets.MNIST(root="./",
                                         train=False,
                                         transform=my_transform,
                                         download=True)

train_loader = data.DataLoader(train_dataset,
                               batch_size=batch_size,
                               shuffle=True)

val_loader = data.DataLoader(val_dataset,
                             batch_size=batch_size,
                             shuffle=True)

print(len(train_dataset))
print(len(val_dataset))

image, label = next(iter(train_loader))
print(image.shape)
print(label)

for i in range(batch_size):
    plt.subplot(1, batch_size, i + 1)
    plt.title(label[i].item())
    plt.axis("off")
    plt.imshow(image[i].permute(1, 2, 0))

plt.show()

输出:
在这里插入图片描述


4、举例 2:torchvision.datasets.CocoDetection

官方文档 : 点击查看

torchvision.datasets.CocoDetection 不支持 COCO 数据集下载
在使用 torchvision.datasets.CocoDetection 之前,需要确保已经下载并淮备好COCO数
据集的图像和标注文件。然后使用 torchvision.datasets.CocoDetection 类来加载 COCO数据集。

torchvision.datasets.CocoDetection(root, 
                                   annFile, 
                                   transform=None, 
                                   target_transform=None, 
                                   transforms=None)

参数 :

  • root : 指定图片地址 (本地已经下载下来的图像地址)
  • annFile : 指定标注文件地址( 本地已经下载下来的标注文件地址)
  • transform : 图像处理 (用于PIL)
  • target_transform : 标注处理
  • transforms : 图像和标注的处理

使用举例:

  • 记得把数据集的下载地址换掉,换成你的 COCO数据集地址
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
import torch
import torchvision
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torchvision.transforms import functional as F
import random


def collate_fn_coco(batch):
    return tuple(zip(*batch))

coco_det = datasets.CocoDetection(root="./COCO2017/train2017",
                                  annFile="./COCO2017/annotations/instances_train 2017.json")


sampler = torch.utils.data.SequentialSampler(coco_det)  # RandomSampler
batch_sampler = torch.utils.data.BatchSampler(sampler, 1, drop_last=True)
data_loader = torch.utils.data.DataLoader(coco_det,
                                          batch_sampler=batch_sampler,
                                          collate_fn=collate_fn_coco)

# 可视化
iterator = iter(data_loader)
imgs, gts = next(iterator)
img,  gts_one_img = imgs[0], gts[0]

bboxes = []
ids = []
for gt in gts_one_img:
    bboxes.append([gt['bbox'][0],
                   gt['bbox'][1],
                   gt['bbox'][2],
                   gt['bbox'][3]
                   ])
    ids.append(gt['category_id'])

fig, ax = plt.subplots()
for box, id in zip(bboxes, ids):
    x = int(box[0])
    y = int(box[1])
    w = int(box[2])
    h = int(box[3])
    rect = plt.Rectangle((x, y), w, h, edgecolor='r', linewidth=2, facecolor='none')
    ax.add_patch(rect)
    ax.text(x, y, id, backgroundcolor="r")

plt.axis("off")
plt.imshow(img)
plt.show()

输出效果:
在这里插入图片描述

DataLoader

1、torch.utils.data.DataLoader

官方文档 :点击查看

from torch.utils.data import DataLoader

data_loader = DataLoader(dataset, 
                         batch_size=1, 
                         shuffle=None, 
                         sampler=None, 
                         batch_sampler=None, 
                         num_workers=0, 
                         collate_fn=None, 
                         pin_memory=False, 
                         drop_last=False,
                         timeout=0
                         )

参数:

  • dataset : 加载数据的数据集
  • batch size : 每批返回的数据量,默认值是 1
  • shuffle:是否在每个 epoch 内将数据打乱顺序。默认值为False
  • sampler :从数据集中提取的样本序列。可以用来自定义样本的采样策路。默认值为None
  • batch_sampler :与sampler类似,但是一次返回一个 batch的索引,用于自定义 batch。它与 与 batch size、shuffle、sampler 和 drop last 互斥
  • num workers : 用于数据加载的子进程数。0表示主进程加载。默认值为0
  • collate_fn: 用于指定如何组合样本数据。如果为None,那么将默认使用默认的组合方法
  • drop_last : 如果数据集的大小不能被 batch _size 整除,那么是否丢弃最后一个数据批次。默认值为 False
  • pin_memory : 将数据固定在内存的锁页内存中,加速数据读取的速度。默认值为False.
  • timeout : workers :等待 collect 一个 batch 的数据的超时时间。默认为 0,表示一直等待

2、常用参数图示

dataset 对 Dataloader 有 2个作用 :

  • 通过 dataset 的 length 方法,dataloader 可以知道数据量,从而根据数据量生成相应的索引列表
  • dataloader 会将索引,传给 dataset 的 getitem 方法,通过 getitem 方法对数据进行处理,并返回处理好的数据

在这里插入图片描述


3、Dataset 与 Dataloader 的内部交互细节 举例

在这里插入图片描述

num_workers 与 pin_memory

1、参数 num_workers

参数 num_workers 参数用于指定 加载数据的子进程的数量

  • num_workers=0 :(默认值) 表示只有主进程去加载 batch数据,这个可能会是一个瓶颈。
  • num_workers=1 :表示只有一个子进程加载数据,主进程不参与,这仍可能导致速度慢。
  • num_workers>0 :表示指定数量的子进程并行加载数据,且主进程不参与。

增加num_workers可以提高加载速度,但也会增加 CPU 和 内存的使用。
通常建议将 num_workers 参数设置为等于或小于 CPU 核心数,以有效平衡数据加载效率和系统资源占用率。
进程之间是动态调度的,谁先做完一个样本:

batch_size = 16
nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])   # number of workers
train_dataloader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=batch_size,
                                               num_workers=nw,
                                               shuffle=True,
                                               pin_memory=True,
                                               collate_fn=collate_fn)

2、数据加载过程是如何并行的

一个进程处理一个 batch 的数据,假设设置 num_workers=2 ,则 进程1 处理一个batch 的数据,进程2 处理另一个batch 的数据

在这里插入图片描述
并行工作流程 :

  • 初始化:创建 DataLoader 实例时,通过参数 num_workers 指定并行加载的子进程数量
  • 子进程加载数据:子进程独立于主进程运行,每个子进程的拿着一个batch 的索引,并行的到 dataset 的 getitem 中预处理数据
  • 数据准备:处理好的数据,放入缓冲区以备主进程请求
  • 数据请求:主进程在 for 循环中请求下一个 batch
  • 数据传输:主进程请求数据时,从缓冲区获取已经准备好的 batch
  • 循环迭代:主进程不断请求数据,子进程并行的处理后续的 batch 数据

3、pin_memory

  • 若设置 pin_memory=True,数据会被加载到CPU的锁定内存中,从而提高数据从 CPU 到 GPU 的传输效率

这是因为锁定的内存(pinned memory)可以更快地被复制到GPU,因为它是连续的,并且已经准备好被传输。

  • 若设置 pin_memory=False ,则数据是被存放在分页内存(pageable memory)中,当我们想要把数据从 cpu 移动到 gpu 上 (执行 .to('cuda') 的时候), 需要先将数据从分页内存中 移动到锁页内存中,然后再传输到 GPU 上

所以,设置 pin_memory=True ,节省的是 将数据从 分页内存移动到锁页内存中 的这段时间

如果你的训练完全在CPU上进行,不涉及GPU,那就没有必要设置 pin_memory=True。因为在这种情况下,数据不需要被传输到GPU,因此不需要使用锁定内存来加速这一过程。可以将 pin_memory 设置为 False,以简化内存管理。
在这里插入图片描述

sampler 与 batch_sampler

1、sampler

torch.utils.data.DataLoader 的参数 sampler 接收的通常是一个实现了 Sampler 接口的对象,比如 :

sampler = SequentialSampler(dataset)   # 使用 SequentialSampler
dataloader = DataLoader(dataset, batch_size=8, sampler=sampler)

通过 sampler 对象来控制数据集的索引顺序,从而影响数据从数据集中的抽取方式
1)pytorch 提供的,可以直接使用的几种 sampler

# 顺序抽样,按照数据集的顺序逐个抽取样本
torch.utils.data.sampler.SequentialSampler()

# 随机抽样,数据集中的样本以随机顺序被抽取
torch.utils.data.sampler.RandomSampler()

# 从指定的样本索引子集内进行随机抽样
torch.utils.data.sampler.SubsetRandomSampler()

# 根据样本的权重随机抽样,不同样本有不同的抽样概率
torch.utils.data.sampler.WeightedRandomSampler()

2)可以自定义 sampler,比如以下是 yolov5 中自定义的 sampler :
在这里插入图片描述
参数 sampler 有一部分功能,是和 参数 shuffle 是重叠的:

  • SequentialSampler 效果等价于 shuffle=False
  • RandomSampler 效果等价于 shuffle=Ture
    Pytorch 提供 sampler 参数,主要是为提升灵活性,支持用户更灵活地设计数据加载的方式

下面我们主要介绍 SequentialSampler 和 RandomSampler, 只要大家通过 SequentialSampler 、RandomSampler 掌握了 sampler 的工作原理,便可以愉快的自定义的去设计 sampler 了。


1)顺序采样 SequentialSampler

作用 :接收一个 Dataset 对象,输出数据包中样本量的顺序索引

举例 1

import torch.utils.data.sampler as sampler

data = list([17, 22, 3, 41, 8])
seq_sampler = sampler.SequentialSampler(data_source=data)

for index in seq_sampler:
    print("index: {}".format(index))

在这里插入图片描述
相关源码

class SequentialSampler(Sampler):
    data_source: Sized

    def __init__(self, data_source: Sized) -> None:
        self.data_source = data_source

    def __iter__(self) -> Iterator[int]:
        return iter(range(len(self.data_source)))

    def __len__(self) -> int:
        return len(self.data_source)
  • init 接收参数:Dataset 对象
  • iter 返回一个可迭代对象(返回的是索引值),因为 SequentialSampler 是顺序采样,所以返回的索引是顺序数值序列
  • len 返回 dataset 中数据个数

举例 2

from torch.utils.data import DataLoader, Dataset
from torch.utils.data.sampler import SequentialSampler


class myDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]


# 示例数据 :0 到 19 的整数
data = [i for i in range(20)]
dataset = myDataset(data)

# 使用 SequentialSampler
sampler = SequentialSampler(dataset)

# 创建 DataLoader
dataloader = DataLoader(dataset, batch_size=8, sampler=sampler)

# 使用 DataLoader 迭代数据
for data in dataloader:
    print(data)

在这里插入图片描述


2)随机采样 RandomSampler

作用 :接收一个 Dataset 对象,输出数据包中样本量的随机索引 (可指定是否可重复)。

举例 1

import torch.utils.data.sampler as sampler

data = list([17, 22, 3, 41, 8])
seq_sampler = sampler.RandomSampler(data_source=data)

for index in seq_sampler:
    print("index: {}".format(index))

在这里插入图片描述
相关源码 (删减版本)

class RandomSampler(Sampler):
    def __init__(self, data_source, replacement=False, num_samples=None):
        self.data_source = data_source
        self.replacement = replacement
        self._num_samples = num_samples
        
    def num_samples(self):
        if self._num_samples is None:
            return len(self.data_source)
        return self._num_samples
    
    def __len__(self):
        return self.num_samples
               
    def __iter__(self):
        n = len(self.data_source)
        if self.replacement:
            # 生成的随机数是可能重复的
            return iter(torch.randint(high=n, size=(self.num_samples,), dtype=torch.int64).tolist())
        # 生成的随机数是不重复的
        return iter(torch.randperm(n).tolist())

查看 torch.randperm() 的使用 :

  • init 参数 :
    • data_source (Dataset): 采样的 Dataset 对象
    • replacement (bool): 如果为 True,则抽取的样本是有放回的。默认为 False
    • num_samples (int): 抽取样本的数量,默认是len(dataset)。当 replacement 是 True 时,应被实例化
  • iter 返回一个可迭代对象(返回的是索引),因为 RandomSampler 是随机采样,所以返回的索引是随机的数值序列 (当 replacement=False 时,生成的排列是无重复的)
  • len 返回 dataset 中样本量

举例 2

from torch.utils.data import DataLoader, Dataset
from torch.utils.data.sampler import RandomSampler


class myDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]


# 示例数据 :0 到 19 的整数
data = [i for i in range(20)]
dataset = myDataset(data)

# 使用 SequentialSampler
sampler = RandomSampler(dataset)

# 创建 DataLoader
dataloader = DataLoader(dataset, batch_size=8, sampler=sampler)

# 使用 DataLoader 迭代数据
for data in dataloader:
    print(data)

在这里插入图片描述


2、sampler 与 shuffle 的互斥

  • 参数 sampler 与 参数 shuffle 是互斥的,不要同时使用 sampler 和 shuffle
  • 因为 shuffle 的默认值为 False,所以代码会兼容 shuffle 等于默认值 False 的情况,即 :
    • 当同时设置了 shuffle 与 sampler,且 shuffle=True,会报错
    • 当同时设置了 shuffle 与 sampler,且 shuffle=False,具体逻辑按照 sampler

3、批采样 BatchSampler

官方文档 :

https://pytorch.org/docs/stable/data.html#torch.utils.data.BatchSampler

torch.utils.data.DataLoaderde 的参数 batch_sample, 接收的一般是 torch.utils.data.BatchSampler 对象,
torch.utils.data.BatchSampler 的作用 : 包装另一个采样器,生成一个小批量索引采样器

torch.utils.data.BatchSampler(sampler, batch_size, drop_last)

举例 1

import torch.utils.data.sampler as sampler
data = list([17, 22, 3, 41, 8])

seq_sampler = sampler.SequentialSampler(data_source=data)
batch_sampler = sampler.BatchSampler(seq_sampler, 2, False )

for index in batch_sampler:
    print(index)

在这里插入图片描述
相关源码 (删减版本)

class BatchSampler(Sampler):
    def __init__(self, sampler, batch_size, drop_last):、
        self.sampler = sampler
        self.batch_size = batch_size
        self.drop_last = drop_last
        
    def __iter__(self):
        batch = []
        for idx in self.sampler:
            batch.append(idx)
            # 如果采样个数和batch_size相等则本次采样完成
            if len(batch) == self.batch_size:
                yield batch
                batch = []
        # for 结束后在不需要剔除不足batch_size的采样个数时返回当前batch        
        if len(batch) > 0 and not self.drop_last:
            yield batch
            
    def __len__(self):
        # 在不进行剔除时,数据的长度就是采样器索引的长度
        if self.drop_last:
            return len(self.sampler) // self.batch_size
        else:
            return (len(self.sampler) + self.batch_size - 1) // self.batch_size
  • 参数 :
    • sampler : 其他采样器实例
    • batch_size :批量大小
    • drop_last :为 “True”时,如果最后一个batch 采样得到的数据个数小于batch_size,则抛弃最后一个batch的数据

举例 2

import torch
from torch.utils.data import DataLoader, Dataset
from torch.utils.data.sampler import SequentialSampler, BatchSampler


class myDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]


# 示例数据 :# 生成 0 到 19 的整数
data = [i for i in range(20)]
dataset = myDataset(data)

# 使用 SequentialSampler 顺序采样
sequential_sampler = SequentialSampler(dataset)

# 使用 BatchSampler 将 SequentialSampler 和 batch_size 结合
batch_sampler = BatchSampler(sequential_sampler, batch_size=8, drop_last=False)

# 创建 DataLoader,使用 BatchSampler
dataloader = DataLoader(dataset, batch_sampler=batch_sampler)

# 使用 DataLoader 迭代数据
for data in dataloader:
    print(data)

4、BatchSampler 与 其他参数的互斥

如果你在 DataLoader(dataset, batch_sampler=batch_sampler) 中指定了参数 batch_sampler, 那么就不能再指定参数 batch_size、shuffle、sampler、和 drop_last 了,他们互斥。

因为:

  • 你在生成torch.utils.data.sampler.BatchSampler() 的时候,就已经制定过 batch_size、sampler、和 drop_last 这些参数了,
  • batch_sampler 与 shuffle 作用一致,所以也互斥

比如,如下代码就会报错,因为在 DataLoader 中重复指定了 batch_size

random_sampler = sampler.RandomSampler(data_source=dataset)
batch_sampler = sampler.BatchSampler(random_sampler, batch_size=2, drop_last=False)
dataloader = DataLoader(dataset, batch_size=2, batch_sampler=batch_sampler)

在这里插入图片描述

重写 collate_fn 实例

1、collate_fn 函数作用

在使用 torch.utils.data.dataset 时,参数 collate_fn 接受一个函数,该函数的函数名通常就定义为: collate_fn
collate_fn 函数的作用 :将多个 经过 dataset.getitem() 处理好的 样本数据,组合成一个 batch 的数据。
在这里插入图片描述
相关代码见最后【4、附】部分


2、默认 collate_fn 函数

简易实现版本 :

def default_collate(batch):
    # 检查样本类型并处理
    if isinstance(batch[0], torch.Tensor):
        return torch.stack(batch, dim=0)
    
    elif isinstance(batch[0], (list, tuple)):
        return [default_collate(samples) for samples in zip(*batch)]
    
    elif isinstance(batch[0], dict):
        return {key: default_collate([d[key] for d in batch]) for key in batch[0]}
    
    elif isinstance(batch[0], int):
        return torch.tensor(batch)  # 将 int 转换为 Tensor
    
    raise TypeError(f"Unsupported type: {type(batch[0])}")


3、自定义 collate_fn 函数

1)常见场景

举例 :一个 batch 中的 多张图片,经过 dataset.getitem() 方法,得到的图像输出尺寸不一样 (比如,可能因为 图像增强 使用 的 transforms ,设计的 最后一步处理方式是范围内的随机裁剪)

因为 网络要求输入数据的尺寸形式为 (batch_size, channel, high,width), 为了将多张图像数据打包成一个batch 的数据形式,需要将图像加上padding,保证所有图像尺寸一致,进而组成 batch 的数据形式

在这里插入图片描述
collate_fn 函数中需要处理的内容为 :

  • 对比 batch 中,所有图像的宽和高,找到最长的宽度 和 最长的高度
  • 将所有的图像都 padding 到最长的宽度 和 最长的高度
  • 处理的得到 mask 数据,用于标注 : 哪些位置是 有效像素,哪些位置是 padding
  • 将所有数据处理成 batch 的格式,进行返回
2)相关代码实现

相关代码 : 点击跳转

data_loader_train = DataLoader(dataset_train, 
                               batch_sampler=batch_sampler_train,
                               collate_fn=utils.collate_fn, 
                               num_workers=args.num_workers,
                               pin_memory=True)
                               
data_loader_val = DataLoader(dataset_val, 
                             args.batch_size, 
                             sampler=sampler_val,
                             drop_last=False, 
                             collate_fn=utils.collate_fn, 
                             num_workers=args.num_workers,
                             pin_memory=True)

相关代码 :点击跳转

def collate_fn(batch):
    batch = list(zip(*batch))
    batch[0] = nested_tensor_from_tensor_list(batch[0])
    return tuple(batch)


def _max_by_axis(the_list):
    # type: (List[List[int]]) -> List[int]
    maxes = the_list[0]
    for sublist in the_list[1:]:
        for index, item in enumerate(sublist):
            maxes[index] = max(maxes[index], item)
    return maxes


def nested_tensor_from_tensor_list(tensor_list: List[Tensor]):
    # TODO make this more general
    if tensor_list[0].ndim == 3:
        # TODO make it support different-sized images
        max_size = _max_by_axis([list(img.shape) for img in tensor_list])
        # min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list]))
        batch_shape = [len(tensor_list)] + max_size
        b, c, h, w = batch_shape
        dtype = tensor_list[0].dtype
        device = tensor_list[0].device
        tensor = torch.zeros(batch_shape, dtype=dtype, device=device)
        mask = torch.ones((b, h, w), dtype=torch.bool, device=device)
        for img, pad_img, m in zip(tensor_list, tensor, mask):
            pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
            m[: img.shape[1], :img.shape[2]] = False
    else:
        raise ValueError('not supported')
    return NestedTensor(tensor, mask)

4、附

在这里插入图片描述
注 :更换 cifar-100 在你本地的路径

import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import os


torch.manual_seed(121)
torch.cuda.manual_seed(121)

label_dict = {
    'apple': 0,
    'aquarium_fish': 1,
    'baby': 2,
    'bear': 3,
    'beaver': 4,
    'bed': 5,
    'bee': 6,
    'beetle': 7,
    'bicycle': 8,
    'bottle': 9,
    'bowl': 10,
    'boy': 11,
    'bridge': 12,
    'bus': 13,
    'butterfly': 14,
    'camel': 15,
    'can': 16,
    'castle': 17,
    'caterpillar': 18,
    'cattle': 19,
    'chair': 20,
    'chimpanzee': 21,
    'clock': 22,
    'cloud': 23,
    'cockroach': 24,
    'couch': 25,
    'crab': 26,
    'crocodile': 27,
    'cup': 28,
    'dinosaur': 29,
    'dolphin': 30,
    'elephant': 31,
    'flatfish': 32,
    'forest': 33,
    'fox': 34,
    'girl': 35,
    'hamster': 36,
    'house': 37,
    'kangaroo': 38,
    'keyboard': 39,
    'lamp': 40,
    'lawn_mower': 41,
    'leopard': 42,
    'lion': 43,
    'lizard': 44,
    'lobster': 45,
    'man': 46,
    'maple_tree': 47,
    'motorcycle': 48,
    'mountain': 49,
    'mouse': 50,
    'mushroom': 51,
    'oak_tree': 52,
    'orange': 53,
    'orchid': 54,
    'otter': 55,
    'palm_tree': 56,
    'pear': 57,
    'pickup_truck': 58,
    'pine_tree': 59,
    'plain': 60,
    'plate': 61,
    'poppy': 62,
    'porcupine': 63,
    'possum': 64,
    'rabbit': 65,
    'raccoon': 66,
    'ray': 67,
    'road': 68,
    'rocket': 69,
    'rose': 70,
    'sea': 71,
    'seal': 72,
    'shark': 73,
    'shrew': 74,
    'skunk': 75,
    'skyscraper': 76,
    'snail': 77,
    'snake': 78,
    'spider': 79,
    'squirrel': 80,
    'streetcar': 81,
    'sunflower': 82,
    'sweet_pepper': 83,
    'table': 84,
    'tank': 85,
    'telephone': 86,
    'television': 87,
    'tiger': 88,
    'tractor': 89,
    'train': 90,
    'trout': 91,
    'tulip': 92,
    'turtle': 93,
    'wardrobe': 94,
    'whale': 95,
    'willow_tree': 96,
    'wolf': 97,
    'woman': 98,
    'worm': 99
}


def default_collate(batch):
    # 检查样本类型并处理
    if isinstance(batch[0], torch.Tensor):
        return torch.stack(batch, dim=0)

    elif isinstance(batch[0], (list, tuple)):
        return [default_collate(samples) for samples in zip(*batch)]

    elif isinstance(batch[0], dict):
        return {key: default_collate([d[key] for d in batch]) for key in batch[0]}

    elif isinstance(batch[0], int):
        return torch.tensor(batch)  # 将 int 转换为 Tensor

    raise TypeError(f"Unsupported type: {type(batch[0])}")


class CustomDataset(Dataset):
    def __init__(self, data_folder, train, transform=None):
        self.data_folder = data_folder
        self.transform = transform
        self.file_list = os.listdir(data_folder)
        self.train = train

    def __getitem__(self, idx):
        img_name = os.path.join(self.data_folder, self.file_list[idx])
        original_image = Image.open(img_name)
        label_name = img_name.split('_')[-1].split('.')[0]
        label_idx = label_dict[label_name]

        if self.train:
            image = self.transform(original_image)
        else:
            image = self.transform(original_image)

        return image, label_idx

    def __len__(self):
        return len(self.file_list)


images_dir = "/Users/enzo/Documents/GitHub/dataset/CIFAR/cifar-100-images/train"
dataset = CustomDataset(images_dir, train=True, transform=transforms.ToTensor())

data_loader = DataLoader(dataset,
                         batch_size=2,
                         shuffle=True,
                         collate_fn=default_collate)


for data in data_loader:
    image, label = data

RandomSampler 与 shuffle=True 的区别

效果完全没有区别,只是实现方式不一样。

  • shuffle=True 的实现方式: 在每个 epoch 开始时将整个数据集打乱,然后按照打乱后的顺序划分 batch。再按照batch_size 个数依次提取数据
  • sampler.BatchSampler(random_sampler) 的实现方式:(数据不会打乱)
    • step 1、RandomSampler 会生成随机的索引。
    • step 2、BatchSampler 根据上面随机出来的索引生成 batch 组。
    • step 3、拿着每个batch 组的索引去取 数据

相同点:

  1. 每个epoch 都会重新打乱
  2. 都不会重复采样,除非你通过参数指定了可以重复采样

其他说明:

  1. shuffle=True 的性能更高一些,而 BatchSampler灵活性更高,因为你可以通过 BatchSampler 设计更复杂的采样方式
  2. 在 Dataloader 中使用 batch_sampler 的常见目的之一,是为了兼容 DistributedSampler,比如:
if args.distributed:
    sampler_train = DistributedSampler(dataset_train)
    sampler_val = DistributedSampler(dataset_val, shuffle=False)
else:
    sampler_train = torch.utils.data.RandomSampler(dataset_train)
    sampler_val = torch.utils.data.SequentialSampler(dataset_val)

batch_sampler_train = torch.utils.data.BatchSampler(sampler_train, args.batch_size, drop_last=True)

data_loader_train = DataLoader(dataset_train,
                               batch_sampler=batch_sampler_train,
                               collate_fn=utils.collate_fn,
                               )
data_loader_val = DataLoader(dataset_val,
                             args.batch_size,
                             sampler=sampler_val,
                             drop_last=False,
                             collate_fn=utils.collate_fn,
                             )

跑个小例子,看一下 :

import torch
import torch.utils.data.sampler as sampler
from torch.utils.data import Dataset, DataLoader

class MyDataset(Dataset):
    def __init__(self):
        self.data = [1, 2, 3, 4, 5]

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        return self.data[index]


dataset = MyDataset()

# =============================================
random_sampler = sampler.RandomSampler(data_source=dataset)
batch_sampler = sampler.BatchSampler(random_sampler, batch_size=2, drop_last=False)
dataloader1 = DataLoader(dataset, batch_sampler=batch_sampler)

for epoch in range(3):
    for index, data in enumerate(dataloader1):
        print(index, data)
print('*'*30)

# =============================================
dataloader2 = DataLoader(dataset, batch_size=2, shuffle=True)

for epoch in range(3):
    for index, data in enumerate(dataloader2):
        print(index, data)

在这里插入图片描述

数据处理&数据增强

数据预处理 和 数据增强,我们一般都是使用 torchvision.transforms 模块来完成的。
我敢说,当你掌握了 torchvision.transforms 的使用方法之后,一定在数据预处理 和 数据增强 方面毫无压力。
官网地址 :

https://pytorch.org/vision/stable/transforms.html#others

在这里插入图片描述


简单使用举例:
1、训练阶段

from torchvision.transforms import transforms

my_transform = transforms.Compose([transforms.RandomResizedCrop(img_size),
                                   transforms.RandomHorizontalFlip(),
                                   transforms.ToTensor(),
                                   transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                                                        std=[0.229, 0.224, 0.225])])

2、推理阶段

from torchvision.transforms import transforms

my_transform = transforms.Compose([transforms.Resize(original_size*1.143)
                                   transforms.CenterCrop(img_size),
                                   transforms.ToTensor(),
                                   transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                                                        std=[0.229, 0.224, 0.225])])

注:

  1. 操作顺序 : 几何变换 / 颜色变换 —>> ToTensor() —>> Normalize()
  2. ToTensor()Normalize() 是数据处理的最后2步
  3. 经过 Normalize()处理后,得到的数据,一般可直接输入到模型中使用

1、图像尺寸变换 与 裁剪

1)transforms.Resize

官方文档 : 点击跳转

torchvision.transforms.Resize(size, 
                              interpolation=InterpolationMode.BILINEAR, 
                              max_size=None)

作用:将图像按照指定的插值方式,resize到指定的尺寸。
参数:

  • size: 输出的图像尺寸。可以是元组 (h, w) ,也可以是单个整数。
    • 如果 size 是元组,则输出大小将分别匹配 h, w 的大小
    • 如果 size 是整数,则图像较小的边将被resize 到此数字,并保持宽高比
  • interpolation: 选用如下插值方法将图像 resize 到输出尺寸
    • PIL.Image.NEAREST 最近邻差值
    • PIL.Image.BILINEAR 双线性差值(默认)
    • PIL.Image.BICUBIC 双三次差值
  • max_size :输出图像的较长边的最大值。仅当 size 为单个整数时才支持此功能。如果图像的较长边在根据 size 缩放后大于 max_size,则 size 将被覆盖,使较长边等于 max_size,这时较短边会小于 size。
from PIL import Image
import torchvision.transforms as transforms
import matplotlib.pyplot as plt

original_img = Image.open('image.jpg')  # https://p.ipic.vip/7pvisy.jpg
print(original_img.size)   # (3280, 1818)

img_1 = transforms.Resize(1500, max_size=None)(original_img)
print(img_1.size)(2706, 1500)   # (2706, 1500)

img_2 = transforms.Resize((1500, 1500))(original_img)
print(img_2.size)   # (1500, 1500)

img_3 = transforms.Resize(1500, max_size=1600)(original_img)
print(img_3.size)   # (1600, 886)


plt.subplot(141)
plt.axis("off")
plt.imshow(original_img)

plt.subplot(142)
plt.axis("off")
plt.imshow(img_1)

plt.subplot(143)
plt.axis("off")
plt.imshow(img_2)

plt.subplot(144)
plt.axis("off")
plt.imshow(img_3)

plt.show()

在这里插入图片描述


2)transforms.CenterCrop

官方文档 : 点击跳转

功能:从图片中心裁剪出尺寸为 size 的图片
参数:

  • size: 所需裁剪的图片尺寸,即输出图像尺寸

注意:

  • 若切正方形,transforms.CenterCrop(100) transforms.CenterCrop((100, 100)),两种写法,效果一样
  • 如果设置的输出的尺寸 大于原图像尺寸,则会在四周补 padding,padding 颜色为黑色(像素值为0)

举例:

from PIL import Image
import torchvision.transforms as transforms
import matplotlib.pyplot as plt

original_img = Image.open('image.jpg')  # https://p.ipic.vip/7pvisy.jpg
print(original_img.size)   # (3280, 1818)

img_1 = transforms.CenterCrop(1500)(original_img)
img_2 = transforms.CenterCrop((1500, 1500))(original_img)
img_3 = transforms.CenterCrop((3000, 3000))(original_img)

plt.subplot(141)
plt.axis("off")
plt.imshow(original_img)

plt.subplot(142)
plt.axis("off")
plt.imshow(img_1)

plt.subplot(143)
plt.axis("off")
plt.imshow(img_2)

plt.subplot(144)
plt.axis("off")
plt.imshow(img_3)

plt.show()

在这里插入图片描述


3)transforms.RandomCrop

官方文档 : 点击跳转
功能:

  • 从图片中随机裁剪出尺寸为 size 的图片
  • 如果设置了参数 padding,先添加 padding,再从padding后的图像中随机裁剪出大小为size的图片
    参数:
  • size :所需裁剪的图片尺寸,即输出图像尺寸
  • padding : 设置填充大小
    • padding值形式式为 a 时,上下左右均填充 a 个像素
    • padding值形式式为 (a, b) 时,左右填充 a 个像素,上下填充 b 个像素
    • padding值形式式为 (a, b, c, d) 时,左上右下分别填充 a,b,c,d
  • pad_if_needed :当原图像尺寸小于设置的输出图像尺寸(由参数size指定),是否填充,默认为 False
  • padding_mode :若 pad_if_needed设置为 True,则此参数起作用, 默认值为 “constant”
    • "constant" : 像素值由参数 fill 指定 (默认填充黑色,像素值为0)
    • "edge" : padding 的像素值 为图像边缘像素值
    • "reflect" : 镜像填充,最后一个像素不镜像。([1,2,3,4] --> [3,2,1,2,3,4,3,2])
    • "symmetric" : 镜像填充,最后一个像素也镜像。([1,2,3,4] -->[2,1,1,2,3,4,4,3])
  • fill :指定填充像素值,当 padding_mode 为 constant 时起作用,默认填充黑色,像素值为0

注意:

  • 同时指定参数padding_mode 和 参数fill 时,若 padding_mode 值不为 "constant" ,则 参数fill不起作用。
  • 若指定的输出图像尺寸size 大于输入图像尺寸,并且指定参数 pad_if_needed= False,则会报错类似如下
    在这里插入图片描述

举例:

from PIL import Image
import torchvision.transforms as transforms
import matplotlib.pyplot as plt

original_img = Image.open('image.jpg')  # https://p.ipic.vip/7pvisy.jpg
print(original_img.size)   # (3280, 1818)

img_1 = transforms.RandomCrop(1500, padding=500)(original_img)
img_2 = transforms.RandomCrop(3000, pad_if_needed=True, fill=(255, 0, 0))(original_img)
img_3 = transforms.RandomCrop(3000, pad_if_needed=True, padding_mode="symmetric")(original_img)

plt.subplot(141)
plt.axis("off")
plt.imshow(original_img)

plt.subplot(142)
plt.axis("off")
plt.imshow(img_1)

plt.subplot(143)
plt.axis("off")
plt.imshow(img_2)

plt.subplot(144)
plt.axis("off")
plt.imshow(img_3)

plt.show()

在这里插入图片描述

4)transforms.RandomResizedCrop

官方文档 : 点击跳转

torchvision.transforms.RandomResizedCrop(size, 
                                         scale=(0.08, 1.0), 
                                         ratio=(0.75, 1.3333333333333333), 
                                         interpolation=InterpolationMode.BILINEAR)

功能:

  • Step 1 : 将图像进行随机裁剪,裁剪出的图像需满足:
    • 裁剪后的图像面积 占原图像面积的比例 在指定的范围内
    • 裁剪后的图像高宽比 在指定范围内
  • Step 2 :将 Step 1 得到的图像通过指定的方式,进行缩放
    参数:
    • size: 输出的图像尺寸
    • scale: 随机缩放面积比例,默认随机选取 (0.08, 1) 之间的一个数
    • ratio: 随机长宽比,默认随机选取 (0.75, 1.33333 ) 之间的一个数。超过这个比例范围会有明显的失真
    • interpolation: 选用如下插值方法将图像 resize 到输出尺寸
      • PIL.Image.NEAREST 最近邻差值
      • PIL.Image.BILINEAR 双线性差值(默认)
      • PIL.Image.BICUBIC 双三次差值

举例:

from PIL import Image
import torchvision.transforms as transforms
import matplotlib.pyplot as plt

original_img = Image.open('image.jpg')  # https://p.ipic.vip/7pvisy.jpg
print(original_img.size)   # (3280, 1818)

img = transforms.RandomResizedCrop(1500)(original_img)

plt.subplot(121)
plt.imshow(original_img)

plt.subplot(122)
plt.imshow(img)

plt.show()

在这里插入图片描述


http://www.kler.cn/a/380011.html

相关文章:

  • 数据集-目标检测系列- 电话 测数据集 call_phone >> DataBall
  • 透明部署、旁路逻辑串联的区别
  • Windows 安装 Docker 和 Docker Compose
  • 单元测试MockitoExtension和SpringExtension
  • 闲谭SpringBoot--ShardingSphere分库分表探究
  • 成为LabVIEW自由开发者
  • js实现漂亮的注册页面(js动态注册页面)
  • 使用 Nginx 部署 Python 项目
  • 【系统设计】高效的分布式系统:使用 Spring Boot 和 Kafka 实现 Saga 模式
  • 【STM32】STM32G431RBT6单片机的几种烧录方式
  • golang函数类型Function Types
  • 废品回收小程序搭建,互联网回收行业的特点
  • 如何更改Android studio的项目存储路径
  • 强网杯-PWN-baby_heap
  • 清单文件 AndroidManifest.xml
  • 操作系统同步机制(锁、信号量等)
  • 基于大数据的热门旅游景点数据分析系统的设计与实现
  • 2-ARM Linux驱动开发-设备树平台驱动
  • 在Android开发中,如何获取手机设备中的所有文件信息?
  • CubeIDE BUG-project‘hello‘has no explict encoding set hello
  • Windows SEH异常处理讨论
  • 【软考】反规范化技术
  • 代码训练营 day55|卡码网98
  • Jenkins找不到maven构建项目
  • H7-TOOL的CAN/CANFD助手增加帧发送成功标识支持, 继续加强完善功能细节
  • 【GESP】C++一级真题练习(202303)luogu-B3835,每月天数