当前位置: 首页 > article >正文

选择正确优化器,加速深度学习模型训练


❤️觉得内容不错的话,欢迎点赞收藏加关注😊😊😊,后续会继续输入更多优质内容❤️

👉有问题欢迎大家加关注私戳或者评论(包括但不限于NLP算法相关,linux学习相关,读研读博相关......)👈

优化器

(封面图由文心一格生成)

选择正确优化器,加速深度学习模型训练

随着深度学习技术的不断发展和应用,深度学习模型的复杂性也在不断增加。因此,模型训练成为深度学习中最为耗时的过程之一。优化器的选择可以对模型训练的效率和准确性产生巨大影响。在本文中,我们将介绍深度学习中常用的优化器以及其原理,并通过代码实现来说明它们的效果和适用范围。

1. 优化器的选择

深度学习中的优化器可以看做是一种求解损失函数最小值的方法,其主要作用是更新模型的参数。常见的优化器包括随机梯度下降(SGD)、动量(Momentum)、Adagrad、RMSProp、Adam等。不同的优化器在更新参数的方式和速度上存在差异,因此选择合适的优化器可以加速模型的训练过程,提高模型的准确率和泛化能力。

2. SGD优化器

SGD是深度学习中最为基础的优化器之一,SGD优化器的原理比较简单,它通过沿着梯度的负方向进行参数更新,以最小化损失函数。然而,由于SGD的更新规则比较简单,因此容易陷入局部最优解,同时在面对参数空间非凸、梯度值变化较大时,收敛速度也比较慢。

代码实现:

import torch
import torch.nn as nn
import torch.optim as optim

# 定义模型和损失函数
model = nn.Linear(10, 2)
criterion = nn.CrossEntropyLoss()

# 定义SGD优化器
optimizer = optim.SGD(model.parameters(), lr=0.01)

# 进行模型训练
for epoch in range(10):
    # 模拟数据
    inputs = torch.randn(1, 10)
    targets = torch.randint(2, (1,))
    
    # 前向传播
    outputs = model(inputs)
    loss = criterion(outputs, targets)
    
    ## 反向传播和参数更新
	optimizer.zero_grad()
	loss.backward()
	optimizer.step()
	
	print('Epoch: {}, Loss: {:.4f}'.format(epoch+1, loss.item()))

3. Momentum优化器

Momentum优化器是SGD优化器的改进版本之一,它可以通过加入动量项来避免SGD陷入局部最优解。动量项可以看作是上一轮梯度的指数移动平均值,它可以提高梯度更新的稳定性,减少梯度更新的震荡。Momentum优化器的更新规则可以表示为:

v t = γ v t − 1 + α ∇ θ L ( θ t ) v_{t} = \gamma v_{t-1} + \alpha\nabla_{\theta} L(\theta_{t}) vt=γvt1+αθL(θt)
θ t + 1 = θ t − v t \theta_{t+1} = \theta_{t} - v_{t} θt+1=θtvt

其中, v t v_{t} vt表示当前梯度的动量, γ \gamma γ表示动量系数,通常取0.9左右。

Momentum优化器的原理是通过动量的引入来加速参数更新过程,减少参数更新的震荡和方向变化,从而提高模型的收敛速度和准确率。

代码实现:

import torch
import torch.nn as nn
import torch.optim as optim

# 定义模型和损失函数
model = nn.Linear(10, 2)
criterion = nn.CrossEntropyLoss()

# 定义Momentum优化器
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)

# 进行模型训练
for epoch in range(10):
    # 模拟数据
    inputs = torch.randn(1, 10)
    targets = torch.randint(2, (1,))
    
    # 前向传播
    outputs = model(inputs)
    loss = criterion(outputs, targets)
    
    # 反向传播和参数更新
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    print('Epoch: {}, Loss: {:.4f}'.format(epoch+1, loss.item()))

4. Adagrad优化器

Adagrad优化器是一种自适应学习率的优化器,它可以根据每个参数的历史梯度值来自适应地调整学习率大小。Adagrad优化器的原理是针对每个参数自适应地调整学习率大小,将学习率设置为每个参数历史梯度值的加权和的平方根,从而使得每个参数的学习率都可以得到有效的调整。Adagrad优化器可以适用于训练复杂的非凸模型,但可能会导致学习率过小,影响模型的收敛速度。

代码实现:

import torch
import torch.nn as nn
import torch.optim as optim

# 定义模型和损失函数
model = nn.Linear(10, 2)
criterion = nn.CrossEntropyLoss()

# 定义Adagrad优化器
optimizer = optim.Adagrad(model.parameters(), lr=0.01)

# 进行模型训练
for epoch in range(10):
    # 模拟数据
    inputs = torch.randn(1, 10)
    targets = torch.randint(2, (1,))
    
    # 前向传播
    outputs = model(inputs)
    loss = criterion(outputs, targets)
    
    # 反向传播和参数更新
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    print('Epoch: {}, Loss: {:.4f}'.format(epoch+1, loss.item()))

5. RMSProp优化器

RMSProp优化器也是一种自适应学习率的优化器,它可以根据每个参数的历史梯度平方和的加权平均值来自适应地调整学习率大小。RMSProp优化器的原理是通过对梯度平方和进行自适应调整学习率大小,可以减少学习率的波动和过大的变化,从而提高模型的收敛速度和准确率。

代码实现:

import torch
import torch.nn as nn
import torch.optim as optim

# 定义模型和损失函数
model = nn.Linear(10, 2)
criterion = nn.CrossEntropyLoss()

# 定义RMSProp优化器
optimizer = optim.RMSprop(model.parameters(), lr=0.01)

# 进行模型训练
for epoch in range(10):
    # 模拟数据
    inputs = torch.randn(1, 10)
    targets = torch.randint(2, (1,))
    
    ## 前向传播
	outputs = model(inputs)
	loss = criterion(outputs, targets)
	
	# 反向传播和参数更新
	optimizer.zero_grad()
	loss.backward()
	optimizer.step()
	
	print('Epoch: {}, Loss: {:.4f}'.format(epoch+1, loss.item()))

6. Adam优化器

Adam优化器是一种综合了动量和自适应学习率的优化器,它可以在保持梯度方向稳定的同时自适应地调整学习率。

Adam优化器的原理是综合动量和自适应学习率的优点,可以在保证梯度更新的稳定性的同时自适应调整学习率,从而提高模型的收敛速度和准确率。

代码实现:

import torch
import torch.nn as nn
import torch.optim as optim

# 定义模型和损失函数
model = nn.Linear(10, 2)
criterion = nn.CrossEntropyLoss()

# 定义Adam优化器
optimizer = optim.Adam(model.parameters(), lr=0.01)

# 进行模型训练
for epoch in range(10):
    # 模拟数据
    inputs = torch.randn(1, 10)
    targets = torch.randint(2, (1,))
    
    # 前向传播
    outputs = model(inputs)
    loss = criterion(outputs, targets)
    
    # 反向传播和参数更新
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    print('Epoch: {}, Loss: {:.4f}'.format(epoch+1, loss.item()))

7. AdamW优化器

AdamW优化器是基于Adam优化器改进的一种优化器,主要是为了解决Adam优化器在权重衰减方面的缺陷。

在Adam优化器中,权重衰减是通过对参数进行L2正则化实现的。但是,在实践中,L2正则化可能会导致权重矩阵的方向向量长度不断缩小,从而影响模型的泛化能力。因此,为了解决这个问题,AdamW优化器引入了一种新的权重衰减方式,即对权重的L2范数进行惩罚,而不是对参数进行L2正则化。

通过引入L2权重衰减,AdamW优化器可以在保证权重方向向量长度不会过小的同时,调整权重的更新幅度,从而提高模型的泛化能力和稳定性。

代码实现:

import torch
import torch.nn as nn
import torch.optim as optim

# 定义模型和损失函数
model = nn.Linear(10, 2)
criterion = nn.CrossEntropyLoss()

# 定义AdamW优化器
optimizer = optim.AdamW(model.parameters(), lr=0.01, weight_decay=0.01)

# 进行模型训练
for epoch in range(10):
    # 模拟数据
    inputs = torch.randn(1, 10)
    targets = torch.randint(2, (1,))
    
    # 前向传播
    outputs = model(inputs)
    loss = criterion(outputs, targets)
    
    # 反向传播和参数更新
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    print('Epoch: {}, Loss: {:.4f}'.format(epoch+1, loss.item()))

8. 总结

在深度学习中,选择合适的优化器可以大大提高模型的训练效率和准确率。本文介绍了几种常见的优化器,包括SGD、Momentum、Adagrad、RMSProp、Adam和AdamW优化器,并对它们的原理进行了详细介绍,并通过代码实现说明了它们的效果和适用范围。在实际应用中,选择合适的优化器需要根据模型的特性、数据集的规模和复杂度等因素来综合考虑。

最后,希望本文能够对读者了解深度学习中优化器的选择提供一些帮助。


❤️觉得内容不错的话,欢迎点赞收藏加关注😊😊😊,后续会继续输入更多优质内容❤️

👉有问题欢迎大家加关注私戳或者评论(包括但不限于NLP算法相关,linux学习相关,读研读博相关......)👈


http://www.kler.cn/a/5791.html

相关文章:

  • trf 4.10安装与使用-生信工具42
  • python学习笔记—17—数据容器之字符串
  • Jenkins触发器--在其他项目执行后构建
  • 【数据库】四、数据库管理与维护
  • 贪心算法(五)
  • 视频编辑最新SOTA!港中文Adobe等发布统一视频生成传播框架——GenProp
  • 谈谈面向对象编程和面向过程编程
  • 容器、虚拟机和 Docker
  • 基于PHP的英语四六级在线模拟考试平台(论文+源码)_kaic
  • Python3 File flush() 方法、 Python3 File write() 方法
  • Chapter9.1:线性系统状态空间基础(上)
  • Java 8 中需要知道的4个函数式接口-Function、Consumer、Supplier、Predicate
  • 吉时利源表出现数据不准怎么办?
  • 通道流量设计
  • 自动驾驶路径规划与控制:讨论自动驾驶车辆的路径规划算法,如A*、Dijkstra等,以及控制策略,如PID控制等
  • 【数据结构与算法】线性表--数组
  • 利用json-server快速在本地搭建一个JSON服务
  • leetcode125:验证回文串
  • JavaScript函数及面向对象
  • GPT-4报告解读
  • 【HDR图像处理】HDR图像,HDRI技术的一些基础概念 | GPT对话记录
  • Appium - 自动化测试框架 - 工作原理、环境搭建
  • SpringBoot整合SpringSecurity权限控制(动态拦截url+单点登录)
  • 多机器人集群网络通信协议分析
  • 设计模式-结构型模式-适配器模式
  • Python字符串常用方法