当前位置: 首页 > article >正文

随着Batch size增加,最佳learning rate如何选择?

最近读到《Surge Phenomenon in Optimal Learning Rate and Batch Size Scaling》这篇论文,里面通过实验和理论证明了learning rate和batch size之间的关系,觉得很有意思,就简答写个blog记录下。

1. 简介

  • 在影响模型训练效果的所有参数中,batch size与learning rate是最为重要的。随着训练数据的增加,为了加快训练效率,batch size变的越来越大,那么如果选择最佳的learning rate就成为了难题。
  • 在SGD优化器中,一些研究表明,learning rate与batch size之间有平方根[1],线性关系[2]等。在比较大的batch size下,一些理论和实验表明:
    在这里插入图片描述
  • 在类Adam优化器中,batch size与learning rate之间并没有上面的关系。一开始learning rate随着batch size增大而增大,随后达到一个点后,会随着batch size增加而降低,同时,随着训练不断进行, B n o i s e B_{noise} Bnoise会不断后移。 具体关系可以参考下面的图:
    在这里插入图片描述
    根据empirical model[3],batch size与optimal learning rate之间有下述的关系:
    在这里插入图片描述
    这里的 B n o i s e B_{noise} Bnoise表示数据有效性与训练速度之间的trade-off,当Batch size等于 B n o i s e B_{noise} Bnoise时,optimal learning rate达到局部最大。对于SGD,当B<< B n o i s e B_{noise} Bnoise时,learning rate与batch size呈现线性关系。
    在这里插入图片描述
    对于Adam优化器,则呈现平方根关系。
    在这里插入图片描述

2. 理论

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

3. 实验

在这里插入图片描述

  • 记录在每个batch size下,不同的实际处理的step数S与处理的训练样本数E在达到具体loss的时候对应的optimal learning rate,可以得到:
    在这里插入图片描述
    然后,用不同batch size下的optimal learning rate搜索结果估计出类Adam优化器的最大的optimal learning rate:
    在这里插入图片描述
    类SGD的最大optimal learning rate为:
    在这里插入图片描述

  • 训练参数如上表所示,总共训练了三个模型,分别为5层的CNN,ResNet18及DiltilGPT2.
    在这里插入图片描述
    从上图中可以看出,在达到 B n o i s e B_{noise} Bnoise前,optimal learning rate随着batch size增加而增加,当达到 B n o i s e B_{noise} Bnoise后,optimal learning rate会逐渐降低。同时,随着训练的不断进行, B n o i s e B_{noise} Bnoise会不断的向右偏移。
    在这里插入图片描述

4. 总结

  • 一开始learning rate随着batch size增大而增大,随后达到一个点后,会随着batch size增加而降低,同时,随着训练不断进行, B n o i s e B_{noise} Bnoise会不断后移。
  • 为了加速训练进程,可以设计自适应的learning rate和batch size。

5. 参考文献

  • [1] One weird trick for parallelizing convolutional neural networks
  • [2] Learning rates as a function of batch size: A random matrix theory approach to neural network training
  • [3] An empirical model of large-batch training

http://www.kler.cn/news/305842.html

相关文章:

  • 一个关于Excel的段子
  • 2860. 让所有学生保持开心的分组方法数
  • 模板替换引擎(支持富文本动态表格)
  • 物体识别之微特征识别任务综述
  • Linux文件系统(下)
  • 红黑树前语
  • 存储课程学习笔记5_iouring的练习(io_uring,rust_echo_bench,fio)
  • Unity2D游戏入门
  • [项目][WebServer][解析错误处理]详细讲解
  • JVM字节码
  • MySQL通过备份恢复的方式搭建主从/重建从库
  • 删除Cookie原理
  • 【Unity】在Unity 3D中使用Spine开发2D动画
  • Java | Leetcode Java题解之第404题左叶子之和
  • SQL中的外键约束
  • 获取某宝拍立淘API接口:深度学习图像实现匹配和检索
  • WebGL系列教程八(GLSL着色器基础语法)
  • Android 13 固定systemUI的状态栏为黑底白字,不能被系统应用或者三方应用修改
  • USB组合设备——鼠标+键盘(两个接口实现)
  • OPENAIGC开发者大赛企业组AI黑马奖 | AIGC数智传媒解决方案
  • iPhone 16即将推出的5项苹果智能功能
  • Computer Vision的学习路线
  • 坐牢第三十八天(Qt)
  • Android SDK和NDK的区别
  • SSH软链接后门从入门到应急响应
  • Redis的常见问题
  • 鸿蒙交互事件开发07——手势竞争问题
  • 速通GPT:《Improving Language Understanding by Generative Pre-Training》全文解读
  • 前端开发的观察者模式
  • K8s 之Pod的定义及详细资源调用案例