当前位置: 首页 > article >正文

梯度下降算法优化—随机梯度下降、小批次、动量、Adagrad等方法pytorch实现

现有不足

现有调整网络的方法是借助成本函数的梯度下降方法,也就是给函数作切线,不断逼近最优点,即成本函数为零的点。
梯度下降的一般公式为:
梯度下降公式
即根据每个节点成本函数的梯度进行更新,使用该方法有一些问题:
**1,计算量大,耗时长。**我们的训练数据往往是成千上万的,每条都反向传播,计算梯度再调整参数,这等计算量就算是计算机也吃不消,更何况现在是大数据的时代,耗费的时间更是要呈指数上升。
**2,易掉进局部最优的陷阱。**根据梯度下降,我们找到的往往是一个极值点,而非最值点,如何找到方法跳出局部最优的陷阱而找到最优解也是目前的一个不足。

传统梯度下降使用pytorch实现的一般思路是:
1,获取数据
2,定义损失函数
3,定义优化器
4,计算损失,并反向传播计算梯度
5,更新模型参数

定义一个最简单的线性模型y=w*x+b,损失函数为预测值和实际的差,训练模型的具体代码如下:

import torch
import matplotlib.pyplot as plt
import time

start_time=time.time()
# 定义训练数据
X = torch.tensor([1, 2, 3, 4], dtype=torch.float32)
Y = torch.tensor([2, 4, 6, 8], dtype=torch.float32)

# 初始化模型参数
w = torch.tensor(0.0, requires_grad=True)
b = torch.tensor(0.0, requires_grad=True)

# 定义损失函数
loss_fn = torch.nn.MSELoss()

# 定义优化器
optimizer = torch.optim.SGD([w, b], lr=0.01)
epochs = 100

# 批量梯度下降
for epoch in range(epochs):
    # 前向传播
    Y_pred = w * X + b
    # 计算损失
    loss = loss_fn(Y_pred, Y)
    # 反向传播
    loss.backward()
    # 更新参数
    optimizer.step()
    # 清空梯度
    optimizer.zero_grad()

# 输出结果
print(f"w = {w.item()}, b = {b.item()}")

# 时间统计
total_time=time.time()-start_time
print(f"time = {total_time}s")
# 绘制拟合直线
plt.scatter(X.numpy(), Y.numpy())
plt.plot(X.numpy(), (w * X + b).detach().numpy(), 'r')
plt.show()

输出结果如图:
批梯度下降输出结果

优化算法

上述问题的解决方案主要有两种:
1,通过数学或工程方法减少计算量。
2,优化路径或步长,使模型在获得最优解前尽量不走弯路(其实本质也是减少计算量)。

随机梯度下降—SGD

与传统批量梯度下降不同,随机梯度下降只选择一个节点的数据进行更新模型参数,该方法相比批量梯度下降不那么准确,但在循环过程中大致是朝着最优方向前进的,但相比批量梯度下降方法,大大提升了训练效率,算法是空间和时间的平衡,那我们现在就是准确性和时间的平衡。

在代码实现中,我们只需要在模型训练中选择一个数据参与更新参数过程即可,通过设置数据量,在模型训练中加载计算,区别的详细代码如下:

# 设置数据量为1
batch_size = 1
epochs = 100

# 随机梯度下降
for epoch in range(epochs):
    # 创建DataLoader,选择数据集中的一个
    loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(X, Y), batch_size=batch_size, shuffle=True)
    for x_batch, y_batch in loader:
        # 前向传播
        y_pred = w * x_batch + b
        # 计算损失
        loss = loss_fn(y_pred, y_batch)

该方法虽然减少了计算量,但因为其只选择一个数据作为更新依据,形成了一种对数据敏感的模型,在下降过程中易反复震荡,增加步数,但在震荡过程中造成其收敛方向的不确定性,一定程度上也能缓解算法陷入局部极值的陷阱。

小批次梯度下降—Mini-batch

该方法介于批量梯度下降和随机梯度下降之间,即每次选择数据集中的一部分作为更新模型的依据,该方法收敛速度快于批量梯度下降,收敛性波动小于随机梯度下降,代码实现中只需更改batch_size的大小即可。

动量梯度下降—MGD

批梯度下降每次参数的更新仅与上一次的梯度有关,而梯度往往是一种趋势,为了利用这种趋势加快步长(类似人在下坡时步子会变大),结合梯度的历史数据创造了动量梯度下降,实现在相关方向上的加速,并一定程度上抑制抖动的效果。

动量公式如下:
动量公式
参数更新公式
其中β代表历史数据的占比大小,在编码中我们可以通过设置momentum参数实现动量梯度下降,修改的代码如下:

momentum = 0.5
# 优化器中设置动量大小
optimizer = torch.optim.SGD([w, b], lr=0.01, momentum=momentum)

在该实验中我们可以推断,w越接近2,b越接近0说明训练效果越好,从该实验结果可以看出动量可以加速训练过程。
训练结果
且有了动量法,如果参数设置的够大,一定程度上我们也能跳过一些小的极值点,避免了陷入局部最小的陷阱中。

AdaGrad自适应梯度下降法

该方法针对学习力α,实现步长的自动调整,其学习率的更新公式为:自适应梯度下降
其中w表示特征梯度,epsilon表示很小的一个数,用于防止分母为零,即通过梯度平方和来调整学习率,使用该方法可以实现梯度大的位置减小学习速率,梯度小的位置增大学习率,而加了求和公式,使得其在稀疏特征中表现更好。

编码中,我们可以使用optimizer = torch.optim.Adagrad([w, b], lr=0.1)将优化器设置为Adagrad实现自适应梯度下降。

但该方法由于其积累平方和,必然导致后期学习率变小,可能造成难以收敛的后果。

RMSProp

为了解决AdamGrad算法后期学习率太小的问题,RMSProp 通过引入一个衰减系数来解决这个问题,使得历史信息能够指数级衰减,将算法A中梯度的求和改为衰减系数
该方法在编码中可使用optimizer = torch.optim.RMSprop([w, b], lr=0.01,weight_decay=0.9)设置RMSProp优化器并指定衰减系数。

Adam

该算法结合了动量法和RMSProp的思想,梯度上结合动量,而学习率结合RMSProp,具体公式如下:
在这里插入图片描述
最终的更新公式为:
在这里插入图片描述
该方法同时考虑了基于自身变化趋势的更新,和学习率根据梯度的变化,β1β2分别控制这两步的权重,项目中可以通过optimizer = torch.optim.Adam([w, b], lr=0.01)设置优化器实现,该算法在当前阶段可以说基本完美,只是执行过程中一定程度依赖于衰减系数,并且因为存储多步计算,空间上占用可能稍多,但可以说根本不是问题。

总结

本文介绍了几种梯度下降优化算法,包括其原理和代码实现,但并不一定说最好的算法就一定产生最好的结果,每种算法都有其适应的领域,在实际中还是要具体问题具体分析。


http://www.kler.cn/a/354186.html

相关文章:

  • TORCH_CUDA_ARCH_LIST
  • 精准提升:从94.5%到99.4%——目标检测调优全纪录
  • acme ssl证书自动续签 nginx
  • StartAI图生图局部重绘,让画面细节焕发新生!!
  • 【每日学点鸿蒙知识】AVCodec、SmartPerf工具、web组件加载、监听键盘的显示隐藏、Asset Store Kit
  • Debian 12 安装配置 fail2ban 保护 SSH 访问
  • pico+Unity交互开发教程——手指触控交互(Poke Interaction)
  • 如何利用OpenCV和yolo实现人脸检测
  • 如何利用边缘计算网关进行工厂设备数据采集?天拓四方
  • Linux创建sh脚本,实现全局调用
  • 可编辑73页PPT | 企业智慧能源管控平台建设方案
  • 机器学习【教育系统改善及其应用】
  • 线性代数基本知识
  • Web 搜索引擎优化
  • k8s部署Kafka集群超详细讲解
  • C#高级编程核心知识点
  • 智慧供排水管网在线监测为城市安全保驾护航
  • Mysql(4)—数据库索引
  • 数据结构实验十二 图的遍历及应用
  • 在FastAPI网站学python:虚拟环境创建和使用
  • 特斯拉智驾路线影响国内OEM组织架构变革,Robotaxi重塑汽车定位搅动风云
  • 【云从】五、负载均衡CLB
  • 使用 Docker compose 部署 Nacos(达梦数据库)
  • MongoDB如何查找数据以及条件运算符使用的详细说明
  • 比肩vercel的nuxt自动化部署,nuxthub+github+cloudflare
  • web网页---QQ注册页面的实现