当前位置: 首页 > article >正文

第十三站:卷积神经网络(CNN)的优化

前言:在上一期我们构建了基本的卷积神经网络之后,接下来我们将学习一些提升网络性能的技巧和方法。这些优化技术包括 数据增强网络架构的改进正则化技术

1. 数据增强(Data Augmentation)

数据增强是提升深度学习模型泛化能力的一种常见手段。通过对训练数据进行各种随机变换,可以生成更多的训练样本,帮助模型避免过拟合。

常见的数据增强方法:
  1. 旋转(Rotation):随机旋转图像,增强模型对旋转变换的鲁棒性。
  2. 翻转(Flipping):随机水平或垂直翻转图像。
  3. 裁剪(Cropping):随机裁剪图像的某一部分。
  4. 平移(Translation):对图像进行随机平移。
  5. 改变亮度、对比度、饱和度(Brightness, Contrast, Saturation):改变图像的光照和颜色,使模型更加鲁棒。
代码示例:使用 PyTorch 实现数据增强
from torchvision import transforms

# 定义数据增强的变换过程
transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),  # 随机水平翻转
    transforms.RandomRotation(15),  # 随机旋转,角度范围为 -15 到 15 度
    transforms.RandomCrop(32, padding=4),  # 随机裁剪,裁剪大小为 32,边缘加 4 像素的填充
    transforms.ToTensor(),  # 转换为张量
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # 归一化
])
  • transforms.RandomHorizontalFlip():随机水平翻转图像,有助于增加训练样本的多样性。
  • transforms.RandomRotation(15):随机旋转图像,旋转角度在 -15 到 15 度之间。
  • transforms.RandomCrop(32, padding=4):随机裁剪图像的部分区域并填充边缘,以获得不同的视角。
  • transforms.ToTensor():将图像从 PIL 格式转换为 PyTorch 的 Tensor 格式。
  • transforms.Normalize():对图像进行标准化,使其均值和标准差分别为 0.5。
2. 网络架构的改进

卷积神经网络可以通过调整网络的层数、卷积核大小、通道数等来改进其性能。以下是一些常见的改进方式:

  1. 增加卷积层的数量

    • 更深的网络能够提取更多的特征信息。通过增加卷积层数,可以让网络学习到更高级别的特征。
  2. 增加卷积核的数量

    • 增加每个卷积层中的卷积核数量(通道数),使得每个卷积层能够提取更多的特征。
  3. 使用较大的卷积核

    • 使用 5x5 或 7x7 的卷积核比 3x3 的卷积核能捕获更大的特征区域,但会增加计算量。
  4. 使用深度可分离卷积(Depthwise Separable Convolution)

    • 深度可分离卷积通过将卷积操作拆解为两步(深度卷积和逐点卷积),减少了参数量和计算量。
  5. 使用更高级的激活函数

    • 例如,Leaky ReLUELU(Exponential Linear Unit) 可以避免 ReLU 激活函数的“死神经元”问题。
代码示例:添加更多卷积层
class ImprovedCNN(nn.Module):
    def __init__(self):
        super(ImprovedCNN, self).__init__()
        # 第一层卷积层
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        # 第二层卷积层
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        # 第三层卷积层
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        
        # 池化层
        self.pool = nn.MaxPool2d(2, 2)
        
        # 全连接层
        self.fc1 = nn.Linear(128 * 8 * 8, 512)
        self.fc2 = nn.Linear(512, 10)
    
    def forward(self, x):
        # 卷积层 + 激活函数 + 池化
        x = self.pool(torch.relu(self.conv1(x)))
        x = self.pool(torch.relu(self.conv2(x)))
        x = self.pool(torch.relu(self.conv3(x)))
        
        # 展平数据
        x = x.view(-1, 128 * 8 * 8)
        
        # 全连接层
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x
  • 增加了更多的卷积层:从 16 个卷积核增加到 128 个卷积核,可以捕捉更复杂的特征。
  • 通过池化层减小图像尺寸:每经过一个卷积层后,都通过池化层来降低特征图的维度。
3. 正则化技术(Regularization Techniques)

正则化是防止模型过拟合的关键。以下是几种常见的正则化技术:

  1. Dropout

    • Dropout 随机丢弃一部分神经元,避免模型依赖于某些特定神经元,增加模型的泛化能力。
  2. L2 正则化(权重衰减)

    • 在损失函数中加入权重的平方和(L2 范数),惩罚模型中的大权重,防止模型变得过于复杂。
代码示例:加入 Dropout 层
class CNNWithDropout(nn.Module):
    def __init__(self):
        super(CNNWithDropout, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(32 * 8 * 8, 512)
        self.fc2 = nn.Linear(512, 10)
        self.dropout = nn.Dropout(p=0.5)  # Dropout 层,丢弃 50% 的神经元
    
    def forward(self, x):
        x = self.pool(torch.relu(self.conv1(x)))
        x = self.pool(torch.relu(self.conv2(x)))
        x = x.view(-1, 32 * 8 * 8)
        x = torch.relu(self.fc1(x))
        x = self.dropout(x)  # 在全连接层前应用 Dropout
        x = self.fc2(x)
        return x
  • self.dropout = nn.Dropout(p=0.5):在全连接层之前添加 Dropout 层,丢弃一半神经元。

注:针对于以上更多修改,大家可以修改参数调试观察更多不同的效果,从而使得自己有一个对优化的大概了解


http://www.kler.cn/a/569606.html

相关文章:

  • Elasticsearch 的分布式架构原理:通俗易懂版
  • Linux的OOM机制
  • Ubuntu 下 nginx-1.24.0 源码分析 - ngx_destroy_pool 函数
  • LSTM预测模型复现笔记和问题记录
  • 第10篇:文件IO与数据持久化(下)(JSON、二进制文件)
  • Junit框架缺点
  • 神经网络之词嵌入模型(基于torch api调用)
  • Vue3 中 defineOptions 学习指南
  • Docker-CE的部署、国内镜像加速
  • Redis(八):Redis分布式锁实现
  • 深入了解 K-Means 聚类算法:原理与应用
  • 介绍 torch-mlir 从 pytorch 生态到 mlir 生态
  • Android Binder 用法详解
  • 智能AI替代专家系统(ES)、决策支持系统(DSS)?
  • SpringDoc和Swagger使用
  • 深入理解并解析C++ stl::vector
  • MySQL 中如何查看 SQL 的执行计划?
  • 部署Joplin私有云服务器postgres版-docker compose
  • 1JVM概念
  • C# 上位机---INI 文件