当前位置: 首页 > article >正文

【AI知识点】机器学习中的常用优化算法(梯度下降、SGD、Adam等)

更多AI知识点总结见我的专栏:【AI知识点】
AI论文精读、项目和一些个人思考见我另一专栏:【AI修炼之路】
有什么问题、批评和建议都非常欢迎交流,三人行必有我师焉😁


1. 什么是优化算法?

在机器学习中优化算法(Optimization Algorithm) 的任务是找到模型参数(如权重、偏置等),使得损失函数(例如均方误差、交叉熵等)最小化。损失函数度量的是模型预测值与真实标签之间的误差。优化算法通过不断调整模型的参数,使损失函数达到全局或局部最小值。

在神经网络中,优化算法需要通过反向传播(Backpropagation)计算每个参数对损失函数的导数(即梯度),并根据这些梯度更新模型的参数。


2. 基于梯度的优化算法

这些算法是深度学习中最常用的优化方法,通过计算梯度来找到损失函数最小的方向。

a. 梯度下降(Gradient Descent)

梯度下降是最基本的优化算法,核心思想是:朝着使损失函数减少的方向更新参数,直到达到最小值。

  • 更新规则
    θ = θ − α ⋅ ∇ θ J ( θ ) \theta = \theta - \alpha \cdot \nabla_\theta J(\theta) θ=θαθJ(θ)
    其中:
    • θ \theta θ 是模型的参数。
    • α \alpha α 是学习率,控制每次更新的步长。
    • ∇ θ J ( θ ) \nabla_\theta J(\theta) θJ(θ) 是损失函数 J ( θ ) J(\theta) J(θ) 对参数 θ \theta θ 的梯度。

b. 随机梯度下降(Stochastic Gradient Descent, SGD)

梯度下降的一个问题是,当数据集很大时,计算所有样本的梯度会很耗时。随机梯度下降(SGD) 是对梯度下降的改进,每次迭代只使用一个数据点来计算梯度,从而大大加快了参数更新。

c. 小批量梯度下降(Mini-batch Gradient Descent)

这是梯度下降和随机梯度下降的折中版本。它通过对一小部分数据(称为mini-batch)进行梯度计算和更新,这样既加快了计算速度,又保持了一定的稳定性。

d. 动量法(Momentum)

SGD 更新参数时每次依赖于当前梯度的方向,但有时可能会在方向上震荡。动量法通过加入“动量”项,积累过去几次梯度的方向,使得优化算法能够更快速地朝着最优解的方向前进。

  • 更新规则
    v t = β v t − 1 + α ∇ θ J ( θ ) v_t = \beta v_{t-1} + \alpha \nabla_\theta J(\theta) vt=βvt1+αθJ(θ)
    θ = θ − v t \theta = \theta - v_t θ=θvt
    其中, β \beta β 是动量项的系数。

e. RMSProp

RMSProp 是另一种改进的优化算法,它对每个参数都使用不同的学习率,通过对每个参数的梯度平方进行平滑加权平均,使得参数的更新步长更加合适。

  • 更新规则
    E [ ∇ θ 2 J ( θ ) ] t = β E [ ∇ θ 2 J ( θ ) ] t − 1 + ( 1 − β ) ∇ θ 2 J ( θ ) E[\nabla_\theta^2 J(\theta)]_t = \beta E[\nabla_\theta^2 J(\theta)]_{t-1} + (1 - \beta) \nabla_\theta^2 J(\theta) E[θ2J(θ)]t=βE[θ2J(θ)]t1+(1β)θ2J(θ)
    θ = θ − α E [ ∇ θ 2 J ( θ ) ] t + ϵ ∇ θ J ( θ ) \theta = \theta - \frac{\alpha}{\sqrt{E[\nabla_\theta^2 J(\theta)]_t + \epsilon}} \nabla_\theta J(\theta) θ=θE[θ2J(θ)]t+ϵ αθJ(θ)
    其中, ϵ \epsilon ϵ 是一个很小的数,用于避免除以零。

f. Adam(Adaptive Moment Estimation)

Adam 是目前深度学习中最常用的优化算法之一,它结合了动量法RMSProp的优点。Adam 同时对一阶和二阶矩进行估计,能够自适应地调整每个参数的学习率。

  • 更新规则
    Adam 分别维护了两个动量变量:

    • 一阶动量(梯度的加权平均): m t m_t mt
    • 二阶动量(梯度平方的加权平均): v t v_t vt

    m t = β 1 m t − 1 + ( 1 − β 1 ) ∇ θ J ( θ ) m_t = \beta_1 m_{t-1} + (1 - \beta_1) \nabla_\theta J(\theta) mt=β1mt1+(1β1)θJ(θ)
    v t = β 2 v t − 1 + ( 1 − β 2 ) ∇ θ 2 J ( θ ) v_t = \beta_2 v_{t-1} + (1 - \beta_2) \nabla_\theta^2 J(\theta) vt=β2vt1+(1β2)θ2J(θ)
    然后对动量进行偏差校正:
    m ^ t = m t 1 − β 1 t , v ^ t = v t 1 − β 2 t \hat{m}_t = \frac{m_t}{1 - \beta_1^t}, \quad \hat{v}_t = \frac{v_t}{1 - \beta_2^t} m^t=1β1tmt,v^t=1β2tvt
    最终更新参数:
    θ = θ − α m ^ t v ^ t + ϵ \theta = \theta - \alpha \frac{\hat{m}_t}{\sqrt{\hat{v}_t} + \epsilon} θ=θαv^t +ϵm^t


3. 梯度下降的图示

图片来源:https://mlpills.dev/machine-learning/gradient-descent/

这张图形象地说明了梯度下降的工作原理:从一个随机的初始参数开始,经过多次迭代更新,逐步逼近最低的损失值,最终找到最佳的模型参数。

  1. 坐标轴

    • 横轴(w):表示模型的参数(权重),它是通过优化调整的变量。
    • 纵轴(Cost):表示模型的成本或损失值,即模型预测与实际结果之间的误差。
  2. 随机初始值

    • 图中左侧标注的“Random initial value”表示算法开始时模型参数的随机初始值。优化过程从这个点开始。
  3. 学习步骤

    • 图中多个蓝色圆点表示算法在每次迭代中的参数值。每个点都对应一个特定的成本值。
    • “Learning step”表示每次迭代中,算法根据当前的梯度(导数)调整参数,以降低成本。每次调整的幅度和方向由学习率决定。
  4. 最小值

    • 图中黄色圆点标识了成本函数的最小值,表示在这个参数值下,模型的预测效果最好,损失最小。
  5. 下降路径

    • 蓝色圆点之间的连接线展示了模型在参数空间中逐步接近最小值的过程。这条路径表明,随着迭代的进行,模型参数不断调整,成本值逐渐降低。

4. 局部最优解和全局最优解

在复杂的损失函数中,可能会存在多个局部最优解。优化算法的目标是找到全局最优解,即损失函数的全局最小值。然而,梯度下降类算法可能会陷入局部最优解,因此一些改进的算法(如动量法、Adam)引入了额外的策略来帮助模型跳出局部最优解。

  • 局部最优解:损失函数的一个小范围内的最小值,但不是全局最小值。
  • 全局最优解:整个损失函数范围内的最小值。

图解:

图片来源:https://easyai.tech/en/ai-definition/gradient-descent/#google_vignette

这张图展示了梯度下降(Gradient Descent) 的概念。图中呈现的三维曲面代表了一个目标函数,通常是损失函数,反映了模型参数与损失之间的关系。黑色箭头表示梯度的方向。不同的点代表不同的参数组合,曲面的高低则表示损失值的大小。其中最低的凹点就是全局最优解,而不是最低点的其他凹点则代表各种局部最优解


5. 优化算法的比较

算法优点缺点
梯度下降简单易懂,适合小规模数据集。计算量大,尤其是大数据集时速度慢。
SGD快速更新参数,适合大规模数据集。收敛不稳定,路径波动大,需要调节学习率。
动量法减少梯度震荡,加快收敛。需要调整动量参数 β \beta β,对不同问题敏感。
RMSProp适应性学习率,避免步长过大或过小,适合深度网络。需要调整超参数 β \beta β,在某些任务上表现不稳定。
Adam自适应学习率,结合动量和 RMSProp 的优点,广泛用于深度学习。需要调整较多的超参数,对学习率敏感,可能导致局部最优解。

三种 Gradient Descent 的形象图示:

在这里插入图片描述
图片来源:https://www.nomidl.com/machine-learning/what-is-gradient-descent-batch-gradient-descent-stochastic-gradient-descent-mini-batch-gradient-descent/


6. 如何选择合适的优化算法

选择优化算法时,需要根据具体的任务需求、数据特点和模型架构来选择合适的算法。以下是一些常见的选择依据:

a. 数据规模

  • 小规模数据集:可以使用标准的梯度下降,因为计算量不大。
  • 大规模数据集:通常使用随机梯度下降(SGD)小批量梯度下降(Mini-batch Gradient Descent)。这些算法对大规模数据更有效。

b. 模型复杂性

  • 浅层模型:如逻辑回归、线性回归等浅层模型,使用SGD动量法 可以取得良好的效果。
  • 深层神经网络:深度学习通常使用AdamRMSProp,它们能够自动调整学习率,适应深度网络中的复杂性和高维性。

c. 收敛速度

  • 如果需要快速收敛,并且可以承受一定的波动性,可以使用SGD动量法
  • 如果需要更加平稳的收敛过程,建议使用AdamRMSProp,这些算法通过自适应调整学习率来保证收敛的平稳性。

7. 超参数调整的重要性

所有的优化算法都有一些关键的超参数,如学习率( α \alpha α)、动量系数( β \beta β)、RMSProp 和 Adam 中的动量参数和二阶动量参数等。这些超参数的选择对于模型性能的影响非常大。

a. 学习率

  • 学习率决定了每次参数更新的步长。学习率太大,可能会导致跳过最优解;学习率太小,模型收敛速度太慢,甚至可能陷入局部最优解。

b. 动量系数

  • 动量系数(通常记为 β \beta β)用于在动量法和 Adam 中,它决定了过去梯度的影响。动量系数过大会导致优化过程“过冲”,而动量系数过小则无法有效加速收敛。

c. 自适应学习率

  • 像 Adam 和 RMSProp 这样的优化算法会根据每个参数的梯度历史自动调整学习率。虽然这些算法自适应学习率,但仍然需要仔细调整初始学习率和其他超参数,才能获得良好的性能。

8. 一些常见的优化技巧

在深度学习和机器学习中,优化算法的性能很大程度上取决于使用者是否有效地结合了各种优化技巧。以下是一些常见的优化技巧:

a. 学习率衰减

  • 在训练的早期使用较大的学习率来加快收敛速度,然后随着训练进行逐渐减小学习率,帮助模型在最优解附近进行更细致的搜索。这可以避免模型在靠近最优解时仍然使用较大的步长导致震荡。

b. 提前停止(Early Stopping)

  • 提前停止是一种防止过拟合的技巧,它会监控模型在验证集上的表现,当验证集上的损失不再降低时,就提前停止训练。这避免了模型过度拟合训练数据,并可以加快训练过程。

c. 批归一化(Batch Normalization)

  • 批归一化在每一层对输入数据进行归一化,使得神经网络中各层的输入数据分布更加稳定,能够加速训练并提高模型的收敛速度。

d. 梯度裁剪(Gradient Clipping)

  • 当模型中梯度过大时,可能会导致梯度爆炸问题,尤其是在深层神经网络或循环神经网络(RNN)中。梯度裁剪将梯度限制在一个固定的范围内,从而防止梯度过大导致不稳定的训练。

9. 总结

优化算法 是机器学习和深度学习的核心工具,它通过调整模型参数,使损失函数最小化,从而提高模型的性能。不同的优化算法适用于不同的数据规模、模型复杂度和任务类型,常见的算法包括梯度下降、动量法、Adam等。选择合适的优化算法和调整超参数是成功训练机器学习模型的关键。结合学习率衰减、提前停止、批归一化等优化技巧,模型的训练效率和效果可以显著提高。


http://www.kler.cn/news/340638.html

相关文章:

  • sqli-labs less-20 less-21 less-22 cookie注入
  • 【JNI】hello world
  • Spring 事务传播机制:深入理解与实践
  • 20241005给荣品RD-RK3588-AHD开发板刷Rockchip原厂的Android12时使用iperf3测网速
  • 某象异形滑块99%准确率方案
  • Springboot 整合 logback 日志框架
  • 校园资源共享新方案:基于SpringBoot的实现
  • 基于SpringBoot+Vue的在线投票系统
  • 【Unity】unity安卓打包参数(个人复习向/有不足之处欢迎指出/侵删)
  • Matter蓝牙解析
  • 06-Cesium 中动态处理与圆形扩散材质相关的属性
  • [nmap] 端口扫描工具的下载及详细安装使用过程(附有下载文件)
  • Java 中的 PO、VO、DAO、BO、DTO、POJO
  • 文件分块上传
  • 黑神话:仙童,数据库自动反射魔法棒
  • 【自动驾驶汽车通讯协议】I2C(IIC)总线通讯技术详解
  • Windows环境安装CentOS7
  • Lumerical 脚本语言——操作实体对象(Manipulating objects)
  • unix进程间通信信号的有效实践
  • 用KLineChart绘制股票行情K线图