当前位置: 首页 > article >正文

PyTorch使用教程(5)-优化器

在深度学习中,优化器扮演着至关重要的角色。它们负责根据损失函数计算出的梯度更新模型的参数,从而达到最小化损失的目的。在PyTorch中的优化器实现在torch.optim模块。torch.optim模块提供了多种优化算法,每种算法都有其独特的特点和适用场景。本文将详细介绍这些优化器、它们的使用方法、基类Optimizer以及学习率调整的策略。

1. 常用的优化器算法的介绍

1.1SGD(Stochastic Gradient Descent,随机梯度下降)

SGD是最基本的优化算法之一。它每次从训练集中随机选择一个样本来计算梯度,并更新参数。SGD的优点是计算速度快,但缺点是可能会因为样本的随机性导致梯度更新不稳定,收敛速度慢,甚至可能陷入局部最优。

torch.optim.SGD(params, lr=0.001, momentum=0, dampening=0, weight_decay=0, nesterov=False, *, maximize=False, foreach=None, differentiable=False, fused=None)

参数

  • lr(learning rate):学习率,控制参数更新的步长。
  • momentum:动量,用于加速SGD在相关方向上的收敛,同时抑制震荡。
  • dampening:动量的衰减系数,通常和momentum一起使用。
  • weight_decay:权重衰减,用于防止过拟合。
  • nesterovs:是否使用Nesterov动量。

1.2 Adam(Adaptive Moment Estimation)

Adam是一种自适应学习率的优化算法,结合了动量法和RMSprop算法的优点。它为每个参数维护一阶矩估计(动量)和二阶矩估计(未中心化的方差),从而动态调整每个参数的学习率。

torch.optim.Adam(params, lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, amsgrad=False, *, foreach=None, maximize=False, capturable=False, differentiable=False, fused=None)

参数

  • lr:学习率。
  • betas:用于计算梯度及其平方的指数移动平均的系数,常用值为(0.9, 0.999)。
  • eps:防止分母为零的项,添加到分母中以提高数值稳定性。
  • weight_decay:权重衰减。
  • amsgrad:是否使用AMSGrad变体,该变体在某些情况下可以提高Adam的稳定性。

1.3 Adagrad(Adaptive Gradient Algorithm)

Adagrad是一种自适应学习率的优化算法,它根据每个参数的梯度历史来调整学习率。对于稀疏数据,Adagrad表现良好,但对于非稀疏数据,随着训练的进行,学习率会逐渐减小到零,导致模型无法继续学习。
参数

  • lr:学习率。
  • lr_decay:学习率的衰减系数。
  • weight_decay:权重衰减。
  • initial_accumulator_value:累加器的初始值,用于防止分母为零。
  • eps:提高数值稳定性的项。

1.4 RMSprop(Root Mean Square Propagation)

RMSprop是Adagrad的一种变体,旨在解决Adagrad学习率过早衰减的问题。它使用指数衰减的平均来计算梯度的平方,从而避免了梯度下降时过早减小学习率。
参数

  • lr:学习率。
  • alpha:梯度平方的指数衰减率。
  • eps:防止分母为零的项。
  • weight_decay:权重衰减。
  • momentum:动量。
  • centered:是否计算中心化的RMSprop。

2. 如何使用优化器

2.1 构建优化器

在PyTorch中,构建优化器非常简单。首先,你需要有一个模型,然后选择一个优化器算法,并将模型的参数传递给优化器。

import torch
import torch.nn as nn
import torch.optim as optim

# 定义一个简单的模型
model = nn.Sequential(
    nn.Linear(10, 5),
    nn.ReLU(),
    nn.Linear(5, 1)
)

# 选择优化器,这里以SGD为例
optimizer = optim.SGD(model.parameters(), lr=0.01)

2.2 参数的选项

优化器支持为不同的参数组设置不同的选项。这可以通过传递一个包含字典的列表来实现,每个字典定义一个参数组及其对应的选项。

# 为不同的参数组设置不同的学习率
optimizer = optim.SGD([
    {'params': model.layer1.parameters(), 'lr': 0.01},
    {'params': model.layer2.parameters(), 'lr': 0.001}
], lr=0.01)  # 注意这里的lr是默认的学习率,可以被参数组中的lr覆盖

2.3 优化步骤
在训练过程中,你需要按照以下步骤来使用优化器:

  • 清除梯度:在每次迭代开始时,使用optimizer.zero_grad()清除所有参数的梯度。
  • 前向传播:将输入数据传递给模型,计算输出。
  • 计算损失:使用损失函数计算输出与目标之间的损失。
  • 反向传播:调用loss.backward()计算损失对每个参数的梯度。
  • 更新参数:使用optimizer.step()更新模型的参数。

下面是一个完整的训练循环示例:

# 假设我们有一个数据集和标签
data = torch.randn(100, 10)
labels = torch.randn(100, 1)

# 定义损失函数
criterion = nn.MSELoss()

# 训练循环
for epoch in range(100):  # 假设训练100个epoch
    # 清除梯度
    optimizer.zero_grad()
    # 前向传播
    outputs = model(data)
    # 计算损失
    loss = criterion(outputs, labels)
    # 反向传播
    loss.backward()
    # 更新参数
    optimizer.step()
    # 打印损失(可选)
    print(f'Epoch {epoch+1}, Loss: {loss.item()}')

3. 基类.Optimizer的介绍

torch.optim.Optimizer是所有优化器的基类。它提供了优化器的基本框架和接口,包括参数的注册、梯度的清零、参数的更新等。当你实现一个自定义的优化器时,你需要继承这个基类,并实现step()和zero_grad()方法。
Optimizer类的主要属性和方法包括:

  • param_groups:一个列表,包含所有参数组及其对应的选项。
  • add_param_group():向优化器中添加一个新的参数组。
  • zero_grad():清除所有参数的梯度。
  • step():更新所有参数的梯度。
  • state_dict():获取优化器的状态字典,用于保存和加载。
  • load_state_dict():从状态字典中加载优化器的状态。

4. 学习率调整的常用方式

在PyTorch中,学习率调整是优化过程中的关键环节,它对于模型的收敛速度、训练稳定性以及最终性能都有着至关重要的影响。PyTorch的torch.optim.lr_scheduler模块为我们提供了多种学习率调整策略,这些策略可以根据训练的需要动态地调整学习率。接下来,我将详细介绍几种常用的学习率调整方式,并提供相应的Python代码示例。

4.1 StepLR

StepLR是一种简单的学习率调整策略,它在每个epoch(或指定的step_size步)之后,将学习率乘以一个固定的因子gamma(通常小于1),从而实现学习率的逐步衰减。

import torch.optim as optim
from torch.optim.lr_scheduler import StepLR

# 假设我们已经有了一个优化器optimizer
optimizer = optim.SGD(model.parameters(), lr=0.1)

# 定义StepLR优化器,每30个epoch学习率乘以0.1
scheduler = StepLR(optimizer, step_size=30, gamma=0.1)

# 训练循环
for epoch in range(100):
    # ...(训练代码,包括前向传播、计算损失、反向传播等)
    
    # 在每个epoch之后调整学习率
    scheduler.step()
    
    # 打印当前学习率(可选)
    print(f'Epoch {epoch+1}, Learning Rate: {scheduler.get_last_lr()}')

在这里插入图片描述

4.2 MultiStepLR

MultiStepLR允许我们在指定的epoch(或step)列表处调整学习率,每次调整时都将学习率乘以一个固定的因子gamma。

from torch.optim.lr_scheduler import MultiStepLR
# 定义MultiStepLR优化器,在指定的epoch处调整学习率
scheduler = MultiStepLR(optimizer, milestones=[30, 80], gamma=0.1)

在这里插入图片描述

4.3 LinearLR

LinearLR 是 PyTorch 中的一个学习率调度器,它按照线性方式调整每个参数组的学习率。具体来说,它会在指定的迭代次数(total_iters)内,从初始学习率(通过 start_factor 调整后的学习率)线性地过渡到结束学习率(通过 end_factor 调整后的学习率)。

import torch
import torch.optim as optim
from torch.optim.lr_scheduler import LinearLR

# 假设我们有一个模型 model
# 创建一个优化器实例
optimizer = optim.SGD(model.parameters(), lr=0.1)

# 创建一个 LinearLR 调度器实例
scheduler = LinearLR(optimizer, start_factor=1.0, end_factor=0.1, total_iters=20)

# 训练循环
for epoch in range(20):
    for input, target in dataset:
        optimizer.zero_grad()
        output = model(input)
        loss = loss_fn(output, target)
        loss.backward()
        optimizer.step()
    
    # 在每个epoch后调用调度器的step方法
    scheduler.step()
    
    # 打印当前学习率(可选)
    for param_group in optimizer.param_groups:
        print(f'Epoch {epoch+1}, Learning Rate: {param_group["lr"]}')

在这里插入图片描述

4.4 CosineAnnealingLR

CosineAnnealingLR使用余弦退火策略调整学习率,它模拟了余弦函数在一个周期内的变化,通常用于模拟学习率在训练过程中的周期性变化。

Copy Code
from torch.optim.lr_scheduler import CosineAnnealingLR
# 定义CosineAnnealingLR优化器,学习率按余弦退火策略变化
scheduler = CosineAnnealingLR(optimizer, T_max=100)

在这里插入图片描述

4.5 LambdaLR

LambdaLR允许用户自定义一个lambda函数来调整学习率,这个函数以当前epoch数为输入,并输出一个学习率的缩放因子。

from torch.optim.lr_scheduler import LambdaLR
# 定义一个lambda函数,根据epoch数调整学习率
lambda_fn = lambda epoch: 0.95 ** epoch
# 定义LambdaLR优化器,使用自定义的lambda函数调整学习率
scheduler = LambdaLR(optimizer, lr_lambda=lambda_fn)

在这里插入图片描述

5.小结

PyTorch的优化器是核心模块之一,在了解到基本模块分布后,需根据模型的实际情况,选择相应的优化器算法和学习率的配置方式。


http://www.kler.cn/a/514280.html

相关文章:

  • 1/20赛后总结
  • 基于Redis实现短信验证码登录
  • 偏序关系.
  • 目标跟踪算法发展简史
  • Spring Boot 项目启动报错 “找不到或无法加载主类” 解决笔记
  • Linux——入门基本指令汇总
  • Android 绘图基础:Canvas,Paint,RectF,Paint类
  • 25/1/21 算法笔记<ROS2> 话题通信接口
  • LabVIEW太赫兹二维扫描成像系统
  • 【Nacos】Nacos快速上手
  • Models如何使用Gorm与数据库进行交互?
  • 利用Kubespray安装生产环境的k8s集群-准备篇
  • centos哪个版本建站好?centos最稳定好用的版本
  • 音频入门(二):音频数据增强
  • 【Elasticsearch】inference ingest pipeline
  • 缓存之美:万文详解 Caffeine 实现原理(上)
  • 多线程杂谈:惊群现象、CAS、安全的单例
  • 一文大白话讲清楚webpack基本使用——6——热更新及其原理
  • Bash语言的安全开发
  • 设计模式Python版 GOF设计模式
  • 【大厂面试题】软件测试面试题整理(附答案)
  • 消息队列篇--原理篇--RabbitMQ和Kafka对比分析
  • Git【将本地代码推送到远程仓库】--初学者必看
  • 2025美赛数学建模B题思路+模型+代码+论文
  • 电脑开机出现Bitlock怎么办
  • solidity基础 -- 内联汇编