当前位置: 首页 > article >正文

【漫话机器学习系列】119.小批量随机梯度方法

1. 引言

在机器学习和深度学习中,梯度下降(Gradient Descent)是一种常见的优化算法,用于调整模型参数以最小化损失函数。然而,在处理大规模数据集时,使用传统的梯度下降(GD)可能会面临计算成本高、收敛速度慢等问题。因此,引入了小批量随机梯度下降(Mini-Batch Stochastic Gradient Descent,MB-SGD),它结合了全批量梯度下降(Batch GD)和随机梯度下降(SGD)的优点,成为深度学习训练中的标准方法。

本文将详细介绍小批量随机梯度方法的基本概念、数学原理、优缺点及其应用,并通过示例代码演示其实际使用方法。


2. 什么是小批量随机梯度下降?

小批量随机梯度下降(Mini-Batch SGD)是一种改进的梯度下降方法,它在每次参数更新时,只使用数据集中的一个小部分(小批量)来计算梯度,而不是整个数据集。

具体来说,小批量随机梯度下降的工作流程如下:

  1. 从数据集中随机抽取一个小批量(Mini-Batch)样本,大小通常为 32、64、128 等。
  2. 计算该小批量上的梯度,然后更新模型参数。
  3. 重复上述步骤,直到遍历整个数据集(一个 epoch)
  4. 重复多个 epoch,直到模型收敛

这一策略避免了全批量梯度下降计算量过大的问题,同时比单样本的随机梯度下降更稳定。


3. 小批量随机梯度下降的数学原理

3.1. 梯度下降基本公式

梯度下降的核心思想是沿着负梯度方向更新参数,从而最小化损失函数 J(θ)。其基本更新公式如下:

\theta = \theta - \alpha \nabla J(\theta)

其中:

  • θ 是模型参数
  • α 是学习率(learning rate)
  • ∇J(θ) 是损失函数关于参数的梯度

3.2. 全批量梯度下降(Batch Gradient Descent)

全批量梯度下降使用整个数据集来计算梯度:

\nabla J(\theta) = \frac{1}{N} \sum_{i=1}^{N} \nabla J_i(\theta)

其中 N 是数据集的大小。这种方法计算精确,但当数据量过大时,计算开销很高。

3.3. 随机梯度下降(Stochastic Gradient Descent, SGD)

随机梯度下降(SGD)每次只使用一个样本来计算梯度:

\theta = \theta - \alpha \nabla J_i(\theta)

由于仅使用一个样本进行更新,计算速度快,但梯度更新噪声较大,导致收敛不稳定。

3.4. 小批量随机梯度下降(Mini-Batch SGD)

小批量随机梯度下降在每次更新时使用一个小批量 B(包含多个样本)来计算梯度:

\theta = \theta - \alpha \frac{1}{|B|} \sum_{i \in B} \nabla J_i(\theta)

其中∣B∣ 是小批量的大小。该方法在计算效率收敛稳定性之间取得了良好的平衡。


4. 小批量随机梯度下降的优缺点

4.1. 优势

  • 减少计算开销:相比全批量梯度下降,小批量方法可以显著降低计算成本。
  • 提高收敛稳定性:相比随机梯度下降,小批量方法的梯度估计更加稳定,能更快地收敛。
  • 可利用并行计算:可以使用 GPU 进行矩阵运算,提高训练效率。
  • 易于处理大规模数据集:能够在数据量较大的情况下高效训练模型。

4.2. 劣势

  • 超参数敏感:小批量大小(batch size)和学习率的选择会影响模型性能。
  • 计算复杂度仍然较高:虽然比全批量下降快,但仍然比纯随机梯度下降计算量大。
  • 收敛可能不如全批量方法:由于梯度估计存在一定噪声,可能会导致收敛到局部最优解。

5. 代码示例

我们使用 Python 代码来实现小批量随机梯度下降。

5.1. 使用 NumPy 手动实现 Mini-Batch SGD

import numpy as np

# 生成模拟数据
np.random.seed(42)
X = np.random.rand(100, 1)  # 100个样本,1个特征
y = 4 * X + np.random.randn(100, 1) * 0.2  # 线性关系 y = 4x + 噪声

# 初始化参数
theta = np.random.randn(2, 1)
learning_rate = 0.1
epochs = 100
batch_size = 10

# 添加偏置项
X_b = np.c_[np.ones((100, 1)), X]  

# Mini-Batch SGD 训练
for epoch in range(epochs):
    shuffled_indices = np.random.permutation(100)  # 随机打乱数据
    X_b_shuffled = X_b[shuffled_indices]
    y_shuffled = y[shuffled_indices]

    for i in range(0, 100, batch_size):
        X_batch = X_b_shuffled[i:i + batch_size]
        y_batch = y_shuffled[i:i + batch_size]
        gradients = 2 / batch_size * X_batch.T.dot(X_batch.dot(theta) - y_batch)
        theta -= learning_rate * gradients

print(f"训练后的参数: {theta}")

运行结果

训练后的参数: [[0.04320936]
 [3.90884737]]

此代码实现了:

  1. 生成数据集并添加噪声。
  2. 使用 Mini-Batch SGD 进行参数更新。
  3. 训练完成后输出最终的参数值。

5.2. 使用 PyTorch 实现 Mini-Batch SGD

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

# 生成数据
X = torch.rand(100, 1)
y = 4 * X + torch.randn(100, 1) * 0.2

# 构建数据集
dataset = TensorDataset(X, y)
dataloader = DataLoader(dataset, batch_size=10, shuffle=True)

# 定义模型
model = nn.Linear(1, 1)
loss_fn = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.1)

# 训练
epochs = 100
for epoch in range(epochs):
    for batch_X, batch_y in dataloader:
        optimizer.zero_grad()
        predictions = model(batch_X)
        loss = loss_fn(predictions, batch_y)
        loss.backward()
        optimizer.step()

print(f"训练后的权重: {model.weight.data}, 偏置: {model.bias.data}")

运行结果

训练后的权重: tensor([[3.9055]]), 偏置: tensor([0.0890])

PyTorch 实现更加简洁,并且支持自动求导和 GPU 加速。


6. 结论

小批量随机梯度下降(Mini-Batch SGD)是一种高效且稳定的优化方法,它结合了全批量梯度下降的稳定性和随机梯度下降的计算效率,是深度学习训练中的标准方法。在实际应用中,需要通过调整学习率、批量大小和优化策略来获得最佳性能。

对于大规模数据集和深度学习任务,小批量方法能够显著提高训练速度,并支持并行计算,使得它成为现代机器学习的核心优化算法之一。


http://www.kler.cn/a/573439.html

相关文章:

  • 机器学习中的优化方法:从局部探索到全局约束
  • Measuring short-form factuality in large language models (SimpleQA) 论文简介
  • mybatis日期格式与字符串不匹配bug
  • 解锁前端表单数据的秘密旅程:从后端到用户选择!✨
  • 微服务通信:用gRPC + Protobuf 构建高效API
  • Java+SpringBoot+Vue+数据可视化的百草园化妆服务平台(程序+论文+讲解+安装+调试+售后)
  • 年后寒假总结及计划安排
  • linux安装Kafka以及windows安装Kafka和常见问题解决
  • 迷你世界脚本对象库接口:ObjectLib
  • Oracle CBD结构和Non-CBD结构区别
  • 微软官宣5 月 5 日关闭 Skype,赢者通吃法则依然有效
  • 解锁网络防御新思维:D3FEND 五大策略如何对抗 ATTCK
  • 如何快速的用pdfjs建立一个网页可以在线阅读你的PDF文件
  • 加密算法学习与SpringBoot实践
  • Java 多态:代码中的通用设计模式
  • 第七节:基于Winform框架的串口助手小项目---协议解析《C#编程》
  • 【数据结构初阶】---时间复杂度和空间复杂度了解及几道相关OJ题
  • Ubuntu20.04 在离线机器上安装 NVIDIA Container Toolkit
  • 【我的 PWN 学习手札】House of Emma
  • Python:简单的爬虫程序,从web页面爬取图片与标题并保存MySQL