当前位置: 首页 > article >正文

深度学习(2):梯度下降

文章目录

  • 梯度下降
    • 梯度是什么
    • 常见梯度下降算法
  • 代码实现
    • 批量梯度下降

梯度下降

梯度是什么

类似y = ax + b这种单变量的函数来说,导数就是它的斜率,这种情况下可以说梯度就是导数。
但在多变量函数中,梯度是一个向量,其分量是各个单一变量的偏导数。这个向量指向函数增长最快的方向,其向量的模(大小)表示在那个方向上的最大变化率。

所以我们沿着梯度的反方向走,这就是下降最快的方向,这样就能够使得损失函数最快的下降了

常见梯度下降算法

  1. 批量梯度下降(Batch Gradient Descent)

    • 原理:在每次迭代中,使用整个训练集计算损失函数的梯度,然后更新参数。
    • 优点
      • 全局最优:每次更新都基于全数据,方向稳定。
      • 收敛稳定:梯度的方向一致,易于收敛。
    • 缺点
      • 计算开销大:对于大型数据集,计算梯度耗时长。
      • 缺乏在线学习能力:无法实时更新模型。
  2. 随机梯度下降(Stochastic Gradient Descent, SGD)

    • 原理:在每次迭代中,只使用一个样本计算梯度并更新参数。
    • 优点
      • 计算效率高:每次更新只需计算一个样本的梯度。
      • 在线学习:适合流式数据处理。
    • 缺点
      • 收敛不稳定:梯度受单个样本影响,可能产生较大波动。
      • 可能陷入局部最优:由于更新方向不稳定。
  3. 小批量梯度下降(Mini-batch Gradient Descent)

    • 原理:在每次迭代中,使用一小批样本(如32、64、128个)计算梯度并更新参数。
    • 优点
      • 权衡计算效率和稳定性:比批量方法快,比随机方法稳。
      • 利用矩阵运算:可充分利用GPU等硬件加速。
    • 缺点
      • 需要选择合适的批量大小:批量过小或过大会影响性能。
  4. 动量法(Momentum)

    • 原理:在参数更新中引入动量项,累计之前梯度的指数加权平均,公式为:

在这里插入图片描述

  • 优势
    • 加速收敛:在一致的梯度方向上加速移动。
    • 减少振荡:在梯度变化方向上抑制波动。
  1. Nesterov加速梯度(Nesterov Accelerated Gradient, NAG)

    • 原理:在动量法的基础上,先对参数进行一步预估,然后计算预估位置的梯度,公式为:
      在这里插入图片描述

    • 优势

      • 提前感知:对未来位置的梯度进行评估,提高了更新的准确性。
      • 更快收敛:比标准动量法具有更好的收敛速度。
  2. AdaGrad(Adaptive Gradient)

    • 原理:为每个参数适应性地调整学习率,累积梯度的平方和,公式为:

在这里插入图片描述

  • 优势

    • 自适应学习率:对频繁更新的参数降低学习率,稀疏参数仍保持较大学习率。
    • 适合稀疏数据:在自然语言处理等领域表现良好。
  • 缺点

    • 学习率单调递减:可能导致后期学习过慢或停止。
  1. RMSProp

    • 原理:对AdaGrad进行改进,采用梯度平方的指数加权移动平均,公式为:

在这里插入图片描述

  • 优势
    • 防止学习率过快下降:保持学习率在较为稳定的范围内。
    • 适合非平稳目标:在处理递归神经网络等问题时表现良好。
  1. Adam(Adaptive Moment Estimation)

    • 原理:结合Momentum和RMSProp,同时计算梯度的一阶和二阶矩的估计,公式为:

在这里插入图片描述

  • 优势
    • 自适应学习率:对每个参数进行动态调整。
    • 快速收敛:在实践中表现出优秀的性能和稳定性。
    • 广泛适用:已成为深度学习中最常用的优化算法之一。

代码实现

批量梯度下降

import matplotlib.pyplot as plt

# 准备数据集,线性关系
x_data = [1.0, 2.0, 3.0]
y_data = [2.0, 4.0, 6.0]

# 随机的初始化权重 
w = 1.0


# 找线性模型
def forward(x):
    return x * w


# 损失函数MSE
def loss(xs, ys):
    cost = 0 # 储存loss ^ 2的和
    for x, y in zip(xs, ys):
        y_pred = forward(x)
        cost += (y_pred - y) ** 2
    return cost / len(xs) # MSE


# 批量梯度下降:选取所有的样本做梯度下降
# 获取当前的梯度是多少
def gradient(xs, ys):
    grad = 0
    for x, y in zip(xs, ys):
        grad += 2 * x * (x * w - y)
    return grad / len(xs)


epoch_list = []
loss_list = []
print('predict (before training)', 4, forward(4))
for epoch in range(100):
    cost_val = loss(x_data, y_data)
    grad_val = gradient(x_data, y_data)
    w -= 0.01 * grad_val  # 0.01 学习率
    print('epoch:', epoch, 'w=', w, 'loss=', cost_val)
    epoch_list.append(epoch)
    loss_list.append(cost_val)

print('predict (after training)', 4, forward(4))
plt.plot(epoch_list, loss_list)
plt.ylabel('cost')
plt.xlabel('epoch')
plt.show()


http://www.kler.cn/a/318489.html

相关文章:

  • [Linux] Linux信号捕捉
  • 阿里云通义大模型团队开源Qwen2.5-Coder:AI编程新纪元
  • 基于微信小程序的乡村研学游平台设计与实现,LW+源码+讲解
  • Linux基础1
  • 字节跳动Android面试题汇总及参考答案(80+面试题,持续更新)
  • qt QKeySequence详解
  • Windows系统使用PHPStudy搭建Cloudreve私有云盘公网环境远程访问
  • OTTO奥托机器人开发总结
  • 2024java高频面试-数据库相关
  • 将python代码文件转成Cython 编译问题集
  • python中实用工具与自动化脚本
  • typename、非类型模板参数、模板参数的特化、模板类成员函数声明和定义分离、继承等的介绍
  • 滚雪球学SpringCloud[6.3讲]: 分布式日志管理与分析
  • 常见统计量与其抽样分布
  • python异步处理
  • [SDX35+WCN6856]SDX35 + WCN6856 WiFi 起来之后,使用终端连接会导致系统重启
  • dotnet4.0编译问题
  • 【系统架构设计师】专题:系统质量属性和架构评估
  • 康养为松,智能为鹤:华为全屋智能画出的松鹤长春图
  • 2024.9.24 数据分析
  • 努比亚z17努比亚NX563j原厂固件卡刷包下载_刷机ROM固件包下载-原厂ROM固件-安卓刷机固件网
  • 智慧城市主要运营模式分析
  • [附源码]宠物领养管理系统+SpringBoot
  • css实现居中的方法
  • C++ prime plus-4-编程练习
  • vue echarts tooltip使用动态模板