当前位置: 首页 > article >正文

深度学习中batch_size

Batch size调整和epoch/iteration的关系

训练数据集总共有1000个样本。若batch_size=10,那么训练完全体样本集需要100次迭代,1次epoch。
训练样本10000条,batchsize设置为20,将所有的训练样本在同一个模型中训练5遍,则epoch=5,batchsize=20, iteration=10000/20=500(即迭代次数表示有多个个batch会过模型)

分布式训练时的batch_size设置:需要将batch_size/num_process。divide the batch size by the number of replicas in order to maintain the overall batch size of 需要的值.[PyTorch:模型训练-分布式训练]

Batch size设置经验

1 一定条件下,batchsize越大训练效果越好。但是batchsize越大,内存gpu消耗越大。梯度累加则实现了batchsize的变相扩大,如果accumulation_steps为8,则batchsize ‘变相’ 扩大了8倍,是解决显存受限的一个不错的trick。

[使用梯度累加的batch训练函数]

2 经验法则是,如果mini-batch size加倍,那么学习率就加倍。
在这里插入图片描述

[神经网络中 warmup 策略为什么有效;有什么理论解释么? - 知乎]

在前面“如果mini-batch size加倍,那么学习率就加倍"中,我们的假设在什么时候可能不成立呢?两种情况:

1)在训练的开始阶段,模型权重迅速改变
2)mini-batch size较小,样本方差较大

第一种情况很好理解,可以认为,刚开始模型对数据的“分布”理解为零,或者是说“均匀分布”(当然这取决于你的初始化);在第一轮训练的时候,每个数据点对模型来说都是新的,模型会很快地进行数据分布修正,如果这时候学习率就很大,极有可能导致开始的时候就对该数据“过拟合”,后面要通过多轮训练才能拉回来,浪费时间。当训练了一段时间(比如两轮、三轮)后,模型已经对每个数据点看过几遍了,或者说对当前的batch而言有了一些正确的先验,较大的学习率就不那么容易会使模型学偏,所以可以适当调大学习率。这个过程就可以看做是warmup。那么为什么之后还要decay呢?当模型训到一定阶段后(比如十个epoch),模型的分布就已经比较固定了,或者说能学到的新东西就比较少了。如果还沿用较大的学习率,就会破坏这种稳定性,用我们通常的话说,就是已经接近loss的local optimal了,为了靠近这个point,我们就要慢慢来。

第二种情况其实和第一种情况是紧密联系的。在训练的过程中,如果有mini-batch内的数据分布方差特别大,这就会导致模型学习剧烈波动,使其学得的权重很不稳定,这在训练初期最为明显,最后期较为缓解(所以我们要对数据进行scale也是这个道理)。

说明,在上面两种情况下,我们并不能单纯地成倍增长lr η̂ =kη。要么改变学习率增长方法,要么设法解决上面两个问题。

[神经网络中 warmup 策略为什么有效;有什么理论解释么?]

所以就有了下面的warmup策略和学习率衰减方法:

学习率 warm-up 策略

训练神经网络时,在初始使用较大学习率而后期切换为较小学习率是一种广为使用的做法。而 warmup 策略则与上述 scheme 有些矛盾,warmup 需要在训练最初使用较小的学习率来启动,并很快切换到大学习率而后进行常见的 decay,那么最开始的这一步 warmup 为什么有效呢?

warmup_lr 的初始值是跟训练预料的大小成反比的,也就是说训练预料越大,那么warmup_lr 初值越小,随后增长到我们预设的超参 initial_learning_rate相同的量级,再接下来又通过 decay_rates 逐步下降。

这样做的原因前面已经说明了,还有什么好处?

1)这样可以使得学习率适应不同的训练集合size实验的时候经常需要先使用小的数据集训练验证模型,然后换大的数据集做生成环境模型训练。

2)即使不幸学习率设置得很大,那么也能通过warmup机制看到合适的学习率区间(即训练误差先降后升的关键位置附近),以便后续验证。

原文:https://blog.csdn.net/pipisorry/article/details/109192443


http://www.kler.cn/a/457260.html

相关文章:

  • BurstAttention:高效的分布式注意力计算框架
  • Java基础知识(五) -- 枚举、注解和异常
  • 30天开发操作系统 第 10 天 -- 叠加处理
  • 【Wi-Fi】802.11u、WPA、WPA2/WPA3-ENterprise、Hotspot 、IEEE802.11x的关系
  • rem em rpx px vw的区别
  • streamlit、shiny、gradio、fastapi四个web APP平台体验
  • MySQL并发问题区别-MVCC如何解决的
  • Linux 下 Mamba 环境安装踩坑问题汇总(重置版)
  • 【前端】Vue3 父传子 Dialog 显示问题:解决方案与最佳实践
  • 狼人杀.转载
  • 神经网络初学总结(一)
  • 国密算法SM3的GmSSL代码Android实现Demo
  • 【Leecode】Leecode刷题之路第93天之复原IP地址
  • 使用Python实现智能交通信号控制系统
  • 深度学习笔记(12)——深度学习概论
  • CDN如何抵御DDoS攻击
  • 如何在 Ubuntu 22.04 上使用 systemctl 管理 systemd 服务教程
  • Pytorch | 利用MIG针对CIFAR10上的ResNet分类器进行对抗攻击
  • python lambda函数用法
  • Android `android.graphics.drawable` 包深度解析:架构与设计模式
  • zentao ubuntu上安装
  • EMNLP'24 最佳论文解读 | 大语言模型的预训练数据检测:基于散度的校准方法
  • 探索鸿蒙的蓝牙A2DP与访问API:从学习到实现的开发之旅
  • 从零开始采用命令行创建uniapp vue3 ts springboot项目
  • 《PHP Switch》
  • DeepSeek-VL2部署指南