当前位置: 首页 > article >正文

机器学习详解(13):CNN图像数据增强(解决过拟合问题)

在之前的文章卷积神经网络CNN之手语识别代码详解中,我们发现最后的训练和验证损失的曲线的波动非常大,而且验证集的准确率仍然落后于训练集的准确率,这表明模型出现了过拟合现象:在验证数据集测试时,模型对未见过的数据表现不佳。

文章目录

  • 1 数据增强
  • 2 代码流程
    • 2.1 数据准备
    • 2.2 模型创建
      • 2.2.1 卷积块类
      • 2.2.2 使用自定义模块构建模型
    • 2.3 数据增强
      • 2.3.1 RandomResizedCrop
      • 2.3.2 RandomHorizontalFlip
      • 2.3.3 RandomRotation
      • 2.3.4 ColorJitter
      • 2.3.5 Compose
      • 2.3.6 总结
    • 2.4 训练
    • 2.5 保存模型
  • 3 总结

1 数据增强

为了提高模型在新数据上的鲁棒性,我们计划通过编程方式增加数据集的规模和多样性。这被称为数据增强(data augmentation),是深度学习应用中的一种常用技术。

  1. 数据集规模的增加为模型训练提供了更多的图像样本;
  2. 数据的多样性增加有助于模型忽略不重要的特征,仅选择对分类真正重要的特征,从而提高模型的泛化能力。

本篇文章步骤如下:

  • 对ASL(美式手语)数据集进行数据增强;
  • 使用增强后的数据训练改进模型;
  • 保存训练好的模型以便部署使用。

2 代码流程

和之前的文章一样,我们读取ASL手语数据文件,然后进行标签&数据分离、归一化等操作,然后将数据集分为训练和验证两个部分。

使用到的库如下:

import torch.nn as nn
import pandas as pd
import torch
from torch.optim import Adam
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms.v2 as transforms
import torchvision.transforms.functional as F
import matplotlib.pyplot as plt

import utils

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.cuda.is_available()

import torch._dynamo
torch._dynamo.config.suppress_errors = True

2.1 数据准备

这里不再详细说明这里的代码,有不懂的可以参考下面的注释或文章开头提到的文章。

# 定义图像的维度和分类数量
IMG_HEIGHT = 28       # 输入图像的高度
IMG_WIDTH = 28        # 输入图像的宽度
IMG_CHS = 1           # 图像的通道数(1表示灰度图像)
N_CLASSES = 24        # 输出的类别数量

# 从CSV文件加载训练集和验证集
train_df = pd.read_csv("sign_mnist_train.csv")  # 训练集数据
valid_df = pd.read_csv("sign_mnist_valid.csv")  # 验证集数据

# 自定义数据集类
class MyDataset(Dataset):
    def __init__(self, base_df):
        # 拷贝基础数据框并分离标签和特征
        x_df = base_df.copy()                  # 复制数据框,避免修改原始数据
        y_df = x_df.pop('label')               # 提取标签列
        x_df = x_df.values / 255               # 将图像像素值归一化到0到1之间
        x_df = x_df.reshape(-1, IMG_CHS, IMG_WIDTH, IMG_HEIGHT)  # 重塑为适配模型的形状
        self.xs = torch.tensor(x_df).float().to(device)  # 转换为张量并移动到设备(如GPU)
        self.ys = torch.tensor(y_df).to(device)          # 标签转为张量并移动到设备

    def __getitem__(self, idx):
        # 获取指定索引的图像和标签
        x = self.xs[idx]  # 获取图像数据
        y = self.ys[idx]  # 获取标签数据
        return x, y

    def __len__(self):
        # 返回数据集的大小
        return len(self.xs)

# 定义批次大小
n = 32

# 初始化训练数据集和数据加载器
train_data = MyDataset(train_df)                       # 创建训练数据集
train_loader = DataLoader(train_data, batch_size=n, shuffle=True)  # 创建训练数据加载器并打乱顺序
train_N = len(train_loader.dataset)                    # 获取训练集的样本数量

# 初始化验证数据集和数据加载器
valid_data = MyDataset(valid_df)                       # 创建验证数据集
valid_loader = DataLoader(valid_data, batch_size=n)    # 创建验证数据加载器
valid_N = len(valid_loader.dataset)                    # 获取验证集的样本数量

2.2 模型创建

2.2.1 卷积块类

通过前面的学习,我们知道卷积神经网络(CNN)使用重复的层序列。我们可以利用这一模式,创建一个自定义卷积模块,并将其作为一个层嵌入到Sequential模型中。

为了实现这一目标,我们将扩展Module类,并定义两个方法:

  1. __init__ 方法:定义模块的属性,包括我们需要的神经网络层。在这里,我们可以实现“模型中的模型”。
  2. forward 方法:定义模块如何处理来自前一层的输入数据。由于我们使用Sequential模型,可以直接将输入数据传入模块,类似于执行预测。
class MyConvBlock(nn.Module):
    def __init__(self, in_ch, out_ch, dropout_p):
        kernel_size = 3  # 卷积核大小
        super().__init__()

        # 定义模块内的层结构
        self.model = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, kernel_size, stride=1, padding=1),  # 卷积层
            nn.BatchNorm2d(out_ch),  # 批归一化
            nn.ReLU(),               # 激活函数
            nn.Dropout(dropout_p),   # Dropout防止过拟合
            nn.MaxPool2d(2, stride=2)  # 最大池化层
        )

    def forward(self, x):
        # 定义数据流动逻辑
        return self.model(x)

2.2.2 使用自定义模块构建模型

现在我们定义了自定义模块,可以将其应用到实际模型中:

flattened_img_size = 75 * 3 * 3  # 扁平化后图像的大小

# 构建Sequential模型
base_model = nn.Sequential(
    MyConvBlock(IMG_CHS, 25, 0),     # 输入:1x28x28,输出:25x14x14
    MyConvBlock(25, 50, 0.2),       # 输出:50x7x7
    MyConvBlock(50, 75, 0),         # 输出:75x3x3
    nn.Flatten(),                   # 扁平化
    nn.Linear(flattened_img_size, 512),  # 全连接层
    nn.Dropout(0.3),                # Dropout
    nn.ReLU(),                      # 激活函数
    nn.Linear(512, N_CLASSES)       # 输出层
)

损失函数和优化器如下:

loss_function = nn.CrossEntropyLoss() 	   # 定义交叉熵损失函数
optimizer = Adam(base_model.parameters())  # 定义Adam优化器

最后编译模型,PyTorch将对模型做出优化:

model = torch.compile(base_model.to(device))  # 使用PyTorch 2.0编译模型
model

输出:
OptimizedModule(
  (_orig_mod): Sequential(
    (0): MyConvBlock(
      (model): Sequential(
        (0): Conv2d(1, 25, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): BatchNorm2d(25, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU()
        (3): Dropout(p=0, inplace=False)
        (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      )
    )
    (1): MyConvBlock(
      (model): Sequential(
        (0): Conv2d(25, 50, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): BatchNorm2d(50, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU()
        (3): Dropout(p=0.2, inplace=False)
        (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      )
    )
    (2): MyConvBlock(
      (model): Sequential(
        (0): Conv2d(50, 75, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): BatchNorm2d(75, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU()
        (3): Dropout(p=0, inplace=False)
        (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      )
    )
    (3): Flatten(start_dim=1, end_dim=-1)
    (4): Linear(in_features=675, out_features=512, bias=True)
    (5): Dropout(p=0.3, inplace=False)
    (6): ReLU()
    (7): Linear(in_features=512, out_features=24, bias=True)
  )
)

2.3 数据增强

在定义训练循环之前,需要先设置数据增强(Data Augmentation)。我们将通过 TorchVision 的 Transforms 工具进一步探索数据增强的方法。以下是具体的步骤:

获取测试图像

首先,从数据集中提取一张示例图片用于测试:

row_0 = train_df.head(1)
y_0 = row_0.pop('label')
x_0 = row_0.values / 255
x_0 = x_0.reshape(IMG_CHS, IMG_WIDTH, IMG_HEIGHT)
x_0 = torch.tensor(x_0)
x_0.shape  # 输出torch.Size([1, 28, 28])

image = F.to_pil_image(x_0)
plt.imshow(image, cmap='gray')

图片如下:

在这里插入图片描述

2.3.1 RandomResizedCrop

  • 功能:随机调整图像大小并裁剪为指定尺寸。
  • 原理:根据比例调整输入图像大小后裁剪。

代码示例

# 定义随机调整大小并裁剪的变换
trans = transforms.Compose([
    transforms.RandomResizedCrop((IMG_WIDTH, IMG_HEIGHT), scale=(.7, 1), ratio=(1, 1)),
])

# 应用变换并生成新的张量
new_x_0 = trans(x_0)
# 转换为 PIL 图像以便显示
image = F.to_pil_image(new_x_0)
plt.imshow(image, cmap='gray')  # 显示裁剪后的图像
new_x_0.shape  # 输出torch.Size([1, 28, 28])
  • scale=(.7, 1):指定裁剪区域相对于原始图像面积的比例范围。

    • .7 表示裁剪区域最小可以是原始图像面积的 70%。

    • 1 表示裁剪区域最大可以是原始图像的 100%。

  • ratio=(1, 1):指定裁剪区域的宽高比范围。

    • (1, 1) 表示裁剪区域的宽高比固定为 1:1,即裁剪区域总是正方形

输出如下:

在这里插入图片描述

2.3.2 RandomHorizontalFlip

功能:随机水平翻转图像。

思考

  • 水平翻转可以增强数据多样性,但垂直翻转可能破坏语义。
  • 例如,美国手语(ASL)可以左右手互换,但上下翻转的情况很少出现。
# 定义随机水平翻转的变换
trans = transforms.Compose([
    transforms.RandomHorizontalFlip()
])

# 应用变换并生成新的张量
new_x_0 = trans(x_0)
# 转换为 PIL 图像以便显示
image = F.to_pil_image(new_x_0)
plt.imshow(image, cmap='gray')  # 显示翻转后的图像

输出:

在这里插入图片描述

2.3.3 RandomRotation

功能:随机旋转图像以增加多样性。

注意

  • 旋转角度过大可能导致图像的语义信息被误读。
  • 例如,手语中的“D”可能被错误解读为“G”。因此,需限制旋转范围。

代码解释

# 定义随机旋转的变换,限制在 ±10 度内
trans = transforms.Compose([
    transforms.RandomRotation(10)
])

# 应用变换并生成新的张量
new_x_0 = trans(x_0)
# 转换为 PIL 图像以便显示
image = F.to_pil_image(new_x_0)
plt.imshow(image, cmap='gray')  # 显示旋转后的图像

输出:

在这里插入图片描述

2.3.4 ColorJitter

功能:随机调整图像的亮度和对比度。用来模拟光照条件的变化

参数

  • brightness:控制亮度变化的幅度(0~1)。
  • contrast:控制对比度变化的幅度(0~1)。

代码解释

# 设置亮度和对比度参数
brightness = .2  # 亮度变化范围
contrast = .5  # 对比度变化范围

# 定义亮度和对比度调整的变换
trans = transforms.Compose([
    transforms.ColorJitter(brightness=brightness, contrast=contrast)
])

# 应用变换并生成新的张量
new_x_0 = trans(x_0)
# 转换为 PIL 图像以便显示
image = F.to_pil_image(new_x_0)
plt.imshow(image, cmap='gray')  # 显示调整后的图像

输出:

在这里插入图片描述

2.3.5 Compose

功能:组合多个随机变换为一个序列,使数据增强更加灵活。

应用场景:结合多种增强手段,生成无限多样的数据变体。

代码解释

# 定义一组随机变换
random_transforms = transforms.Compose([
    transforms.RandomRotation(5),  # 小范围旋转
    transforms.RandomResizedCrop((IMG_WIDTH, IMG_HEIGHT), scale=(.9, 1), ratio=(1, 1)),  # 随机裁剪
    transforms.RandomHorizontalFlip(),  # 随机水平翻转
    transforms.ColorJitter(brightness=.2, contrast=.5)  # 调整亮度和对比度
])

# 应用组合变换并生成新的张量
new_x_0 = random_transforms(x_0)
# 转换为 PIL 图像以便显示
image = F.to_pil_image(new_x_0)
plt.imshow(image, cmap='gray')  # 显示综合增强后的图像

输出:

在这里插入图片描述

2.3.6 总结

增强方法功能关键参数作用注意事项
RandomResizedCrop随机调整图像大小并裁剪为指定尺寸- scale=(.7, 1): 裁剪区域面积占原图面积的比例范围, ratio=(1, 1): 裁剪区域宽高比范围 (正方形)生成不同大小的随机裁剪区域,增强图像的局部视角多样性如果设置不当,可能裁剪掉重要的图像信息
RandomHorizontalFlip随机水平翻转图像- p=0.5(默认):翻转的概率增强数据多样性,适合左右对称的场景,如手语、自然景物等不适用于语义上不对称的图像(如文字)
RandomRotation随机旋转图像- degrees=10: 旋转角度范围(±10度)增加图像的角度变化,模拟不同视角的情况旋转角度过大可能引起语义混淆,例如手语符号“D”和“G”可能被误读
ColorJitter调整图像亮度和对比度brightness=0.2: 亮度变化范围, contrast=0.5: 对比度变化范围模拟光照变化条件,增强数据在不同环境下的适应能力如果参数设置过大,可能导致图像失真

2.4 训练

我们的训练流程与之前大致相同,但有一处不同:在将图像输入模型之前,需应用前面我们定义的 random_transforms 进行数据增强。

def train():
    loss = 0  # 初始化损失值
    accuracy = 0  # 初始化准确率

    model.train()  # 设置模型为训练模式
    for x, y in train_loader:  # 遍历训练数据
        output = model(random_transforms(x))  # 应用随机数据增强并将图像输入模型
        optimizer.zero_grad()  # 清除梯度
        batch_loss = loss_function(output, y)  # 计算当前批次的损失
        batch_loss.backward()  # 反向传播计算梯度
        optimizer.step()  # 更新模型权重

        # 累加损失和准确率
        loss += batch_loss.item()
        accuracy += get_batch_accuracy(output, y, train_N)
    # 打印训练结果
    print('Train - Loss: {:.4f} Accuracy: {:.4f}'.format(loss, accuracy))

下面是训练和验证集的损失曲线和准确率曲线:

在这里插入图片描述

在这里插入图片描述

对比之前的损失曲线,明显在一定程度上优化了过拟合的问题。

2.5 保存模型

现在我们已经拥有一个训练良好的模型,可以将其部署用于对新图像进行推断。当我们对训练的模型感到满意时,通常会将其保存到磁盘。PyTorch 提供了多种方法来实现这一点,这里我们将使用torch.save方法。在下一篇文章中,我们会加载该模型,并使用它来识别新的手语图片。

需要注意的是,PyTorch无法保存已编译的模型(参考文章)

  • 在 PyTorch 中,model(编译后的模型)和base_model(未编译的模型)是共享权重的。因此,无论你对 model 进行训练,还是对其权重进行修改,这些变化都会同步到 base_model,因为它们底层实际上引用了同一组权重。

因此我们将使用以下代码保存模型:

torch.save(base_model, 'model.pth')

3 总结

数据增强通过随机裁剪、翻转、旋转和颜色抖动等手段,生成多样化的训练样本,有效提高了模型的泛化能力,减少了过拟合的风险。在训练阶段引入随机变换,可以模拟不同的场景变化,使模型更具鲁棒性。同时,通过合理设置增强参数,确保数据增强不会破坏图像的语义信息,从而在性能与稳定性之间取得平衡。


http://www.kler.cn/a/467210.html

相关文章:

  • 【数学建模笔记】评价模型-基于熵权法的TOPSIS模型
  • Nginx (40分钟学会,快速入门)
  • 在Linux中,如何查看和修改网络接口配置?
  • 单片机-串转并-74HC595芯片
  • 成都和力九垠科技有限公司九垠赢系统Common存在任意文件上传漏洞
  • 旧服务改造及微服务架构演进
  • [读书日志]从零开始学习Chisel 第一篇:书籍介绍,Scala与Chisel概述,Scala安装运行(敏捷硬件开发语言Chisel与数字系统设计)
  • 博时基金宋和文:以责任之名,共筑公益梦想
  • 每日十题八股-2025年1月4日
  • 第 31 章 - 源码篇 - Elasticsearch 写入流程深入分析
  • 【学习总结|DAY027】JAVA操作数据库
  • Flume拦截器的实现
  • 笔记本电脑扩展的显示器如何左右或上下分屏显示?
  • springboot使用hutool captcha +vue实现图形验证码
  • 智能新纪元:代理AI的崛起与未来
  • ArcGIS JSAPI 高级教程 - 通过RenderNode实现视频融合效果(不借助三方工具)
  • 【GeekBand】C++设计模式笔记24_Visitor_访问器
  • 爬虫案例-爬取某度文档
  • 洛谷B4071 [GESP202412 五级] 武器强化
  • Java 数据库连接 - Sqlite
  • 解决openpyxl操纵带公式的excel或者csv之后,pandas无法读取数值的问题
  • 基于PHP+MySQL实现的web端借还书系统
  • android studio老版本下载教程
  • 【AI学习】Transformer深入学习(二):从MHA、MQA、GQA到MLA
  • 阿里云-通义灵码:在 PyCharm 中的强大助力(下)
  • 急需升级,D-Link 路由器漏洞被僵尸网络广泛用于 DDoS 攻击