当前位置: 首页 > article >正文

Pytorch | 从零构建GoogleNet对CIFAR10进行分类

Pytorch | 从零构建Vgg对CIFAR10进行分类

  • CIFAR10数据集
  • GoogleNet
      • 网络结构特点
      • 网络整体架构
      • 特征图尺寸变化
      • 应用与影响
  • GoogleNet结构代码详解
    • 结构代码
    • 代码详解
      • Inception 类
        • 初始化方法
        • 前向传播 forward
      • GoogleNet 类
        • 初始化方法
        • 前向传播 forward
  • 训练和测试
    • 训练代码train.py
    • 测试代码test.py
    • 训练过程和测试结果
  • 代码汇总
    • googlenet.py
    • train.py
    • test.py

前面文章我们构建了AlexNet、Vgg对CIFAR10进行分类:
Pytorch | 从零构建AlexNet对CIFAR10进行分类
Pytorch | 从零构建Vgg对CIFAR10进行分类
这篇文章我们来构建GoogleNet.

CIFAR10数据集

CIFAR-10数据集是由加拿大高级研究所(CIFAR)收集整理的用于图像识别研究的常用数据集,基本信息如下:

  • 数据规模:该数据集包含60,000张彩色图像,分为10个不同的类别,每个类别有6,000张图像。通常将其中50,000张作为训练集,用于模型的训练;10,000张作为测试集,用于评估模型的性能。
  • 图像尺寸:所有图像的尺寸均为32×32像素,这相对较小的尺寸使得模型在处理该数据集时能够相对快速地进行训练和推理,但也增加了图像分类的难度。
  • 类别内容:涵盖了飞机(plane)、汽车(car)、鸟(bird)、猫(cat)、鹿(deer)、狗(dog)、青蛙(frog)、马(horse)、船(ship)、卡车(truck)这10个不同的类别,这些类别都是现实世界中常见的物体,具有一定的代表性。

下面是一些示例样本:
在这里插入图片描述

GoogleNet

GoogleNet是由Google团队在2014年提出的一种深度卷积神经网络架构,以下是对它的详细介绍:

网络结构特点

  • Inception模块:这是GoogleNet的核心创新点。Inception模块通过并行使用不同大小的卷积核(如1×1、3×3、5×5)和池化操作,然后将它们的结果在通道维度上进行拼接,从而可以同时提取不同尺度的特征。例如,1×1卷积核可以用于在不改变特征图尺寸的情况下进行降维或升维,减少计算量;3×3和5×5卷积核则可以捕捉不同感受野的特征。
  • 深度和宽度:GoogleNet网络很深,共有22层,但它的参数量却比同层次的一些网络少很多,这得益于Inception模块的高效设计。同时,网络的宽度也较大,能够学习到丰富的特征表示。
  • 辅助分类器:为了缓解梯度消失问题,GoogleNet在网络中间层添加了两个辅助分类器。这些辅助分类器在训练过程中与主分类器一起进行反向传播,帮助梯度更好地传播到浅层网络,加快训练速度并提高模型的泛化能力。在测试时,辅助分类器的结果会被加权融合到主分类器的结果中。

网络整体架构

  • 输入层:接收大小为 H × W × 3 H×W×3 H×W×3的图像数据,其中 H H H W W W表示图像的高度和宽度,3表示图像的RGB通道数。

  • 卷积层和池化层:网络的前面几层主要由卷积层和池化层组成,用于提取图像的基本特征。这些层逐渐降低图像的分辨率,同时增加特征图的通道数。

  • Inception模块组:网络的主体部分由多个Inception模块组构成,每个模块组包含多个Inception模块。随着网络的深入,Inception模块的输出通道数逐渐增加,以学习更高级的特征。
    在这里插入图片描述

  • 池化层和全连接层:在Inception模块组之后,网络通过一个平均池化层将特征图的尺寸缩小到1×1,然后将其展平并连接到一个全连接层,最后通过一个Softmax层输出分类结果。

特征图尺寸变化

假设输入图像的尺寸为 32 × 32 × 3 32×32×3 32×32×3,以下是在网络前向传播过程中特征图尺寸的大致变化:

  1. 第一层卷积:使用 7 × 7 7×7 7×7的卷积核,步长为2,填充为3,经过卷积后特征图尺寸变为 16 × 16 × 64 16×16×64 16×16×64
  2. 最大池化层:使用 3 × 3 3×3 3×3的池化核,步长为2,经过池化后特征图尺寸变为 8 × 8 × 64 8×8×64 8×8×64
  3. Inception模块组:在Inception模块组中,特征图的尺寸会根据不同的卷积和池化操作而变化。例如,在一些Inception模块中,使用 1 × 1 1×1 1×1 3 × 3 3×3 3×3 5 × 5 5×5 5×5的卷积核以及 3 × 3 3×3 3×3的池化操作,特征图的尺寸可能会在 8 × 8 8×8 8×8 4 × 4 4×4 4×4等之间变化,而通道数会逐渐增加。
  4. 平均池化层:在网络的最后,使用一个平均池化层将特征图的尺寸变为 1 × 1 × 1024 1×1×1024 1×1×1024

应用与影响

  • 图像分类:GoogleNet在图像分类任务上取得了非常好的效果,在ILSVRC 2014图像分类竞赛中获得了冠军。它能够准确地识别各种自然图像中的物体类别,如猫、狗、汽车、飞机等。
  • 目标检测:GoogleNet也可以应用于目标检测任务,通过在网络中添加一些额外的检测层和算法,可以实现对图像中物体的定位和检测。
  • 后续研究基础:GoogleNet的成功推动了深度学习领域的发展,其Inception模块的设计思想为后来的许多网络架构提供了灵感,如Inception系列的后续版本以及其他一些基于多分支结构的网络。

GoogleNet结构代码详解

结构代码

import torch
import torch.nn as nn


class Inception(nn.Module):
    def __init__(self, in_channels, ch1x1, ch3x3reduc, ch3x3, ch5x5reduc, ch5x5, pool_proj):
        super().__init__()
        self.branch1x1 = nn.Sequential(
            nn.Conv2d(in_channels, ch1x1, kernel_size=1),
            nn.BatchNorm2d(ch1x1),
            nn.ReLU(inplace=True)
        )

        self.branch3x3 = nn.Sequential(
            nn.Conv2d(in_channels, ch3x3reduc, kernel_size=1),
            nn.BatchNorm2d(ch3x3reduc),
            nn.ReLU(inplace=True),
            nn.Conv2d(ch3x3reduc, ch3x3, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )

        self.branch5x5 = nn.Sequential(
            nn.Conv2d(in_channels, ch5x5reduc, kernel_size=1),
            nn.BatchNorm2d(ch5x5reduc),
            nn.ReLU(inplace=True),
            nn.Conv2d(ch5x5reduc, ch5x5, kernel_size=3, padding=1),
            nn.BatchNorm2d(ch5x5),
            nn.ReLU(inplace=True),
            nn.Conv2d(ch5x5, ch5x5, kernel_size=3, padding=1),
            nn.BatchNorm2d(ch5x5),
            nn.ReLU(inplace=True)
        )

        self.branch_pool = nn.Sequential(
            nn.MaxPool2d(kernel_size=3, stride=1, padding=1),
            nn.Conv2d(in_channels, pool_proj, kernel_size=1),
            nn.BatchNorm2d(pool_proj),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        branch1x1 = self.branch1x1(x)
        branch3x3 = self.branch3x3(x)
        branch5x5 = self.branch5x5(x)
        branch_pool = self.branch_pool(x)

        return torch.cat([branch1x1, branch3x3, branch5x5, branch_pool], 1)


class GoogleNet(nn.Module):
    def __init__(self, num_classes):
        super(GoogleNet, self).__init__()
        self.prelayers = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 192, kernel_size=3, padding=1),
            nn.BatchNorm2d(192),
            nn.ReLU(inplace=True)
        )
        self.maxpool2 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        
        self.inception3a = Inception(192, 64, 96, 128, 16, 32, 32)
        self.inception3b = Inception(256, 128, 128, 192, 32, 96, 64)
        self.maxpool3 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        self.inception4a = Inception(480, 192, 96, 208, 16, 48, 64)
        self.inception4b = Inception(512, 160, 112, 224, 24, 64, 64)
        self.inception4c = Inception(512, 128, 128, 256, 24, 64, 64)
        self.inception4d = Inception(512, 112, 144, 288, 32, 64, 64)
        self.inception4e = Inception(528, 256, 160, 320, 32, 128, 128)
        self.maxpool4 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        self.inception5a = Inception(832, 256, 160, 320, 32, 128, 128)
        self.inception5b = Inception(832, 384, 192, 384, 48, 128, 128)

        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.dropout = nn.Dropout(0.4)
        self.fc = nn.Linear(1024, num_classes)

    def forward(self, x):
        x = self.prelayers(x)
        x = self.maxpool2(x)

        x = self.inception3a(x)
        x = self.inception3b(x)
        x = self.maxpool3(x)

        x = self.inception4a(x)
        x = self.inception4b(x)
        x = self.inception4c(x)
        x = self.inception4d(x)
        x = self.inception4e(x)
        x = self.maxpool4(x)

        x = self.inception5a(x)
        x = self.inception5b(x)

        x = self.avgpool(x)
        x = self.dropout(x)
        x = x.view(x.size()[0], -1)
        x = self.fc(x)

        return x

代码详解

以下是对上述代码的详细解释,这段Python代码使用PyTorch库构建了经典的GoogleNet(Inception v1)网络结构,用于图像分类任务,以下从不同部分展开介绍:

Inception 类

这个类定义了GoogleNet中的Inception模块,它的作用是通过不同尺寸的卷积核等操作来并行提取特征,然后将这些特征在通道维度上进行拼接。

初始化方法
  • 参数说明
    • in_channels:输入特征图的通道数,即输入数据的深度维度。
    • ch1x1ch3x3reducch3x3ch5x5reducch5x5pool_proj:分别对应不同分支中卷积操作涉及的通道数等参数,用于配置每个分支的结构。
  • 网络结构构建
    • self.branch1x1:构建了一个包含1×1卷积、批归一化(BatchNorm)和ReLU激活函数的顺序结构。1×1卷积用于在不改变特征图尺寸的情况下调整通道数,批归一化有助于加速训练和提高模型稳定性,ReLU激活函数引入非线性变换。
    • self.branch3x3:先是一个1×1卷积进行通道数的降维(减少计算量),接着经过批归一化和ReLU激活,然后是一个3×3卷积(通过padding=1保证特征图尺寸不变),最后再接ReLU激活。
    • self.branch5x5:结构相对更复杂些,先是1×1卷积和批归一化、ReLU激活,然后连续两个3×3卷积(都通过合适的padding保证尺寸不变),中间穿插批归一化和ReLU激活,用于提取更复杂的特征。
    • self.branch_pool:先进行最大池化(MaxPool2d,通过特定参数设置保证尺寸基本不变),然后接1×1卷积来调整通道数,再进行批归一化和ReLU激活。
前向传播 forward
  • 接收输入张量x,分别将其传入上述四个分支结构中,得到四个分支的输出branch1x1branch3x3branch5x5branch_pool
  • 最后通过torch.cat函数沿着通道维度(维度1,即参数中的1)将这四个分支的输出特征图拼接在一起,作为整个Inception模块的输出。

GoogleNet 类

这是整个网络的主体类,将多个Inception模块以及其他必要的层组合起来构建完整的GoogleNet架构。

初始化方法
  • 参数说明
    • num_classes:表示分类任务的类别数量,用于最终全连接层输出对应数量的类别预测结果。
  • 网络结构构建
    • self.prelayers:由一系列的卷积、批归一化和ReLU激活函数组成的顺序结构,用于对输入图像进行初步的特征提取,逐步将输入的3通道(对应RGB图像)特征图转换为192通道的特征图。
    • self.maxpool2:一个最大池化层,用于下采样,减小特征图尺寸,同时增大感受野,步长为2,按一定的padding设置来控制输出尺寸。
    • 接下来依次定义了多个Inception模块,如self.inception3aself.inception3b等,它们的输入通道数和各分支的配置参数不同,随着网络的深入逐渐提取更高级、更复杂的特征,并且中间穿插了几个最大池化层(self.maxpool3self.maxpool4等)进行下采样操作。
    • self.avgpool:自适应平均池化层,将不同尺寸的特征图转换为固定大小(这里是1×1)的特征图,方便后续的全连接层处理。
    • self.dropout:引入Dropout层,概率设置为0.4,在训练过程中随机丢弃部分神经元连接,防止过拟合。
    • self.fc:全连接层,将经过前面处理后的特征映射到指定的num_classes个类别上,用于最终的分类预测。
前向传播 forward
  • 首先将输入x传入self.prelayers进行初步特征提取,然后经过self.maxpool2下采样。
  • 接着依次将特征图传入各个Inception模块,并穿插经过最大池化层进行下采样,不断提取和整合特征。
  • 经过最后的Inception模块后,特征图通过self.avgpool进行平均池化,再经过self.dropout进行随机失活处理,然后通过x.view函数将特征图展平成一维向量(方便全连接层处理),最后传入self.fc全连接层得到最终的分类预测结果并返回。

训练和测试

训练代码train.py

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from models import *
import matplotlib.pyplot as plt

import ssl
ssl._create_default_https_context = ssl._create_unverified_context

# 定义数据预处理操作
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.491, 0.482, 0.446), (0.247, 0.243, 0.261))])

# 加载CIFAR10训练集
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=False, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128,
                                          shuffle=True, num_workers=2)

# 定义设备(GPU优先,若可用)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 实例化模型
model_name = 'GoogleNet'
if model_name == 'AlexNet':
    model = AlexNet(num_classes=10).to(device)
elif model_name == 'Vgg_A':
    model = Vgg(cfg_vgg='A', num_classes=10).to(device)
elif model_name == 'Vgg_A-LRN':
    model = Vgg(cfg_vgg='A-LRN', num_classes=10).to(device)
elif model_name == 'Vgg_B':
    model = Vgg(cfg_vgg='B', num_classes=10).to(device)
elif model_name == 'Vgg_C':
    model = Vgg(cfg_vgg='C', num_classes=10).to(device)
elif model_name == 'Vgg_D':
    model = Vgg(cfg_vgg='D', num_classes=10).to(device)
elif model_name == 'Vgg_E':
    model = Vgg(cfg_vgg='E', num_classes=10).to(device)
elif model_name == 'GoogleNet':
    model = GoogleNet(num_classes=10).to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# 训练轮次
epochs = 15

def train(model, trainloader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data[0].to(device), data[1].to(device)

        optimizer.zero_grad()

        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()

    epoch_loss = running_loss / len(trainloader)
    epoch_acc = 100. * correct / total
    return epoch_loss, epoch_acc

if __name__ == "__main__":
    loss_history, acc_history = [], []
    for epoch in range(epochs):
        train_loss, train_acc = train(model, trainloader, criterion, optimizer, device)
        print(f'Epoch {epoch + 1}: Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%')
        loss_history.append(train_loss)
        acc_history.append(train_acc)
        # 保存模型权重,每5轮次保存到weights文件夹下
        if (epoch + 1) % 5 == 0:
            torch.save(model.state_dict(), f'weights/{model_name}_epoch_{epoch + 1}.pth')
    
    # 绘制损失曲线
    plt.plot(range(1, epochs+1), loss_history, label='Loss', marker='o')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Training Loss Curve')
    plt.legend()
    plt.savefig(f'results\\{model_name}_train_loss_curve.png')
    plt.close()

    # 绘制准确率曲线
    plt.plot(range(1, epochs+1), acc_history, label='Accuracy', marker='o')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy (%)')
    plt.title('Training Accuracy Curve')
    plt.legend()
    plt.savefig(f'results\\{model_name}_train_acc_curve.png')
    plt.close()

测试代码test.py

import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from models import *

import ssl
ssl._create_default_https_context = ssl._create_unverified_context
# 定义数据预处理操作
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.491, 0.482, 0.446), (0.247, 0.243, 0.261))])

# 加载CIFAR10测试集
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=False, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=128,
                                         shuffle=False, num_workers=2)

# 定义设备(GPU优先,若可用)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 实例化模型
model_name = 'GoogleNet'
if model_name == 'AlexNet':
    model = AlexNet(num_classes=10).to(device)
elif model_name == 'Vgg_A':
    model = Vgg(cfg_vgg='A', num_classes=10).to(device)
elif model_name == 'Vgg_A-LRN':
    model = Vgg(cfg_vgg='A-LRN', num_classes=10).to(device)
elif model_name == 'Vgg_B':
    model = Vgg(cfg_vgg='B', num_classes=10).to(device)
elif model_name == 'Vgg_C':
    model = Vgg(cfg_vgg='C', num_classes=10).to(device)
elif model_name == 'Vgg_D':
    model = Vgg(cfg_vgg='D', num_classes=10).to(device)
elif model_name == 'Vgg_E':
    model = Vgg(cfg_vgg='E', num_classes=10).to(device)
elif model_name == 'GoogleNet':
    model = GoogleNet(num_classes=10).to(device)

criterion = nn.CrossEntropyLoss()

# 加载模型权重
weights_path = f"weights/{model_name}_epoch_15.pth"  
model.load_state_dict(torch.load(weights_path, map_location=device))

def test(model, testloader, criterion, device):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        for data in testloader:
            inputs, labels = data[0].to(device), data[1].to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)

            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()

    epoch_loss = running_loss / len(testloader)
    epoch_acc = 100. * correct / total
    return epoch_loss, epoch_acc

if __name__ == "__main__":
    test_loss, test_acc = test(model, testloader, criterion, device)
    print(f"================{model_name} Test================")
    print(f"Load Model Weights From: {weights_path}")
    print(f'Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.2f}%')

训练过程和测试结果

训练过程损失函数变化曲线:
在这里插入图片描述

训练过程准确率变化曲线:
在这里插入图片描述

测试结果:
在这里插入图片描述

代码汇总

项目github地址
项目结构:

|--data
|--models
	|--__init__.py
	|--googlenet.py
	|--...
|--results
|--weights
|--train.py
|--test.py

googlenet.py

import torch
import torch.nn as nn


class Inception(nn.Module):
    def __init__(self, in_channels, ch1x1, ch3x3reduc, ch3x3, ch5x5reduc, ch5x5, pool_proj):
        super().__init__()
        self.branch1x1 = nn.Sequential(
            nn.Conv2d(in_channels, ch1x1, kernel_size=1),
            nn.BatchNorm2d(ch1x1),
            nn.ReLU(inplace=True)
        )

        self.branch3x3 = nn.Sequential(
            nn.Conv2d(in_channels, ch3x3reduc, kernel_size=1),
            nn.BatchNorm2d(ch3x3reduc),
            nn.ReLU(inplace=True),
            nn.Conv2d(ch3x3reduc, ch3x3, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )

        self.branch5x5 = nn.Sequential(
            nn.Conv2d(in_channels, ch5x5reduc, kernel_size=1),
            nn.BatchNorm2d(ch5x5reduc),
            nn.ReLU(inplace=True),
            nn.Conv2d(ch5x5reduc, ch5x5, kernel_size=3, padding=1),
            nn.BatchNorm2d(ch5x5),
            nn.ReLU(inplace=True),
            nn.Conv2d(ch5x5, ch5x5, kernel_size=3, padding=1),
            nn.BatchNorm2d(ch5x5),
            nn.ReLU(inplace=True)
        )

        self.branch_pool = nn.Sequential(
            nn.MaxPool2d(kernel_size=3, stride=1, padding=1),
            nn.Conv2d(in_channels, pool_proj, kernel_size=1),
            nn.BatchNorm2d(pool_proj),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        branch1x1 = self.branch1x1(x)
        branch3x3 = self.branch3x3(x)
        branch5x5 = self.branch5x5(x)
        branch_pool = self.branch_pool(x)

        return torch.cat([branch1x1, branch3x3, branch5x5, branch_pool], 1)


class GoogleNet(nn.Module):
    def __init__(self, num_classes):
        super(GoogleNet, self).__init__()
        self.prelayers = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 192, kernel_size=3, padding=1),
            nn.BatchNorm2d(192),
            nn.ReLU(inplace=True)
        )
        self.maxpool2 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        
        self.inception3a = Inception(192, 64, 96, 128, 16, 32, 32)
        self.inception3b = Inception(256, 128, 128, 192, 32, 96, 64)
        self.maxpool3 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        self.inception4a = Inception(480, 192, 96, 208, 16, 48, 64)
        self.inception4b = Inception(512, 160, 112, 224, 24, 64, 64)
        self.inception4c = Inception(512, 128, 128, 256, 24, 64, 64)
        self.inception4d = Inception(512, 112, 144, 288, 32, 64, 64)
        self.inception4e = Inception(528, 256, 160, 320, 32, 128, 128)
        self.maxpool4 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        self.inception5a = Inception(832, 256, 160, 320, 32, 128, 128)
        self.inception5b = Inception(832, 384, 192, 384, 48, 128, 128)

        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.dropout = nn.Dropout(0.4)
        self.fc = nn.Linear(1024, num_classes)

    def forward(self, x):
        x = self.prelayers(x)
        x = self.maxpool2(x)

        x = self.inception3a(x)
        x = self.inception3b(x)
        x = self.maxpool3(x)

        x = self.inception4a(x)
        x = self.inception4b(x)
        x = self.inception4c(x)
        x = self.inception4d(x)
        x = self.inception4e(x)
        x = self.maxpool4(x)

        x = self.inception5a(x)
        x = self.inception5b(x)

        x = self.avgpool(x)
        x = self.dropout(x)
        x = x.view(x.size()[0], -1)
        x = self.fc(x)

        return x

train.py

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from models import *
import matplotlib.pyplot as plt

import ssl
ssl._create_default_https_context = ssl._create_unverified_context

# 定义数据预处理操作
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.491, 0.482, 0.446), (0.247, 0.243, 0.261))])

# 加载CIFAR10训练集
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=False, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128,
                                          shuffle=True, num_workers=2)

# 定义设备(GPU优先,若可用)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 实例化模型
model_name = 'GoogleNet'
if model_name == 'AlexNet':
    model = AlexNet(num_classes=10).to(device)
elif model_name == 'Vgg_A':
    model = Vgg(cfg_vgg='A', num_classes=10).to(device)
elif model_name == 'Vgg_A-LRN':
    model = Vgg(cfg_vgg='A-LRN', num_classes=10).to(device)
elif model_name == 'Vgg_B':
    model = Vgg(cfg_vgg='B', num_classes=10).to(device)
elif model_name == 'Vgg_C':
    model = Vgg(cfg_vgg='C', num_classes=10).to(device)
elif model_name == 'Vgg_D':
    model = Vgg(cfg_vgg='D', num_classes=10).to(device)
elif model_name == 'Vgg_E':
    model = Vgg(cfg_vgg='E', num_classes=10).to(device)
elif model_name == 'GoogleNet':
    model = GoogleNet(num_classes=10).to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# 训练轮次
epochs = 15

def train(model, trainloader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data[0].to(device), data[1].to(device)

        optimizer.zero_grad()

        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()

    epoch_loss = running_loss / len(trainloader)
    epoch_acc = 100. * correct / total
    return epoch_loss, epoch_acc

if __name__ == "__main__":
    loss_history, acc_history = [], []
    for epoch in range(epochs):
        train_loss, train_acc = train(model, trainloader, criterion, optimizer, device)
        print(f'Epoch {epoch + 1}: Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%')
        loss_history.append(train_loss)
        acc_history.append(train_acc)
        # 保存模型权重,每5轮次保存到weights文件夹下
        if (epoch + 1) % 5 == 0:
            torch.save(model.state_dict(), f'weights/{model_name}_epoch_{epoch + 1}.pth')
    
    # 绘制损失曲线
    plt.plot(range(1, epochs+1), loss_history, label='Loss', marker='o')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Training Loss Curve')
    plt.legend()
    plt.savefig(f'results\\{model_name}_train_loss_curve.png')
    plt.close()

    # 绘制准确率曲线
    plt.plot(range(1, epochs+1), acc_history, label='Accuracy', marker='o')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy (%)')
    plt.title('Training Accuracy Curve')
    plt.legend()
    plt.savefig(f'results\\{model_name}_train_acc_curve.png')
    plt.close()

test.py

import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from models import *

import ssl
ssl._create_default_https_context = ssl._create_unverified_context
# 定义数据预处理操作
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.491, 0.482, 0.446), (0.247, 0.243, 0.261))])

# 加载CIFAR10测试集
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=False, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=128,
                                         shuffle=False, num_workers=2)

# 定义设备(GPU优先,若可用)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 实例化模型
model_name = 'GoogleNet'
if model_name == 'AlexNet':
    model = AlexNet(num_classes=10).to(device)
elif model_name == 'Vgg_A':
    model = Vgg(cfg_vgg='A', num_classes=10).to(device)
elif model_name == 'Vgg_A-LRN':
    model = Vgg(cfg_vgg='A-LRN', num_classes=10).to(device)
elif model_name == 'Vgg_B':
    model = Vgg(cfg_vgg='B', num_classes=10).to(device)
elif model_name == 'Vgg_C':
    model = Vgg(cfg_vgg='C', num_classes=10).to(device)
elif model_name == 'Vgg_D':
    model = Vgg(cfg_vgg='D', num_classes=10).to(device)
elif model_name == 'Vgg_E':
    model = Vgg(cfg_vgg='E', num_classes=10).to(device)
elif model_name == 'GoogleNet':
    model = GoogleNet(num_classes=10).to(device)

criterion = nn.CrossEntropyLoss()

# 加载模型权重
weights_path = f"weights/{model_name}_epoch_15.pth"  
model.load_state_dict(torch.load(weights_path, map_location=device))

def test(model, testloader, criterion, device):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        for data in testloader:
            inputs, labels = data[0].to(device), data[1].to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)

            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()

    epoch_loss = running_loss / len(testloader)
    epoch_acc = 100. * correct / total
    return epoch_loss, epoch_acc

if __name__ == "__main__":
    test_loss, test_acc = test(model, testloader, criterion, device)
    print(f"================{model_name} Test================")
    print(f"Load Model Weights From: {weights_path}")
    print(f'Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.2f}%')

http://www.kler.cn/a/441184.html

相关文章:

  • 【深度分析】DeepSeek 遭暴力破解,攻击 IP 均来自美国,造成影响有多大?有哪些好的防御措施?
  • 【NLP251】意图识别 与 Seq2Seq
  • 力扣【416. 分割等和子集】详细Java题解(背包问题)
  • 【PyTorch】6.张量形状操作:在深度学习的 “魔方” 里,玩转张量形状
  • 【最后203篇系列】007 使用APS搭建本地定时任务
  • C# dataGridView1获取选中行的名字
  • 【Linux课程学习】:第二十一弹---深入理解信号(中断,信号,kill,abort,raise,larm函数)
  • Linux学习笔记思维导图(系统调用+网络编程)
  • 信息安全实训室网络攻防靶场实战核心平台解决方案
  • 【集成部署打包】vue3+django集成部署打包成exe 文件
  • android、flutter离线推送插件,支持oppo、vivo、小米、华为
  • Tomcat HTTPS配置、域名解析及Java WAR包打包
  • Moretl品质文件采集
  • nods.js之nrm安装及使用
  • 掌握HTML, 从零开始的网页设计
  • ubuntu+ros新手笔记(三):21讲没讲到的MoveIt2
  • 【网络】传输层协议UDP/TCP网络层IP数据链路层MACNAT详解
  • Github 2024-12-14 Rust开源项目日报Top9
  • 动手学深度学习-线性神经网络-7softmax回归的简洁实现
  • 如何在Excel中保护公式?三种方法让你的数据更安全
  • AI前沿分析:ChatGPT搜索上线,Google搜索地位能否守住?
  • 一、springcloud 入门——笔记
  • 芯品荟|SWM221系列芯片之TFTLCD彩屏显示及控制
  • [SZ901] JTAG合并功能(类似FPGA菊花链)
  • 【错误收集】tomcat资源访问404
  • uniapp navigateTo、redirectTo、reLaunch等页面路由跳转方法的区别