当前位置: 首页 > article >正文

深度学习11. CNN经典网络 LeNet-5实现CIFAR-10

深度学习11. CNN经典网络 LeNet-5实现CIFAR-10

  • 一、CIFAR-10介绍
  • 二、PyTorch的 transforms 介绍
  • 三、实现步骤
    • 1. 准备数据
    • 2. 模型定义
    • 3. 训练与测试
  • 四、完整代码

在这里插入图片描述

本文在前节程序基础上,实现对CIFAR-10的训练与测试,以加深对LeNet-5网络的理解 。

首先,要了解LeNet-5并不适合训练 CIFAR-10 , 最后的正确率不会太理想 。

一、CIFAR-10介绍

CIFAR-10是一个常用的图像分类数据集,由10类共计60,000张32x32大小的彩色图像组成,每类包含6,000张图像。这些图像被平均分为了5个训练批次和1个测试批次,每个批次包含10,000张图像。CIFAR-10数据集中的10个类别分别为:飞机、汽车、鸟类、猫、鹿、狗、青蛙、马、船和卡车。

相比之下,MNIST是一个手写数字分类数据集,由10个数字(0-9)共计60,000个训练样本和10,000个测试样本组成,每个样本是一个28x28的灰度图像。

与MNIST相比,CIFAR-10更具挑战性,因为它是一个彩色图像数据集,每张图像包含更多的信息和细节,难度更高。此外,CIFAR-10的类别也更加多样化,更加贴近实际应用场景。因此,CIFAR-10更适合用于测试和评估具有更高难度的图像分类模型,而MNIST则更适合用于介绍和入门级别的模型训练和测试。

二、PyTorch的 transforms 介绍

PyTorch中的transforms是用于对数据进行预处理和增强的工具,主要用于图像数据的处理,它可以方便地对数据进行转换,使其符合神经网络的输入要求。

transforms的方法:

  • ToTensor : 将数据转换为PyTorch中的张量格式。
  • Normalize:对数据进行标准化,使其均值为0,方差为1,以便网络更容易训练。
  • Resize:调整图像大小。
  • RandomCrop:随机裁剪图像的一部分。
  • CenterCrop:从图像的中心裁剪出一部分。
  • RandomHorizontalFlip :以一定的概率随机水平翻转图像,以增加训练集的多样性。
  • RandomVerticalFlip:以一定的概率随机垂直翻转图像,以增加训练集的多样性。
  • RandomRotation:以一定的概率随机旋转图像。
  • ColorJitter:随机调整图像的亮度、对比度、饱和度和色调。
  • RandomErasing:随机擦除图像中的一部分区域,以增加训练集的多样性。

使用transforms可以方便地进行数据预处理和增强,提高模型的鲁棒性和泛化能力。在实际应用中,可以根据具体问题和需求进行选择和组合。

三、实现步骤

1. 准备数据

下面定义加载CIFAR-10数据集,首先会对图片进行一些处理:

  • transforms.RandomHorizontalFlip():随机水平翻转图像
  • transforms.RandomCrop(32, padding=4):随机裁剪图像,大小为32x32,边缘填充4个像素
  • transforms.ToTensor():将图像转换为张量,并归一化到[0,1]范围内
  • transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)):将张量标准化,使每个通道的均值为0.5,标准差为0.5

对数据的处理可以增加数据的多样性和丰富性,以提高神经网络的泛化能力和准确率。

transform = transforms.Compose(
    [transforms.RandomHorizontalFlip(),
     transforms.RandomCrop(32, padding=4),
     transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(trainset, batch_size=32,
                                          shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(testset, batch_size=32,
                                         shuffle=False, num_workers=2)

2. 模型定义

class LeNet5(nn.Module):
    def __init__(self):
        super(LeNet5, self).__init__()
        # 定义卷积层C1,输入通道数为1,输出通道数为6,卷积核大小为5x5
        self.conv1 = nn.Conv2d(3, 6, kernel_size=5, stride=1)
        # 定义池化层S2,池化核大小为2x2,步长为2
        self.pool1 = nn.AvgPool2d(kernel_size=2, stride=2)
        # 定义卷积层C3,输入通道数为6,输出通道数为16,卷积核大小为5x5
        self.conv2 = nn.Conv2d(6, 16, kernel_size=5, stride=1)
        # 定义池化层S4,池化核大小为2x2,步长为2
        self.pool2 = nn.AvgPool2d(kernel_size=2, stride=2)
        # 定义全连接层F5,输入节点数为16x4x4=256,输出节点数为120
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        # 定义全连接层F6,输入节点数为120,输出节点数为84
        self.fc2 = nn.Linear(120, 84)
        # 定义输出层,输入节点数为84,输出节点数为10
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        # 卷积层C1
        x = self.conv1(x)
        # 池化层S2
        x = self.pool1(torch.relu(x))
        # 卷积层C3
        x = self.conv2(x)
        # 池化层S4
        x = self.pool2(torch.relu(x))
        # 全连接层F5
        x = x.view(-1, 16 * 5 * 5)
        x = self.fc1(x)
        x = torch.relu(x)
        # 全连接层F6
        x = self.fc2(x)
        x = torch.relu(x)
        # 输出层
        x = self.fc3(x)
        return x

和上节类似 , 这个模型定义了经典的 LeNet-5。它由两个卷积层、两个池化层和三个全连接层组成,层间通过一定的非线性激活函数进行连接。

  • 模型中的第一个卷积层(C1)的输入通道数是3,即输入的是3通道的图像数据;输出通道数为6,表示该层有6个卷积核,每个卷积核可以提取出一种特征,卷积核大小为5x5。
  • 第一个池化层(S2)的池化核大小为2x2,步长为2,可以将特征图的大小降低一半。
  • 卷积层(C3),输入通道数为6,输出通道数为16,卷积核大小为5x5。然后再经过一个池化层(S4),池化核大小为2x2,步长为2,同样可以将特征图的大小降低一半。
  • 接下来是三个全连接层(F5、F6、F7),其中 F5 的输入节点数为16x5x5=400,输出节点数为120;F6 的输入节点数为120,输出节点数为84;输出层的输入节点数为84,输出节点数为10,表示对10个类别进行分类。

最后,网络输出了分类结果。在前向传播过程中,经过卷积、池化、全连接等操作,每层的输出都要经过一定的非线性激活函数,这里使用的是 ReLU 函数(即 Rectified Linear Unit)。

3. 训练与测试

model = LeNet5()
optimizer = optim.SGD(model.parameters(), lr=learning_rate)
criterion = nn.CrossEntropyLoss()
if __name__ == '__main__':
    # 定义模型保存路径和文件名
    model_path = 'model.pth'
    if os.path.exists(model_path):
        # 存在,直接加载模型
        model.load_state_dict(torch.load(model_path))
        print('Loaded model from', model_path)
    else:
        # 训练模型
        for epoch in range(epochs):
            model.train()
            for images, labels in train_loader:
                # 将数据放入模型
                optimizer.zero_grad()
                outputs = model(images)
                loss = criterion(outputs, labels)
                loss.backward()
                optimizer.step()

            # 在测试集上测试模型
            model.eval()
            correct = 0
            with torch.no_grad():
                for images, labels in test_loader:
                    # 将数据放入模型
                    outputs = model(images)
                    _, predicted = torch.max(outputs, 1)
                    correct += (predicted == labels).sum().item()

            accuracy = 100 * correct / len(testset)
            print('Epoch [{}/{}], Loss: {:.4f}, Accuracy: {:.2f}%'.format(epoch + 1, epochs, loss.item(), accuracy))

        torch.save(model.state_dict(), 'model.pth')

    for i in range(10):
        img, label = next(iter(test_loader))
        img = img[i].unsqueeze(0)

        # 使用模型进行预测
        model.eval()
        with torch.no_grad():
            output = model(img)

        # 解码预测结果
        pred = output.argmax(dim=1).item()
        print(f'Predicted class: {pred}, actual value: {label[i]}')

四、完整代码

import os

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torchvision import datasets, transforms
from torch.utils.data import DataLoader


# 定义 LeNet-5 模型
class LeNet5(nn.Module):
    def __init__(self):
        super(LeNet5, self).__init__()
        # 定义卷积层C1,输入通道数为1,输出通道数为6,卷积核大小为5x5
        self.conv1 = nn.Conv2d(3, 6, kernel_size=5, stride=1)
        # 定义池化层S2,池化核大小为2x2,步长为2
        self.pool1 = nn.AvgPool2d(kernel_size=2, stride=2)
        # 定义卷积层C3,输入通道数为6,输出通道数为16,卷积核大小为5x5
        self.conv2 = nn.Conv2d(6, 16, kernel_size=5, stride=1)
        # 定义池化层S4,池化核大小为2x2,步长为2
        self.pool2 = nn.AvgPool2d(kernel_size=2, stride=2)
        # 定义全连接层F5,输入节点数为16x4x4=256,输出节点数为120
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        # 定义全连接层F6,输入节点数为120,输出节点数为84
        self.fc2 = nn.Linear(120, 84)
        # 定义输出层,输入节点数为84,输出节点数为10
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        # 卷积层C1
        x = self.conv1(x)
        # print('卷积层C1后的形状:', x.shape)
        # 池化层S2
        x = self.pool1(torch.relu(x))
        # print('池化层S2后的形状:', x.shape)
        # 卷积层C3
        x = self.conv2(x)
        # print('卷积层C3后的形状:', x.shape)
        # 池化层S4
        x = self.pool2(torch.relu(x))
        # print('池化层S4后的形状:', x.shape)
        # 全连接层F5
        x = x.view(-1, 16 * 5 * 5)
        x = self.fc1(x)
        # print('全连接层F5后的形状:', x.shape)
        x = torch.relu(x)
        # 全连接层F6
        x = self.fc2(x)
        # print('全连接层F6后的形状:', x.shape)
        x = torch.relu(x)
        # 输出层
        x = self.fc3(x)
        # print('输出层后的形状:', x.shape)
        return x

# 设置超参数
batch_size = 128
learning_rate = 0.01
epochs = 10

# CIFAR-10
# 准备数据
transform = transforms.Compose(
    [transforms.RandomHorizontalFlip(),
     transforms.RandomCrop(32, padding=4),
     transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(trainset, batch_size=32,
                                          shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(testset, batch_size=32,
                                         shuffle=False, num_workers=2)

# 实例化模型和优化器
model = LeNet5()
optimizer = optim.SGD(model.parameters(), lr=learning_rate)
criterion = nn.CrossEntropyLoss()
if __name__ == '__main__':
    # 定义模型保存路径和文件名
    model_path = 'model.pth'
    if os.path.exists(model_path):
        # 存在,直接加载模型
        model.load_state_dict(torch.load(model_path))
        print('Loaded model from', model_path)
    else:
        # 训练模型
        for epoch in range(epochs):
            model.train()
            for images, labels in train_loader:
                # 将数据放入模型
                optimizer.zero_grad()
                outputs = model(images)
                loss = criterion(outputs, labels)
                loss.backward()
                optimizer.step()

            # 在测试集上测试模型
            model.eval()
            correct = 0
            with torch.no_grad():
                for images, labels in test_loader:
                    # 将数据放入模型
                    outputs = model(images)
                    _, predicted = torch.max(outputs, 1)
                    correct += (predicted == labels).sum().item()

            accuracy = 100 * correct / len(testset)
            print('Epoch [{}/{}], Loss: {:.4f}, Accuracy: {:.2f}%'.format(epoch + 1, epochs, loss.item(), accuracy))

        torch.save(model.state_dict(), 'model.pth')

    for i in range(10):
        img, label = next(iter(test_loader))
        img = img[i].unsqueeze(0)

        # 使用模型进行预测
        model.eval()
        with torch.no_grad():
            output = model(img)

        # 解码预测结果
        pred = output.argmax(dim=1).item()
        print(f'Predicted class: {pred}, actual value: {label[i]}')

最后训练准确率仅有47.13%。


http://www.kler.cn/a/1195.html

相关文章:

  • STL总结
  • 【Python/Opencv】图像权重加法函数:cv2.addWeighted()详解
  • 节流还在用JS吗?CSS也可以实现哦
  • JAVA并发编程(2)——(如何保证原子性,原子类,CAS乐观锁,JUC常用类)
  • 176万,GPT-4发布了,如何查看OpenAI的下载量?
  • 面试官:聊聊你知道的跨域解决方案
  • Linux 路由表说明
  • 剑指 Offer II 031. 最近最少使用缓存
  • Linux:函数指针做函数参数
  • 介绍两款红队常用的信息收集组合工具
  • 【CSS 知识总结】第二篇 - HTML 扩展简介
  • OKHttp 源码解析(二)拦截器
  • 中断控制器
  • 面试官问 : ArrayList 不是线程安全的,为什么 ?(看完这篇,以后反问面试官)
  • 信创办公–基于WPS的PPT最佳实践系列(表格和图标常用动画)
  • 每日算法题
  • Unity学习日记12(导航走路相关、动作完成度返回参数)
  • yolo车牌识别、车辆识别、行人识别、车距识别源码(包含单目双目)
  • Webpack迁移Rspack速攻实战教程(前瞻版)
  • 【OpenCV】车牌自动识别算法的设计与实现