当前位置: 首页 > article >正文

批量归一化(Batch Normalization)

批量归一化(Batch Normalization) 是一种用于加速深度神经网络训练并提高模型稳定性的技术,通常简称为 BatchNorm。它通过在每一层网络的激活输出上应用归一化操作来减少内部协变量偏移(Internal Covariate Shift),即减小网络在训练过程中因参数变化导致的分布漂移。批量归一化可以使网络更快地收敛,并帮助模型在训练时更稳定。

批量归一化的工作原理:

对于每一层网络的输出(通常是一个 mini-batch),批量归一化会将其均值和方差调整到标准正态分布(均值为0,方差为1)。具体步骤如下:

  1. 计算均值和方差:对 mini-batch 中的每个样本,计算其均值和方差。
  2. 归一化:将每个样本减去均值,再除以标准差,从而使数据的分布在该 mini-batch 中变成均值为0,方差为1。
  3. 缩放和平移:引入两个可学习的参数,缩放系数(gamma)和偏移系数(beta),用于恢复模型的表示能力。这一步允许网络在需要时重新调整归一化后的分布,以便更好地适应任务需求。

批量归一化的优势:

  • 提高训练速度:通过减少内部协变量偏移,网络在训练过程中收敛得更快。
  • 稳定训练过程:批量归一化有助于减小梯度消失或爆炸的风险,使网络在较大的学习率下也能稳定训练。
  • 一定程度的正则化效果:批量归一化对每一批数据应用不同的归一化,因此引入了随机性,具有一定的正则化效果,从而减少过拟合的风险。

批量归一化在卷积神经网络(CNN)和循环神经网络(RNN)中应用广泛。通过在网络层的激活输出上进行批量归一化,深度学习模型能够在更少的训练时间内达到更好的效果。

其实怎么理解呢,如果每一层的输出都是一个不同的分布,那么训练的时候就很难收敛,我们需要将每次的输出都整理为相似的输出,有助于收敛和训练。

假设一个批量数据集中的样本为 ( x_1, x_2, \ldots, x_n ),其均值和方差分别为:

  • 均值:
    μ = 1 n ∑ i = 1 n x i \mu = \frac{1}{n} \sum_{i=1}^{n} x_i μ=n1i=1nxi
  • 方差:
    σ 2 = 1 n ∑ i = 1 n ( x i − μ ) 2 \sigma^2 = \frac{1}{n} \sum_{i=1}^{n} (x_i - \mu)^2 σ2=n1i=1n(xiμ)2
步骤 1:数据去均值

将每个样本减去均值 ( \mu ),得到新的样本 ( x_i’ = x_i - \mu ):

  • 新数据的均值为0,因为:
    均值 ( x i ′ ) = 1 n ∑ i = 1 n ( x i − μ ) = 1 n ( ∑ i = 1 n x i − n μ ) = 0 \text{均值}(x_i') = \frac{1}{n} \sum_{i=1}^{n} (x_i - \mu) = \frac{1}{n} \left(\sum_{i=1}^{n} x_i - n \mu\right) = 0 均值(xi)=n1i=1n(xiμ)=n1(i=1nxinμ)=0
步骤 2:数据除以标准差

再将去均值后的每个样本 ( x_i’ ) 除以标准差 ( \sigma ),得到 ( x_i’’ = \frac{x_i’}{\sigma} = \frac{x_i - \mu}{\sigma} ):

  • 新数据的方差变为1,因为:
    方差 ( x i ′ ′ ) = 1 n ∑ i = 1 n ( x i − μ σ ) 2 = 1 n ∑ i = 1 n ( x i − μ ) 2 σ 2 = σ 2 σ 2 = 1 \text{方差}(x_i'') = \frac{1}{n} \sum_{i=1}^{n} \left(\frac{x_i - \mu}{\sigma}\right)^2 = \frac{1}{n} \sum_{i=1}^{n} \frac{(x_i - \mu)^2}{\sigma^2} = \frac{\sigma^2}{\sigma^2} = 1 方差(xi′′)=n1i=1n(σxiμ)2=n1i=1nσ2(xiμ)2=σ2σ2=1

http://www.kler.cn/news/365147.html

相关文章:

  • Java 多线程(八)—— 锁策略,synchronized 的优化,JVM 与编译器的锁优化,ReentrantLock,CAS
  • 深度学习系列——RNN/LSTM/GRU,seq2seq/attention机制
  • php命令执行的一些执行函数----以ctfshow靶场为解题思路
  • 「C/C++」C++17 之 std::variant 安全的联合体(变体)
  • Ruby 从入门到精通:学习之旅与资源推荐
  • 使用docker-compose搭建redis7集群-3主3从
  • 混个1024勋章
  • [笔记] 关于CreateProcessWithLogonW函数创建进程
  • Linux系统
  • 鸿蒙到底是不是纯血?到底能不能走向世界?
  • 蓝桥杯题目理解
  • Python爬虫:urllib_ajax的get请求豆瓣电影前十页(08)
  • 【C++】用哈希桶模拟实现unordered_set和unordered_map
  • 网络安全中的日志审计:为何至关重要?
  • 35.第二阶段x86游戏实战2-C++遍历技能
  • CPRI与eCPRI的区别
  • 每天5分钟玩转C#/.NET之C#语言详细介绍
  • python-PyQt项目实战案例:制作一个视频播放器
  • 双十一送你一份购物攻略,绿联NAS DXP2800评测
  • 借老系统重构我给jpa写了个mybatis风格的查询模块
  • 【笔记】apt源设置为阿里云源
  • 19.面试算法-树的深度优先遍历(一)
  • Nginx15-Lua扩展模块
  • Zookeeper面试整理-Zookeeper集群管理
  • 简单走近ChatGPT
  • 信息安全工程师(55)网络安全漏洞概述