当前位置: 首页 > article >正文

【深度学习】图像分类数据集

图像分类数据集

MNIST数据集是图像分类中广泛使用的数据集之一,但作为基准数据集过于简单。
我们将使用类似但更复杂的Fashion-MNIST数据集

%matplotlib inline
import torch
import torchvision
from torch.utils import data
from torchvision import transforms
from d2l import torch as d2l

d2l.use_svg_display()#设置图表大小,具体实现过程及其底层逻辑见微积分一节

读取数据集

我们可以[通过框架中的内置函数将Fashion-MNIST数据集下载并读取到内存中]。

# 通过ToTensor实例将图像数据从PIL类型变换成32位浮点数格式,

# 并除以255使得所有像素的数值均在0~1之间

trans = transforms.ToTensor()

mnist_train = torchvision.datasets.FashionMNIST(
    root="../data", train=True, transform=trans, download=True)
    
mnist_test = torchvision.datasets.FashionMNIST(
    root="../data", train=False, transform=trans, download=True)

这段代码的主要目的是从 torchvision 库中下载并加载 Fashion - MNIST 数据集,同时对数据进行预处理,将图像转换为 PyTorch 张量。
代码主要分为三个部分:定义图像预处理操作、加载训练集数据、加载测试集数据。下面逐行进行详细解释。

1. 定义图像预处理操作

trans = transforms.ToTensor()

  • 功能:创建一个图像预处理的转换对象 transtransforms.ToTensor()torchvision.transforms 模块里的一个类,专门用于将 PIL(Python Imaging Library)图像或者 NumPy 数组(一般是 uint8 类型)转换为 torch.FloatTensor 类型的张量。
  • 转换细节
    - 在转换过程中,会把图像的像素值归一化到 [0.0, 1.0] 范围。例如,原始图像像素值范围是 [0, 255],经过该转换后,像素值会除以 255,变成 [0.0, 1.0] 之间的浮点数。
    - 同时,转换后张量的维度也会发生变化。对于单通道的灰度图像,会从 (H, W)(高度和宽度)变为 (1, H, W);对于三通道的彩色图像,会从 (H, W, C) 变为 (C, H, W),这里 C 代表通道数。

2. 加载训练集数据

mnist_train = torchvision.datasets.FashionMNIST(root="../data", train=True, transform=trans, download=True)

  • 功能:创建一个 FashionMNIST 数据集对象 mnist_train,用于加载 Fashion - MNIST 数据集的训练集部分。
  • 参数解释
    - root="../data":指定数据集的存储路径。若该路径下没有数据集,下载的数据会存于此;若已存在,则直接从该路径加载数据。
    - train=True:表明要加载的是训练集数据。Fashion - MNIST 数据集包含 60,000 张训练图像和 10,000 张测试图像,通过此参数区分加载的是训练集还是测试集。
    - transform=trans:指定对图像数据进行的预处理操作。这里使用之前创建的 trans 对象,即对每个图像应用 ToTensor() 变换,将其转换为张量
    - download=True:如果指定路径下未找到数据集,会自动从网络下载 Fashion - MNIST 数据集。

3. 加载测试集数据

mnist_test = torchvision.datasets.FashionMNIST(root="../data", train=False, transform=trans, download=True)

  • 功能:创建一个 FashionMNIST 数据集对象 mnist_test,用于加载 Fashion - MNIST 数据集的测试集部分。
  • 参数解释:与加载训练集的代码基本相同,唯一区别在于 train=False,表示加载的是测试集数据。

Fashion-MNIST由10个类别的图像组成,每个类别由训练数据集(train dataset)中的6000张图像
测试数据集(test dataset)中的1000张图像组成。
因此,训练集和测试集分别包含60000和10000张图像。测试数据集不会用于训练,只用于评估模型性能。

len(mnist_train), len(mnist_test)

在这里插入图片描述
每个输入图像的高度和宽度均为28像素。
数据集由灰度图像组成,其通道数为1。
为了简洁起见,将高度 h h h像素、宽度 w w w像素图像的形状记为 h × w h \times w h×w或( h h h, w w w)。

mnist_train[0][0].shape

在这里插入图片描述
[两个可视化数据集的函数]

Fashion-MNIST中包含的10个类别,分别为t-shirt(T恤)、trouser(裤子)、pullover(套衫)、dress(连衣裙)、coat(外套)、sandal(凉鞋)、shirt(衬衫)、sneaker(运动鞋)、bag(包)和ankle boot(短靴)。
以下函数用于在数字标签索引及其文本名称之间进行转换。

def get_fashion_mnist_labels(labels):  #@save
    """返回Fashion-MNIST数据集的文本标签"""
    text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat',
                   'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']
    return [text_labels[int(i)] for i in labels]

列表推导式
[expression for item in iterable]

  • expression:对每个 item 进行操作后得到的结果,它将成为新列表中的一个元素。
  • item:从 iterable 中取出的单个元素。
  • iterable:一个可迭代对象,如列表、元组、字符串等。

示例代码

text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat',
               'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']
labels = [0, 2, 4]
result = [text_labels[int(i)] for i in labels]
print(result)  # 输出: ['t-shirt', 'pullover', 'coat']

我们现在可以创建一个函数来可视化这些样本。

def show_images(imgs, num_rows, num_cols, titles=None, scale=1.5):  #@save
    """绘制图像列表"""
    figsize = (num_cols * scale, num_rows * scale)
    _, axes = d2l.plt.subplots(num_rows, num_cols, figsize=figsize)
    axes = axes.flatten()
    for i, (ax, img) in enumerate(zip(axes, imgs)):
        if torch.is_tensor(img):
            # 图片张量
            ax.imshow(img.numpy())
        else:
            # PIL图片
            ax.imshow(img)
        ax.axes.get_xaxis().set_visible(False)
        ax.axes.get_yaxis().set_visible(False)
        if titles:
            ax.set_title(titles[i])
    return axes

子图坐标轴对象
在 matplotlib 中,一个图形(Figure)可以包含多个子图(Axes),每个子图就是一个独立的绘图区域,子图坐标轴对象(Axes 对象)就代表了这些独立的绘图区域。它可以被看作是一个 “画布”,你可以在这个 “画布” 上进行各种绘图操作,比如绘制线条、散点、柱状图等,还可以设置坐标轴的范围、标签、标题等。

以下是对 show_images 函数的详细解释:

  • def show_images(imgs, num_rows, num_cols, titles=None, scale=1.5):
    • 定义了一个名为 show_images 的函数,用于将一组图像以网格形式展示出来。
    • imgs:是一个包含图像的列表,这些图像可以是 PyTorch 张量,也可以是 PIL(Python Imaging Library)图像对象。
    • num_rows:指定了要展示的图像网格的行数。
    • num_cols:指定了要展示的图像网格的列数。
    • titles:是一个可选参数,类型为列表,用于为每个图像设置对应的标题。如果不提供该参数,则默认不显示标题。
    • scale:同样是可选参数,是一个浮点数,用于调整图像显示的缩放比例,默认值为 1.5。
  • figsize = (num_cols * scale, num_rows * scale):
    • 这行代码根据 num_cols(列数)、num_rows(行数)和 scale(缩放比例)计算出整个图像展示窗口的大小。
    • figsize 是一个元组,第一个元素是窗口的宽度,由列数乘以缩放比例得到;第二个元素是窗口的高度,由行数乘以缩放比例得到。
  • _, axes = d2l.plt.subplots(num_rows, num_cols, figsize=figsize)
    • num_rowsnum_cols 分别指定了子图的行数和列数,也就是图像网格的布局。
    • figsize=figsize 表示使用之前计算好的窗口大小。
    • subplots 函数返回两个值,第一个是 Figure 对象,这里用 _ 占位表示我们不关心这个返回值;第二个是一个包含所有子图坐标轴对象的数组,赋值给 axes
  • axes = axes.flatten()
    • axes 原本是一个二维数组,因为它对应着 num_rows 行和 num_cols 列的子图布局。
    • flatten 方法将这个二维数组转换为一维数组,这样在后续遍历图像和子图时会更加方便。
  • for i, (ax, img) in enumerate(zip(axes, imgs))
    • zip(axes, imgs)axes 数组(包含所有子图坐标轴对象)和 imgs 列表(包含所有要展示的图像)中的元素一一对应地组合起来。
    • enumerate 函数用于为组合后的元素添加索引,i 就是当前元素的索引。
    • 在每次循环中,ax 代表当前子图的坐标轴对象,img 代表当前要展示的图像。
        if torch.is_tensor(img):
            # 图片张量
            ax.imshow(img.numpy())
        else:
            # PIL图片
            ax.imshow(img)
  • torch.is_tensor(img) 用于判断当前的 img 是否为 PyTorch 张量。
  • 如果是张量,使用 img.numpy() 将其转换为 NumPy 数组,因为 matplotlibimshow 函数更适合处理 NumPy 数组。然后使用 ax.imshow 函数在当前子图上显示图像。
  • 如果不是张量,说明 img 可能是 PIL 图像对象,直接使用 ax.imshow 函数显示该图像。
        ax.axes.get_xaxis().set_visible(False)
        ax.axes.get_yaxis().set_visible(False)
  • ax.axes.get_xaxis() 获取当前子图的 x 轴对象,set_visible(False) 方法将 x 轴设置为不可见。
  • 同理,ax.axes.get_yaxis() 获取当前子图的 y 轴对象,set_visible(False) 方法将 y 轴设置为不可见。这样可以使图像显示更加简洁,只专注于图像内容。
        if titles:
            ax.set_title(titles[i])
  • if titles: 检查是否提供了 titles 列表。
  • 如果提供了,使用 ax.set_title 方法为当前子图设置对应的标题,标题从 titles 列表中根据当前索引 i 取出。
    return axes
  • 最后,函数返回 axes 数组,这个数组包含了所有子图的坐标轴对象。返回它的目的是方便在调用该函数后,对图形进行进一步的操作,例如修改坐标轴属性等。

以下是训练数据集中前[几个样本的图像及其相应的标签]。

X, y = next(iter(data.DataLoader(mnist_train, batch_size=18)))
show_images(X.reshape(18, 28, 28), 2, 9, titles=get_fashion_mnist_labels(y));

在这里插入图片描述

读取小批量

为了使我们在读取训练集和测试集时更容易,我们使用内置的数据迭代器,而不是从零开始创建。
回顾一下,在每次迭代中,数据加载器每次都会[读取一小批量数据,大小为batch_size]。
通过内置数据迭代器,我们可以随机打乱了所有样本,从而无偏见地读取小批量。

batch_size = 256

def get_dataloader_workers():  #@save
    """使用4个进程来读取数据"""
    return 4
#shuffle表示在每个训练周期开始时,对数据集进行随机打乱
train_iter = data.DataLoader(mnist_train, batch_size, shuffle=True,
                             num_workers=get_dataloader_workers())

我们看一下读取训练数据所需的时间。

timer = d2l.Timer()
for X, y in train_iter:
    continue
f'{timer.stop():.2f} sec'

整合所有组件

现在我们[定义load_data_fashion_mnist函数],用于获取和读取Fashion-MNIST数据集。
这个函数返回训练集和验证集的数据迭代器。
此外,这个函数还接受一个可选参数resize,用来将图像大小调整为另一种形状。

def load_data_fashion_mnist(batch_size, resize=None):  #@save
    """下载Fashion-MNIST数据集,然后将其加载到内存中"""
    trans = [transforms.ToTensor()]
    #trans初始化为一个包含transforms.ToTensor()的列表
    
    if resize:
        trans.insert(0, transforms.Resize(resize))
        #在 trans 列表的开头插入 transforms.Resize(resize) 操作
        
    trans = transforms.Compose(trans)
    #将 trans 列表中的所有变换操作组合成一个完整的变换序列 trans
    
    mnist_train = torchvision.datasets.FashionMNIST(
        root="../data", train=True, transform=trans, download=True)
    mnist_test = torchvision.datasets.FashionMNIST(
        root="../data", train=False, transform=trans, download=True)
        
    return (data.DataLoader(mnist_train, batch_size, shuffle=True,
                            num_workers=get_dataloader_workers()),
            data.DataLoader(mnist_test, batch_size, shuffle=False,
                            num_workers=get_dataloader_workers()))

下面,我们通过指定resize参数来测试load_data_fashion_mnist函数的图像大小调整功能。

train_iter, test_iter = load_data_fashion_mnist(32, resize=64)
for X, y in train_iter:
    print(X.shape, X.dtype, y.shape, y.dtype)#X.shape表示张量 X 的形状,X.dtype表示张量 X 中元素的数据类型
    break

在这里插入图片描述


http://www.kler.cn/a/525199.html

相关文章:

  • Visual Studio Code修改terminal字体
  • 架构技能(六):软件设计(下)
  • Ubuntu介绍、与centos的区别、基于VMware安装Ubuntu Server 22.04、配置远程连接、安装jdk+Tomcat
  • 新年快乐!给大家带来了一份 python 烟花代码!
  • Cannot resolve symbol ‘XXX‘ Maven 依赖问题的解决过程
  • Java坦克大战
  • Kafa分区策略实现
  • fpga系列 HDL:XILINX Vivado Vitis 高层次综合(HLS) 实现 EBAZ板LED控制(下)
  • 前端力扣刷题 | 2:hot100之 双指针
  • Web3 如何赋能元宇宙,实现虚实融合的无缝对接
  • 论“0是不存在的”
  • H3CNE-27-链路聚合(L3)
  • 使用shell命令安装virtualbox的虚拟机并导出到vagrant的Box
  • 正则表达式入门
  • DeepSeek的崛起与全球科技市场的震荡
  • C++并发编程指南03
  • 【JavaWeb】利用IntelliJ IDEA 2024.1.4 +Tomcat10 搭建Java Web项目开发环境(图文超详细)
  • 商品信息管理自动化测试
  • 落地基于特征的图像拼接
  • 研发的立足之本到底是啥?
  • 跨平台物联网漏洞挖掘算法评估框架设计与实现文献综述之Gemini
  • 我的求职之路合集
  • zsh安装插件
  • Vue演练场基础知识(七)插槽
  • sentence_transformers安装
  • BGP分解实验·15——路由阻尼(抑制/衰减)实验