当前位置: 首页 > article >正文

Pytorch | 对比Pytorch中的十种优化器:基于CIFAR10上的ResNet分类器

Pytorch | 对比Pytorch中的十种优化器:基于CIFAR10上的ResNet分类器

  • CIFAR10数据集
  • ResNet
    • 提出背景
    • 网络结构特点
    • 工作原理
    • 优势
  • 代码实现分析
    • utils.py
    • main.py
      • 导入必要的库
      • 设备选择与数据预处理定义
      • 加载训练集和测试集
      • 主函数部分
        • 训练部分
        • 测试部分
  • 结果
    • 10种优化器对应的训练损失下降曲线
      • Adadelta
      • Adagrad
      • Adam
      • Adamax
      • AdamW
      • NAdam
      • RMSprop
      • Rprop
      • SGD
      • SparseAdam
    • 测试结果
  • 代码汇总
    • utils.py
    • main.py

上篇文章中实现了十种优化算法:Python | 从零实现10种优化算法并比较
这篇文章我们用Pytorch上不同的优化器在CIFAR10数据集上训练ResNet模型,比较不同优化器的效果。

CIFAR10数据集

CIFAR-10数据集是由加拿大高级研究所(CIFAR)收集整理的用于图像识别研究的常用数据集,基本信息如下:

  • 数据规模:该数据集包含60,000张彩色图像,分为10个不同的类别,每个类别有6,000张图像。通常将其中50,000张作为训练集,用于模型的训练;10,000张作为测试集,用于评估模型的性能。
  • 图像尺寸:所有图像的尺寸均为32×32像素,这相对较小的尺寸使得模型在处理该数据集时能够相对快速地进行训练和推理,但也增加了图像分类的难度。
  • 类别内容:涵盖了飞机(plane)、汽车(car)、鸟(bird)、猫(cat)、鹿(deer)、狗(dog)、青蛙(frog)、马(horse)、船(ship)、卡车(truck)这10个不同的类别,这些类别都是现实世界中常见的物体,具有一定的代表性。

下面是一些示例样本:
在这里插入图片描述

ResNet

ResNet(Residual Network)即残差网络,是由微软研究院的何恺明等人在2015年提出的一种深度卷积神经网络架构,在图像识别等计算机视觉任务中取得了巨大成功。

提出背景

随着神经网络深度的增加,出现了梯度消失/爆炸以及网络退化等问题,导致训练难度增大,精度饱和甚至下降。ResNet通过引入残差连接(shortcut connection)有效地解决了这些问题,使得训练极深的网络成为可能。

网络结构特点

  • 残差块(Residual Block):这是ResNet的核心结构。它由多个卷积层组成,并且在卷积层之间引入了shortcut connection。一个基本的残差块包含两个3×3卷积层,中间有一个ReLU激活函数,其输入可以直接跳过这两个卷积层与输出相加,这种结构使得网络能够学习到残差函数,即输入与输出之间的差异,而不是直接学习输出本身。
    在这里插入图片描述

  • 多种层数的网络结构:ResNet有多种不同层数的架构,如ResNet-18、ResNet-34、ResNet-50、ResNet-101和ResNet-152等,其中数字表示网络的层数。层数越深,模型的表示能力越强,但计算成本也越高,训练难度也相应增大。
    在这里插入图片描述

  • 瓶颈结构(Bottleneck):在较深的ResNet架构如ResNet-50及以上中,使用了瓶颈结构来减少计算量。它由1×1、3×3和1×1三个卷积层组成,1×1卷积层用于降低输入特征图的通道数,3×3卷积层进行主要的特征提取,最后1×1卷积层用于恢复通道数。

工作原理

在正向传播时,输入特征图通过残差块中的卷积层进行特征提取,得到输出特征图。然后将输入特征图与输出特征图相加,得到最终的输出。如果残差块中的卷积层没有学到有用的特征,那么它们的输出接近于零,此时最终的输出就近似等于输入,即网络可以学习到恒等映射。在反向传播时,由于shortcut connection的存在,梯度可以直接通过捷径传播到较早的层,避免了梯度消失或爆炸的问题,使得网络能够更容易地训练深层网络。

优势

  • 有效解决梯度消失和退化问题:使得训练非常深的网络成为可能,能够提取更高级的图像特征,从而提高了模型的准确性和泛化能力。
  • 降低模型训练难度:残差连接使得网络在训练过程中更容易收敛,减少了对超参数调整的依赖,提高了训练效率。
  • 模型具有很强的可扩展性:可以通过增加残差块的数量来构建更深的网络,以适应不同的任务和数据集。

代码实现分析

utils.py

该文件中我们预先为调用不同的优化器做好准备,对于不同的优化器参数,为了方便,这里只设置学习率,其余参数可根据需要自己设置。

import torch
import torch.optim as optim

def get_optimizer(optimizer_name, model_parameters, lr=0.001, **kwargs):
    """
    根据传入的优化器名称返回对应的优化器实例。

    参数:
    - optimizer_name: 优化器名称,如 "Adadelta", "Adagrad" 等。
    - model_parameters: 模型的可训练参数,通常通过 model.parameters() 获取。
    - lr: 学习率,默认值为 0.001。
    - **kwargs: 其他特定优化器需要的额外参数。

    返回:
    - optimizer: 对应的优化器实例。
    """
    if optimizer_name == "Adadelta":
        return optim.Adadelta(model_parameters, lr=lr, **kwargs)
    elif optimizer_name == "Adagrad":
        return optim.Adagrad(model_parameters, lr=lr, **kwargs)
    elif optimizer_name == "Adam":
        return optim.Adam(model_parameters, lr=lr, **kwargs)
    elif optimizer_name == "Adamax":
        return optim.Adamax(model_parameters, lr=lr, **kwargs)
    elif optimizer_name == "AdamW":
        return optim.AdamW(model_parameters, lr=lr, **kwargs)
    elif optimizer_name == "NAdam":
        return optim.NAdam(model_parameters, lr=lr, **kwargs)
    elif optimizer_name == "RMSprop":
        return optim.RMSprop(model_parameters, lr=lr, **kwargs)
    elif optimizer_name == "Rprop":
        return optim.Rprop(model_parameters, lr=lr, **kwargs)
    elif optimizer_name == "SGD":
        return optim.SGD(model_parameters, lr=lr, **kwargs)
    elif optimizer_name == "SparseAdam":
        return optim.SparseAdam(model_parameters, lr=lr, **kwargs)
    else:
        raise ValueError(f"不支持的优化器名称: {optimizer_name}")

main.py

本文这里使用Pytorch自带的ResNet18模型。

以下是对上述代码的分块讲解:

导入必要的库

import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
from utils import get_optimizer

import warnings
warnings.filterwarnings("ignore")

import ssl

ssl._create_default_https_context = ssl._create_unverified_context
  • from utils import get_optimizer 从前面的 utils 模块中导入 get_optimizer 函数,用于调用不同的优化器;
  • warnings.filterwarnings("ignore") 用于忽略一些可能出现的警告信息,让代码运行时输出更简洁;
  • ssl 相关的代码是为了解决在下载数据集时可能出现的SSL验证问题,通过创建一个默认的不验证SSL证书的上下文来允许数据正常下载。

设备选择与数据预处理定义

# 检查cuda是否可用
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 定义数据预处理操作,将图像转换为张量并进行归一化
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.491, 0.482, 0.446), (0.247, 0.243, 0.261))]
    )
  • 检查 gpu(cuda) 是否可用;
  • transforms.Normalize 对图像张量进行归一化操作,传入的两个元组分别表示每个通道的均值和标准差

加载训练集和测试集

# 加载训练集,设置batch_size等参数
# 这里batch_size设为128,可按需调整
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128,
                                          shuffle=True, num_workers=2)

# 加载测试集
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=128,
                                         shuffle=False, num_workers=2)

这部分代码用于加载CIFAR-10数据集,它是一个常用的图像分类数据集,包含10个不同类别的图像。

对于训练集:

  • torchvision.datasets.CIFAR10 函数用于创建训练集对象 trainset,其中 root='./data' 指定了数据集下载和保存的根目录(如果不存在会自动下载到该目录下),train=True 表示加载的是训练集部分,download=True 表示如果数据集不存在则自动下载,transform=transform 表示应用前面定义好的数据预处理操作。
  • torch.utils.data.DataLoader 用于将数据集 trainset 包装成一个可迭代的数据加载器 trainloader,设置 batch_size=128 意味着每次迭代会返回一个包含128张图像及其对应标签的批次数据,shuffle=True 会在每个训练轮次开始时打乱数据顺序,num_workers=2 表示使用2个子进程来并行加载数据,加快数据读取速度。

对于测试集:

  • 同样使用 torchvision.datasets.CIFAR10 函数创建 testset,不过 train=False 表示加载的是测试集部分,其他参数作用和训练集加载时类似。
  • 再通过 torch.utils.data.DataLoader 创建 testloader,只是 shuffle=False 因为测试集一般不需要打乱顺序。

主函数部分

训练部分

训练部分实现了使用10种优化器训练,共训练10个epoch,并记录损失下降情况以可视化,具体可参考下面代码和其中注释:

# 定义类别标签,CIFAR-10有10个类别
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
if __name__ == "__main__":
    # 使用不同的优化器
    optimizer_names = ["Adadelta", "Adagrad", "Adam", "Adamax", "AdamW", 
                       "NAdam", "RMSProp", "RProp", "SGD", "SparseAdam"]
	for optimizer_name in optimizer_names:
	    print(f"========= Optimizer: {optimizer_name} ==========")
	    # 使用PyTorch自带的ResNet18模型,修改全连接层输出维度为10(对应CIFAR-10的类别数)
	    model = torchvision.models.resnet18(pretrained=False)
	    num_ftrs = model.fc.in_features
	    model.fc = nn.Linear(num_ftrs, 10)
	    model = model.to(device)
	
	    # 定义交叉熵损失函数,常用于分类任务
	    criterion = nn.CrossEntropyLoss()
	    # 定义优化器
	    optimizer = get_optimizer(optimizer_name, model.parameters(), lr=0.001)
	    # 训练轮数,设为10轮,可根据实际情况更改
	    num_epochs = 10
	    loss_history = []
	    for epoch in range(num_epochs):
	        running_loss = 0.0
	        for i, data in enumerate(trainloader, 0):
	            # 获取输入数据和标签
	            inputs, labels = data
	            inputs = inputs.to(device)
	            labels = labels.to(device)
	            # 梯度清零
	            optimizer.zero_grad()
	
	            # 前向传播 + 计算损失
	            outputs = model(inputs)
	            loss = criterion(outputs, labels)
	
	            # 反向传播并更新权重
	            loss.backward()
	            optimizer.step()
	
	            running_loss += loss.item()
	            loss_history.append(loss.item())
	            if i % 100 == 99:    # 每100个小批次打印一次平均损失
	                print(f'Epoch: {epoch + 1}  Batch: {i + 1}  loss: {running_loss / 100}')
	                running_loss = 0.0
	    
	    # 绘制并保存损失下降曲线
	    plt.plot(loss_history)
	    plt.xlabel('Iteration')
	    plt.ylabel('Loss')
	    plt.title(f'Loss Curve by {optimizer_name}')
	    plt.savefig(f'results\\loss_curve_{optimizer_name}.png')  # 保存图像为 loss_curve.png 文件,可根据需求修改文件名和路径
	    plt.close()
测试部分

测试部分使用已经训练好的模型对测试集进行预测,具体参考下面代码:

    correct = 0
    total = 0
    with torch.no_grad():
        for data in testloader:
            images, labels = data
            images = images.to(device)
            labels = labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    print(f'Optimizer: {optimizer_name} -- Accuracy of the network on the 10000 test images: {100 * correct / total}%')

结果

10种优化器对应的训练损失下降曲线

Adadelta

在这里插入图片描述

Adagrad

在这里插入图片描述

Adam

在这里插入图片描述

Adamax

在这里插入图片描述

AdamW

在这里插入图片描述

NAdam

在这里插入图片描述

RMSprop

在这里插入图片描述

Rprop

在这里插入图片描述

SGD

SparseAdam

本文训练到这里时,报错:RuntimeError: SparseAdam does not support dense gradients, please consider Adam instead .
原因: SparseAdam 优化器主要是设计用于处理稀疏梯度的场景,也就是梯度张量中大部分元素为零的情况(比如在处理稀疏的文本数据表示等情况时)。而在当前的代码应用场景中,很可能模型计算得到的梯度是密集的(即梯度张量中元素大多是非零值),这就导致 SparseAdam 优化器无法正常处理这样的梯度,进而抛出这个错误提示,建议你考虑使用 Adam 优化器来替代 SparseAdam 优化器。
----因此这里不再给出SparseAdam的优化结果----

测试结果

在这里插入图片描述
从测试结果来看,刨除 SparseAdamAdam 优化器的训练效果最优,Adadelta 优化器的训练效果最差.

代码汇总

项目结构:

|--data
|--results
|--utils.py
|--main.py

utils.py

import torch
import torch.optim as optim

def get_optimizer(optimizer_name, model_parameters, lr=0.001, **kwargs):
    """
    根据传入的优化器名称返回对应的优化器实例。

    参数:
    - optimizer_name: 优化器名称,如 "Adadelta", "Adagrad" 等。
    - model_parameters: 模型的可训练参数,通常通过 model.parameters() 获取。
    - lr: 学习率,默认值为 0.001。
    - **kwargs: 其他特定优化器需要的额外参数。

    返回:
    - optimizer: 对应的优化器实例。
    """
    if optimizer_name == "Adadelta":
        return optim.Adadelta(model_parameters, lr=lr, **kwargs)
    elif optimizer_name == "Adagrad":
        return optim.Adagrad(model_parameters, lr=lr, **kwargs)
    elif optimizer_name == "Adam":
        return optim.Adam(model_parameters, lr=lr, **kwargs)
    elif optimizer_name == "Adamax":
        return optim.Adamax(model_parameters, lr=lr, **kwargs)
    elif optimizer_name == "AdamW":
        return optim.AdamW(model_parameters, lr=lr, **kwargs)
    elif optimizer_name == "NAdam":
        return optim.NAdam(model_parameters, lr=lr, **kwargs)
    elif optimizer_name == "RMSprop":
        return optim.RMSprop(model_parameters, lr=lr, **kwargs)
    elif optimizer_name == "Rprop":
        return optim.Rprop(model_parameters, lr=lr, **kwargs)
    elif optimizer_name == "SGD":
        return optim.SGD(model_parameters, lr=lr, **kwargs)
    elif optimizer_name == "SparseAdam":
        return optim.SparseAdam(model_parameters, lr=lr, **kwargs)
    else:
        raise ValueError(f"不支持的优化器名称: {optimizer_name}")

main.py

import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
from utils import get_optimizer

import warnings
warnings.filterwarnings("ignore")

import ssl

ssl._create_default_https_context = ssl._create_unverified_context

# 检查cuda是否可用
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 定义数据预处理操作,将图像转换为张量并进行归一化
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.491, 0.482, 0.446), (0.247, 0.243, 0.261))]
    )

# 加载训练集,设置batch_size等参数
# 这里batch_size设为128,可按需调整
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=False, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128,
                                          shuffle=True, num_workers=4)

# 加载测试集
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=False, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=128,
                                         shuffle=False, num_workers=4)

# 定义类别标签,CIFAR-10有10个类别
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')


if __name__ == "__main__":
    # 使用不同的优化器
    optimizer_names = ["Adadelta", "Adagrad", "Adam", "Adamax", "AdamW", 
                       "NAdam", "RMSprop", "Rprop", "SGD", "SparseAdam"]
    
    for optimizer_name in optimizer_names:
        print(f"========= Optimizer: {optimizer_name} ==========")
        # 使用PyTorch自带的ResNet18模型,修改全连接层输出维度为10(对应CIFAR-10的类别数)
        model = torchvision.models.resnet18(pretrained=False)
        num_ftrs = model.fc.in_features
        model.fc = nn.Linear(num_ftrs, 10)
        model = model.to(device)

        # 定义交叉熵损失函数,常用于分类任务
        criterion = nn.CrossEntropyLoss()
        # 定义优化器
        optimizer = get_optimizer(optimizer_name, model.parameters(), lr=0.001)
        # 训练轮数,设为10轮,可根据实际情况更改
        num_epochs = 10
        loss_history = []
        for epoch in range(num_epochs):
            running_loss = 0.0
            for i, data in enumerate(trainloader, 0):
                # 获取输入数据和标签
                inputs, labels = data
                inputs = inputs.to(device)
                labels = labels.to(device)
                # 梯度清零
                optimizer.zero_grad()

                # 前向传播 + 计算损失
                outputs = model(inputs)
                loss = criterion(outputs, labels)

                # 反向传播并更新权重
                loss.backward()
                optimizer.step()

                running_loss += loss.item()
                loss_history.append(loss.item())
                if i % 100 == 99:    # 每100个小批次打印一次平均损失
                    # print(f'Epoch: {epoch + 1}  Batch: {i + 1}  loss: {running_loss / 100}')
                    running_loss = 0.0

        # 绘制并保存损失下降曲线
        plt.plot(loss_history)
        plt.xlabel('Iteration')
        plt.ylabel('Loss')
        plt.title(f'Loss Curve by {optimizer_name}')
        plt.savefig(f'results\\loss_curve_{optimizer_name}.png')  # 保存图像为 loss_curve.png 文件,可根据需求修改文件名和路径
        plt.close()

        correct = 0
        total = 0
        with torch.no_grad():
            for data in testloader:
                images, labels = data
                images = images.to(device)
                labels = labels.to(device)
                outputs = model(images)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

        print(f'Optimizer: {optimizer_name} -- Accuracy of the network on the 10000 test images: {100 * correct / total}%')

http://www.kler.cn/a/444456.html

相关文章:

  • Java中的方法重写:深入解析与最佳实践
  • 【libuv】Fargo信令1:client发connect消息给到server
  • UE UMG 多级弹出菜单踩坑
  • 提炼关键词的力量:AI驱动下的SEO优化策略
  • 【开源项目】数字孪生轨道~经典开源项目数字孪生智慧轨道——开源工程及源码
  • 拼多多电子面单接入:常见问题及专业解决方案
  • 美创科技完成新一轮融资!
  • Linux-Profile工具
  • java全栈day19--Web后端实战(java操作数据库3)
  • mac uniapp 转为微信小程序开发
  • Python构造方法:对象的“开机启动程序”
  • windows C#-方法概述(上)
  • HCIE-day7
  • 大数据治理实战
  • 小鹏“飞行汽车”上海首飞,如何保障智能出行的安全性?
  • 社区版 IDEA 开发webapp 配置tomcat
  • C# 方法的参数主要有四种类型:值参数、引用参数ref 、输出参数out、可变参数params
  • React 项目引入 svg 图片为 undefined 情况
  • SpringBoot自己写的maven项目-配置文件提示
  • java Kafka批量消费和单个消费消息
  • SQL 查询方式比较:子查询与自连接
  • LabVIEW与PLC点位控制及OPC通讯
  • 如何处理对象的状态变化?如何实现工厂模式?
  • 如何实现一套完整的CI/CD?
  • 当我用影刀AI Power做了一个旅游攻略小助手
  • 【Javaweb】第一篇上,什么是web?