当前位置: 首页 > article >正文

【基于深度学习的验证码识别】---- part3数据加载、模型等API介绍(1)

一、MNIST数据集

MNIST(Modified National Institute of Standards and Technology)数据集是计算机视觉和机器学习领域最经典的入门级数据集之一,主要用于手写数字识别任务。

使用示例(以PyTorch为例)

from torchvision.datasets import MNIST
mnist_train = MNIST(root='./MNIST_data', train=True, download=True)

在这里插入图片描述

from torchvision.datasets import MNIST
import matplotlib.pyplot as plt
mnist_train = MNIST(root='./MNIST_data', train=True, download=True)

# 训练集长度
print(len(mnist_train))
# 取第一个图片
print(mnist_train[0])
image = mnist_train[5000][0]
# 打印出图片
plt.imshow(image)
plt.show()
print(mnist_train[5000][1])

二、数据加载

在PyTorch中,使用DataLoader加载MNIST数据集时,参数的合理配置直接影响训练效率和模型性能。以下是核心参数的详细说明及其在MNIST场景中的应用:

from torch.utils.data import DataLoader

参数:batch_size、shuffle、num_workers、pin_memory、drop_last

1、batch_size(批次大小)
  • 定义:每个批次包含的样本数量。例如,batch_size=64表示每次迭代加载64张图像。
  • 作用:定义每个批次包含的样本数量。例如,若batch_size=64,则每次迭代从数据集中加载64张手写数字图像。
  • MNIST应用
    MNIST图像尺寸为28x28,单个样本数据量小,通常可设置较大的batch_size(如64或128)以充分利用显存并加速训练。
    显存不足时需减小batch_size,否则会引发内存错误(OOM)
2、 shuffle(数据打乱)
  • 定义:是否在每个训练周期(epoch)开始时随机打乱数据顺序。
  • 作用
    • 防止模型偏见:避免模型学习到数据顺序特征(如MNIST训练集需设为True)。
    • 测试集处理:测试时通常设为False以保持评估结果一致性。
  • MNIST应用
    # 训练集打乱,测试集不打乱
    train_loader = DataLoader(..., shuffle=True)
    test_loader = DataLoader(..., shuffle=False)
    
3、 num_workers(子进程数)
  • 定义:用于并行加载数据的子进程数量。默认为0(主进程加载)。
  • 作用
    • 加速数据加载:多进程并行读取数据(建议设为CPU核心数的2~4倍,如4或8)。
    • 资源平衡:MNIST数据量小,过高值可能导致内存溢出(需实验调优)。
  • MNIST应用
    # 使用4个子进程加载数据
    train_loader = DataLoader(..., num_workers=4)
    
4、pin_memory(内存锁定)
  • 定义:是否将数据复制到CUDA固定内存(pinned memory)。
  • 作用
    • 加速GPU传输:启用后,数据从CPU到GPU的传输速度更快(GPU训练时强烈建议设为True)。
    • 资源占用:仅对GPU有效,CPU训练时可忽略。
  • MNIST应用
  # GPU训练时启用内存锁定
  train_loader = DataLoader(..., pin_memory=True)
5、 drop_last(丢弃末批)
  • 定义:当数据集大小无法被batch_size整除时,是否丢弃最后一个不完整批次。
  • 作用
    • 避免小批次影响:丢弃末尾样本(如MNIST训练集60000样本,batch_size=64时最后一个批次含16样本)。
    • 分布式训练对齐:需所有批次大小一致时启用。
  • MNIST应用
    # 丢弃不完整批次
    train_loader = DataLoader(..., batch_size=64, drop_last=True)
    

代码示例

from torch.utils.data import DataLoader
from torchvision import datasets, transforms

# 定义数据转换
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# 加载 MNIST 数据集
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)

# 创建 DataLoader
train_loader = DataLoader(
    dataset=train_dataset,
    batch_size=64,
    shuffle=True,
    num_workers=4,
    pin_memory=True,
    drop_last=True
)

三、图片处理 transform

在深度学习中,图像数据通常需要进行预处理(如缩放、裁剪、归一化等)以适应模型的输入要求。PyTorch 提供了 torchvision.transforms 模块,用于定义和实现这些图像处理操作。

transforms 的作用

transforms 是一个用于图像预处理的工具集,可以将一系列图像处理操作组合在一起,形成一个处理流水线(pipeline)。这些操作通常包括:

  • 数据增强:增加数据的多样性,防止模型过拟合。
  • 数据标准化:将数据转换为模型所需的格式(如归一化到特定范围)。
  • 数据转换:将图像转换为张量(Tensor)格式,以便输入模型。
常用 transforms 操作
1、基础操作
  • Resize: 调整图像大小。
transforms.Resize((height, width))  # 将图像调整为指定大小
  • CenterCrop: 从图像中心裁剪指定大小的区域。
transforms.CenterCrop(size)  # 裁剪大小为 (size, size)
  • RandomCrop: 随机裁剪图像。
transforms.RandomCrop(size)  # 随机裁剪大小为 (size, size)
  • RandomHorizontalFlip: 随机水平翻转图像。
transforms.RandomHorizontalFlip(p=0.5)  # 以 50% 的概率水平翻转
  • RandomRotation: 随机旋转图像。
transforms.RandomRotation(degrees=30)  # 随机旋转 ±30 度
2、 颜色变换
  • ColorJitter: 随机改变图像的亮度、对比度、饱和度和色调。
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1)
  • Grayscale: 将图像转换为灰度图。
transforms.Grayscale(num_output_channels=1)  # 转换为单通道灰度图
3、 归一化和标准化
  • ToTensor: 将图像(PIL 或 NumPy 格式)转换为 PyTorch 张量(Tensor),并将像素值从 [0, 255] 缩放到 [0, 1]。
transforms.ToTensor()

在使用 transforms.ToTensor() 处理图像后,PyTorch 会将图像的通道维度移动到最前面。
transforms.ToTensor() 的作用
1.将图像转换为张量:
输入的图像通常是 PIL 图像或 NumPy 数组,形状为 (H, W, C),其中:
H 是图像的高度(Height)。
W 是图像的宽度(Width)。
C 是图像的通道数(Channels,例如 RGB 图像为 3,灰度图像为 1)。
transforms.ToTensor() 会将图像转换为 PyTorch 张量(Tensor),并将像素值从 [0, 255] 缩放到 [0, 1]。

2通道维度的变化:
转换后的张量形状为 (C, H, W),即通道维度被移动到最前面。
这种格式是 PyTorch 的标准输入格式,便于后续的模型处理。

  • Normalize: 对图像进行标准化处理(减去均值,除以标准差)。
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

这里的均值和标准差通常是根据数据集计算的(例如 ImageNet 的均值和标准差)。

4、 组合操作
  • Compose: 将多个操作组合成一个流水线。
 transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
示例代码
from torchvision import datasets, transforms

# 定义 transforms 流水线
transform = transforms.Compose([
    transforms.Resize((32, 32)),          # 调整图像大小为 32x32
    transforms.RandomHorizontalFlip(),    # 随机水平翻转
    transforms.ToTensor(),                # 转换为张量,并缩放到 [0, 1]
    transforms.Normalize((0.5,), (0.5,))  # 归一化到 [-1, 1]
])

# 加载 MNIST 数据集并应用 transforms
train_dataset = datasets.MNIST(
    root='./data', 
    train=True, 
    download=True, 
    transform=transform
)

# 查看处理后的图像
image, label = train_dataset[0]
print(image.shape)  # 输出: torch.Size([1, 32, 32])
总结
操作作用
Resize调整图像大小。
CenterCrop从图像中心裁剪指定大小的区域。
RandomCrop随机裁剪图像。
RandomHorizontalFlip随机水平翻转图像。
RandomRotation随机旋转图像。
ColorJitter随机改变图像的亮度、对比度、饱和度和色调。
Grayscale将图像转换为灰度图。
ToTensor将图像转换为张量,并缩放到 [0, 1]。
Normalize对图像进行标准化处理(减去均值,除以标准差)。
Compose将多个操作组合成一个流水线。
如何在数据加载过程中看到图片的样子

先轴交换,再利用make_grid合并再处理成数组.numpy()后,就可以展示出来

from torchvision.datasets import MNIST
import matplotlib.pyplot as plt
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision.utils import make_grid

my_transforms = transforms.Compose(
    [transforms.PILToTensor(),
     ]
)
mnist_train = MNIST(root='./MNIST_data', train=True, download=True, transform=transforms.PILToTensor())
dataloader = DataLoader(mnist_train, batch_size=5, shuffle=True) #DataLoader 初始化
for (image, label) in dataloader:# 遍历 DataLoader
    print(image.shape) #torch.Size([5, 1, 28, 28])
    print(label) #tensor([3, 1, 2, 8, 3])
    print(make_grid(image).shape)  #torch.Size([3, 32, 152])  使用 make_grid 将图像拼接成网格
    image = make_grid(image).permute(1,2,0).numpy()#调整网格图像的维度并转换为 NumPy 数组
    plt.imshow(image) #使用 Matplotlib 显示图像
    plt.show()
    exit()

http://www.kler.cn/a/592851.html

相关文章:

  • Java:Apache HttpClient中HttpRoute用法的介绍
  • Unity导出WebGL,无法加载,data文件无法找到 404(NotFound)
  • 「数据会说话」:让AI成为你的数据分析魔法师 ✨
  • 从零开始写C++3D游戏引擎(开发环境VS2022+OpenGL)之十一点二五 光照贴图(lighting maps)的实现 细嚼慢咽逐条读代码系列
  • 实现鼠标键盘动作录制与重复播放的工具
  • SQL Server数据库慢SQL调优
  • 从 Snowflake 到 Databend Cloud:全球游戏平台借助 Databend 实现实时数据处理
  • 【K8S】ImagePullBackOff状态问题排查。
  • 网络安全——SpringBoot配置文件明文加密
  • 如何把视频转成动态图?
  • 计算机网络-IPv6
  • 2025/03/19 Cursor使用方法(Java方向,适合Java后端把家从idea搬家到cursor)
  • DeepSORT 目标追踪算法详解
  • 数据结构-------栈
  • Java学习打卡-Day19-Set、HashSet、LinkedHashSet
  • C++学习之QT实现取证小软件首页
  • 施耐德PLC仿真软件Modbus tcp通讯测试
  • Python实现爬虫:天气数据抓取(+折线图)
  • 【软件工程】02_软件生命周期模型
  • 【C++入门】数组:从基础到实践