当前位置: 首页 > article >正文

Pytorch | 从零构建AlexNet对CIFAR10进行分类

Pytorch | 从零构建AlexNet对CIFAR10进行分类

  • CIFAR10数据集
  • AlexNet
    • 网络结构
    • 技术创新点
    • 性能表现
    • 影响和意义
  • AlexNet结构代码详解
    • 结构代码
    • 代码详解
      • 特征提取层 self.features
      • 分类部分self.classifier
      • 前向传播forward
  • 训练过程和测试结果
  • 代码汇总
    • alexnet.py
    • train.py
    • test.py

CIFAR10数据集

CIFAR-10数据集是由加拿大高级研究所(CIFAR)收集整理的用于图像识别研究的常用数据集,基本信息如下:

  • 数据规模:该数据集包含60,000张彩色图像,分为10个不同的类别,每个类别有6,000张图像。通常将其中50,000张作为训练集,用于模型的训练;10,000张作为测试集,用于评估模型的性能。
  • 图像尺寸:所有图像的尺寸均为32×32像素,这相对较小的尺寸使得模型在处理该数据集时能够相对快速地进行训练和推理,但也增加了图像分类的难度。
  • 类别内容:涵盖了飞机(plane)、汽车(car)、鸟(bird)、猫(cat)、鹿(deer)、狗(dog)、青蛙(frog)、马(horse)、船(ship)、卡车(truck)这10个不同的类别,这些类别都是现实世界中常见的物体,具有一定的代表性。

下面是一些示例样本:
在这里插入图片描述

AlexNet

AlexNet是由Alex Krizhevsky、Ilya Sutskever和Geoffrey Hinton在2012年提出的一种深度卷积神经网络,在ImageNet图像识别挑战赛中取得了巨大成功,推动了深度学习在计算机视觉领域的快速发展。以下是对它的详细介绍:

网络结构

  • 卷积层:包含5个卷积层,这些卷积层通过不同的卷积核大小、步长和填充方式,逐步提取图像的特征。
  • 池化层:有3个最大池化层,用于减小特征图的尺寸,同时保留关键特征,减少计算量和过拟合风险。
  • 全连接层:包括3个全连接层,用于对提取的特征进行分类,最后一层输出分类结果。
    在这里插入图片描述
    上图为AlexNet原文中的网络结构(针对ImageNet,图片尺寸为224×224),本文是针对CIFAR10,其尺寸为32×32,因此结构不太相同,比如卷积核的大小,具体可以参考下面的代码。

技术创新点

  • ReLU激活函数:使用ReLU(Rectified Linear Unit)作为激活函数,解决了传统激活函数在深度网络中梯度消失的问题,加快了训练速度。
  • Dropout正则化:在全连接层中使用了Dropout技术,随机丢弃部分神经元,防止过拟合,提高模型的泛化能力。
  • 重叠池化:采用重叠池化(Overlapping Pooling),即池化窗口之间有重叠,有助于提取更多的特征信息,提升模型的性能。
  • 多GPU训练:首次利用多GPU进行并行训练,大大提高了训练速度,使得在大规模数据集上训练深度网络成为可能。

性能表现

  • 在ImageNet数据集上,AlexNet的top-5错误率大幅降低至15.3%,相比之前的方法有了显著提升,展示了其强大的图像识别能力。
  • 能够学习到丰富的图像特征,对不同类别的物体具有很好的区分能力,在实际应用中取得了很好的效果。

影响和意义

  • 推动深度学习发展:AlexNet的成功引起了学术界和工业界对深度学习的广泛关注,激发了更多研究人员对深度神经网络的研究兴趣,推动了深度学习技术的快速发展。
  • 开启卷积神经网络新时代:为后续的卷积神经网络研究提供了重要的参考和借鉴,许多新的网络结构和技术都是在AlexNet的基础上发展而来的。
  • 拓展应用领域:由于其在图像识别任务上的出色表现,AlexNet及其改进模型被广泛应用于计算机视觉的各个领域,如目标检测、图像分割、人脸识别等。

AlexNet结构代码详解

结构代码

import torch
import torch.nn as nn


class AlexNet(nn.Module):
    def __init__(self, num_classes):
        super(AlexNet, self).__init__()
        self.features = nn.Sequential(
            # input size: (B, 3, 32, 32)   (Batch_size, Channel, Height, Width)
            nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1), # (B, 64, 16, 16)
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2),    # (B, 64, 8, 8)
            nn.Conv2d(64, 192, kernel_size=3, padding=1),   # (B, 192, 8, 8)
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2),    # (B, 192, 4, 4)
            nn.Conv2d(192, 384, kernel_size=3, padding=1),  # (B, 384, 4, 4)
            nn.ReLU(inplace=True),
            nn.Conv2d(384, 256, kernel_size=3, padding=1),  # (B, 256, 4, 4)
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),  # (B, 256, 4, 4)
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2),    # (B, 256, 2, 2)
        )
        self.classifier = nn.Sequential(
            nn.Dropout(),
            nn.Linear(256 * 2 * 2, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(),
            nn.Linear(4096, 4096),
            nn.ReLU(inplace=True),
            nn.Linear(4096, num_classes)
        )

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), 256 * 2 *2)
        x = self.classifier(x)
        return x

代码详解

以下是对上述AlexNet代码的详细解释:

特征提取层 self.features

这部分构建了AlexNet的特征提取层,是一个由多个层组成的顺序结构(通过nn.Sequential来定义)。
- 第一个卷积层
nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1)表示输入图像的通道数为3(通常对应RGB图像的红、绿、蓝三个通道),输出的通道数为64(即卷积核的数量为64,意味着会生成64个不同的特征图),卷积核大小是3×3,步长为2(在空间维度上每次移动2个像素),填充为1(在图像边缘进行1个像素的填充,这样可以保证输入输出图像尺寸在卷积操作下能按预期变化),经过这个卷积层后,输入尺寸为(B, 3, 32, 32)的图像数据会变成(B, 64, 16, 16)
- 激活函数层
nn.ReLU(inplace=True)是使用修正线性单元(Rectified Linear Unit)作为激活函数,inplace=True表示直接在输入的张量上进行修改(节省内存空间),对经过卷积后的特征图进行非线性变换,增强网络的表达能力。
- 池化层
nn.MaxPool2d(kernel_size=2)是最大池化层,池化核大小为2×2,它会在每个2×2的窗口内选取最大值作为输出,起到下采样的作用,减少数据量同时保留重要特征,比如经过第一次池化后特征图尺寸从(B, 64, 16, 16)变为(B, 64, 8, 8)

后续依次重复卷积、激活、池化等操作,不断提取图像的特征,逐步降低特征图的尺寸同时增加特征图的深度(通道数),最终经过这一系列操作后得到尺寸为(B, 256, 2, 2)的特征图。

分类部分self.classifier

self.classifier = nn.Sequential(
    nn.Dropout(),
    nn.Linear(256 * 2 * 2, 4096),
    nn.ReLU(inplace=True),
    nn.Dropout(),
    nn.Linear(4096, 4096),
    nn.ReLU(inplace=True),
    nn.Linear(4096, num_classes)
)

这部分构建了AlexNet的分类器,同样是顺序结构。
- Dropout层
nn.Dropout()是一种正则化技术,在训练过程中以一定概率(默认0.5)随机将神经元的输出设置为0,防止过拟合,提高模型的泛化能力。这里使用了两次Dropout,分别在不同的全连接层之前。
- 全连接层
第一个nn.Linear(256 * 2 * 2, 4096)表示将经过特征提取后展平的特征向量(尺寸为256 * 2 * 2,因为前面特征提取部分最后得到的特征图尺寸是(B, 256, 2, 2),展平后维度就是256 * 2 * 2)映射到一个4096维的向量空间,后面接着激活函数nn.ReLU(inplace=True)进行非线性变换。然后又是一个Dropout层和一个同样输出维度为4096的全连接层以及相应的激活函数,最后通过nn.Linear(4096, num_classes)将4096维的向量映射到指定的类别数(num_classes)维度,得到最终的分类预测结果。

前向传播forward

def forward(self, x):
    x = self.features(x)
    x = x.view(x.size(0), 256 * 2 *2)
    x = self.classifier(x)
    return x

forward方法定义了数据在网络中的前向传播过程。

  • 特征提取
    首先x = self.features(x),将输入数据x送入到之前定义的特征提取部分(features),按照特征提取层中定义的卷积、激活、池化等操作依次对输入数据进行处理,得到提取后的特征图。
  • 特征图展平
    x = x.view(x.size(0), 256 * 2 *2)这行代码将特征图进行展平操作,使其变成一个二维张量,其中第一维对应批次大小(x.size(0)表示批次中的样本数量),第二维就是展平后的特征向量长度(由前面特征提取最后得到的特征图尺寸计算得出),这样才能输入到后面的全连接层中进行分类处理。
  • 分类预测
    最后x = self.classifier(x)将展平后的特征向量送入分类器部分(classifier),经过全连接层、激活函数、Dropout等操作逐步得到最终的分类预测结果,然后通过return x返回这个预测结果。

训练过程和测试结果

训练过程损失函数变化曲线:
在这里插入图片描述
在这里插入图片描述
训练过程准确率变化曲线:
在这里插入图片描述
测试结果:
在这里插入图片描述

代码汇总

项目github地址
项目结构:

|--data
|--models
	|--__init__.py
	|--alexnet.py
|--results
|--weights
|--train.py
|--test.py

alexnet.py

import torch
import torch.nn as nn


class AlexNet(nn.Module):
    def __init__(self, num_classes):
        super(AlexNet, self).__init__()
        self.features = nn.Sequential(
            # input size: (B, 3, 32, 32)   (Batch_size, Channel, Height, Width)
            nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1), # (B, 64, 16, 16)
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2),    # (B, 64, 8, 8)
            nn.Conv2d(64, 192, kernel_size=3, padding=1),   # (B, 192, 8, 8)
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2),    # (B, 192, 4, 4)
            nn.Conv2d(192, 384, kernel_size=3, padding=1),  # (B, 384, 4, 4)
            nn.ReLU(inplace=True),
            nn.Conv2d(384, 256, kernel_size=3, padding=1),  # (B, 256, 4, 4)
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),  # (B, 256, 4, 4)
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2),    # (B, 256, 2, 2)
        )
        self.classifier = nn.Sequential(
            nn.Dropout(),
            nn.Linear(256 * 2 * 2, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(),
            nn.Linear(4096, 4096),
            nn.ReLU(inplace=True),
            nn.Linear(4096, num_classes)
        )

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), 256 * 2 *2)
        x = self.classifier(x)
        return x

train.py

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from models import AlexNet
import matplotlib.pyplot as plt

import ssl
ssl._create_default_https_context = ssl._create_unverified_context

# 定义数据预处理操作
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.491, 0.482, 0.446), (0.247, 0.243, 0.261))])

# 加载CIFAR10训练集
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=False, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128,
                                          shuffle=True, num_workers=2)

# 定义设备(GPU优先,若可用)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 实例化模型
model = AlexNet(num_classes=10).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# 训练轮次
epochs = 15

def train(model, trainloader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data[0].to(device), data[1].to(device)

        optimizer.zero_grad()

        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()

    epoch_loss = running_loss / len(trainloader)
    epoch_acc = 100. * correct / total
    return epoch_loss, epoch_acc

if __name__ == "__main__":
    loss_history, acc_history = [], []
    for epoch in range(epochs):
        train_loss, train_acc = train(model, trainloader, criterion, optimizer, device)
        print(f'Epoch {epoch + 1}: Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%')
        loss_history.append(train_loss)
        acc_history.append(train_acc)
        # 保存模型权重,每5轮次保存到weights文件夹下
        if (epoch + 1) % 5 == 0:
            torch.save(model.state_dict(), f'weights/alexnet_epoch_{epoch + 1}.pth')
    # 绘制损失曲线
    plt.plot(range(1, epochs+1), loss_history, label='Loss', marker='o')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Training Loss Curve')
    plt.legend()
    plt.savefig('results\\train_loss_curve.png')
    plt.close()

    # 绘制准确率曲线
    plt.plot(range(1, epochs+1), acc_history, label='Accuracy', marker='o')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy (%)')
    plt.title('Training Accuracy Curve')
    plt.legend()
    plt.savefig('results\\train_acc_curve.png')
    plt.close()

test.py

import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from models import AlexNet

import ssl
ssl._create_default_https_context = ssl._create_unverified_context
# 定义数据预处理操作
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.491, 0.482, 0.446), (0.247, 0.243, 0.261))])

# 加载CIFAR10测试集
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=False, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=128,
                                         shuffle=False, num_workers=2)

# 定义设备(GPU优先,若可用)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 实例化模型
model = AlexNet(num_classes=10).to(device)
criterion = nn.CrossEntropyLoss()

# 加载模型权重
weights_path = "weights/alexnet_epoch_15.pth"  
model.load_state_dict(torch.load(weights_path, map_location=device))

def test(model, testloader, criterion, device):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        for data in testloader:
            inputs, labels = data[0].to(device), data[1].to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)

            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()

    epoch_loss = running_loss / len(testloader)
    epoch_acc = 100. * correct / total
    return epoch_loss, epoch_acc

if __name__ == "__main__":
    test_loss, test_acc = test(model, testloader, criterion, device)
    print("================AlexNet Test================")
    print(f"Load Model Weights From: {weights_path}")
    print(f'Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.2f}%')

http://www.kler.cn/a/444582.html

相关文章:

  • uniapp自定义树型结构数据弹窗,给默认选中的节点,禁用所有子节点
  • 基于LSB最低有效位的音频水印嵌入提取算法FPGA实现,包含testbench和MATLAB对比
  • YOLOv8全解析:高效、精准的目标检测新时代——创新架构与性能提升
  • 【AI驱动的数据结构:包装类的艺术与科学】
  • 微信小程序:轻应用的未来与无限可能
  • 使用计算机创建一个虚拟世界
  • 鸿蒙项目云捐助第十五讲云数据库的初步使用
  • linux CentOS系统上卸载Kubernetes(k8s)
  • druid与pgsql结合踩坑记
  • js 算法
  • Excel根据身份证号,计算退休日期和剩余天数!
  • Qt-Advanced-Docking-System配置及使用、心得
  • 第十二课 Unity 内存优化_内存工具篇(Memory)详解
  • 【论文阅读】Trigger Hunting with a Topological Prior for Trojan Detection
  • PostgreSQL17.x数据库备份命令及语法说明
  • Facebook 对社交互动的革新与启示
  • 使用Flinkcdc 采集mysql数据
  • Swift 的动态性
  • package.json中版本管理的标识有哪些
  • 欢乐堡游乐园信息管理系统的设计与实现(Django Python MySQL)+文档
  • Express (nodejs) 相关
  • 手机无法连接电脑,如何解决(快速排除手机与电脑连接问题的方法)
  • 【2024版】超详细Python+Pycharm安装保姆级教程,Python环境配置和使用指南,看完这一篇就够了
  • 深度学习之目标检测篇——残差网络与FPN结合
  • 007 Qt_按钮类控件
  • docker如何学习与使用入门