当前位置: 首页 > article >正文

深度学习|表示学习|Mini-Batch Normalization 具体计算举例|23

如是我闻: Batch Normalization(BN)是一种在 mini-batch 级别计算均值和方差的归一化方法,它能够加速训练、稳定梯度,并减少对权重初始化的敏感性。

在 BN 过程中,我们不会使用整个数据集计算均值和方差,而是在每个 mini-batch 内单独计算,然后对数据进行标准化。


1. 假设数据集

我们假设有一个简单的数据集,其中:

  • 20 个样本
  • 每个样本有 3 个特征
  • mini-batch 大小设为 5

整个数据集如下:

X = [ 1.0 2.0 3.0 2.0 3.0 4.0 3.0 4.0 5.0 4.0 5.0 6.0 5.0 6.0 7.0 6.0 7.0 8.0 7.0 8.0 9.0 8.0 9.0 10.0 9.0 10.0 11.0 10.0 11.0 12.0 11.0 12.0 13.0 . . . . . . . . . 20.0 21.0 22.0 ] X = \begin{bmatrix} 1.0 & 2.0 & 3.0 \\ 2.0 & 3.0 & 4.0 \\ 3.0 & 4.0 & 5.0 \\ 4.0 & 5.0 & 6.0 \\ 5.0 & 6.0 & 7.0 \\ \hline 6.0 & 7.0 & 8.0 \\ 7.0 & 8.0 & 9.0 \\ 8.0 & 9.0 & 10.0 \\ 9.0 & 10.0 & 11.0 \\ 10.0 & 11.0 & 12.0 \\ \hline 11.0 & 12.0 & 13.0 \\ ... & ... & ... \\ 20.0 & 21.0 & 22.0 \\ \end{bmatrix} X= 1.02.03.04.05.06.07.08.09.010.011.0...20.02.03.04.05.06.07.08.09.010.011.012.0...21.03.04.05.06.07.08.09.010.011.012.013.0...22.0

训练时,我们按 mini-batch 进行计算,每个 batch 处理 5 个样本。


2. Batch 1(样本 1-5)的 BN 计算

Step 1: 取出 mini-batch

第一个 mini-batch(样本 1-5):

X = [ 1.0 2.0 3.0 2.0 3.0 4.0 3.0 4.0 5.0 4.0 5.0 6.0 5.0 6.0 7.0 ] X = \begin{bmatrix} 1.0 & 2.0 & 3.0 \\ 2.0 & 3.0 & 4.0 \\ 3.0 & 4.0 & 5.0 \\ 4.0 & 5.0 & 6.0 \\ 5.0 & 6.0 & 7.0 \\ \end{bmatrix} X= 1.02.03.04.05.02.03.04.05.06.03.04.05.06.07.0

Step 2: 计算 batch 内每个特征的均值

μ j = 1 m ∑ i = 1 m x i , j \mu_j = \frac{1}{m} \sum_{i=1}^{m} x_{i,j} μj=m1i=1mxi,j

对于 3 个特征维度:
μ 1 = 1 + 2 + 3 + 4 + 5 5 = 3.0 \mu_1 = \frac{1+2+3+4+5}{5} = 3.0 μ1=51+2+3+4+5=3.0
μ 2 = 2 + 3 + 4 + 5 + 6 5 = 4.0 \mu_2 = \frac{2+3+4+5+6}{5} = 4.0 μ2=52+3+4+5+6=4.0
μ 3 = 3 + 4 + 5 + 6 + 7 5 = 5.0 \mu_3 = \frac{3+4+5+6+7}{5} = 5.0 μ3=53+4+5+6+7=5.0

即:
μ = [ 3.0 , 4.0 , 5.0 ] \mu = [3.0, 4.0, 5.0] μ=[3.0,4.0,5.0]

Step 3: 计算 batch 内方差

σ j 2 = 1 m ∑ i = 1 m ( x i , j − μ j ) 2 \sigma_j^2 = \frac{1}{m} \sum_{i=1}^{m} (x_{i,j} - \mu_j)^2 σj2=m1i=1m(xi,jμj)2

计算得到:
σ 2 = [ 2.0 , 2.0 , 2.0 ] \sigma^2 = [2.0, 2.0, 2.0] σ2=[2.0,2.0,2.0]

为了防止除零,我们加一个很小的数 ϵ \epsilon ϵ(通常是 1e-5):
σ ^ 2 = [ 2.0 + 1 e − 5 , 2.0 + 1 e − 5 , 2.0 + 1 e − 5 ] \hat{\sigma}^2 = [2.0 + 1e-5, 2.0 + 1e-5, 2.0 + 1e-5] σ^2=[2.0+1e5,2.0+1e5,2.0+1e5]

Step 4: 归一化

x ^ i , j = x i , j − μ j σ j 2 + ϵ \hat{x}_{i,j} = \frac{x_{i,j} - \mu_j}{\sqrt{\sigma_j^2 + \epsilon}} x^i,j=σj2+ϵ xi,jμj

计算所有样本的归一化值:

X ^ = [ − 1.414 − 1.414 − 1.414 − 0.707 − 0.707 − 0.707 0.0 0.0 0.0 0.707 0.707 0.707 1.414 1.414 1.414 ] \hat{X} =\begin{bmatrix} -1.414 & -1.414 & -1.414 \\ -0.707 & -0.707 & -0.707 \\ 0.0 & 0.0 & 0.0 \\ 0.707 & 0.707 & 0.707 \\ 1.414 & 1.414 & 1.414 \\ \end{bmatrix} X^= 1.4140.7070.00.7071.4141.4140.7070.00.7071.4141.4140.7070.00.7071.414

Step 5: 重新缩放和平移

y i , j = γ j x ^ i , j + η j y_{i,j} = \gamma_j \hat{x}_{i,j} +\eta_j yi,j=γjx^i,j+ηj

其中:

  • γ \gamma γ η \eta η 是可学习参数,在训练过程中不断调整。

3. Batch 2(样本 6-10)的 BN 计算

当我们处理 下一个 mini-batch(样本 6-10) 时,重新计算均值和方差,而不会使用上一个 batch 的统计量。

取出 Batch 2
X = [ 6.0 7.0 8.0 7.0 8.0 9.0 8.0 9.0 10.0 9.0 10.0 11.0 10.0 11.0 12.0 ] X = \begin{bmatrix} 6.0 & 7.0 & 8.0 \\ 7.0 & 8.0 & 9.0 \\ 8.0 & 9.0 & 10.0 \\ 9.0 & 10.0 & 11.0 \\ 10.0 & 11.0 & 12.0 \\ \end{bmatrix} X= 6.07.08.09.010.07.08.09.010.011.08.09.010.011.012.0

同样进行 BN 计算:

  • 均值 μ = [ 8.0 , 9.0 , 10.0 ] \mu = [8.0, 9.0, 10.0] μ=[8.0,9.0,10.0]
  • 方差 σ 2 = [ 2.0 , 2.0 , 2.0 ] \sigma^2 = [2.0, 2.0, 2.0] σ2=[2.0,2.0,2.0]
  • 归一化:使均值变 0,方差变 1
  • 重新缩放和平移:用 γ \gamma γ 和 eta$ 调整

这与 Batch 1 是完全独立的计算过程


4. 训练 vs 预测

训练阶段

  • 每个 mini-batch 计算自己的均值和方差
  • 使用当前 batch 的均值和方差进行归一化

推理(测试)阶段

在测试时,我们不再依赖 mini-batch,而是用整个训练过程中累积的全局均值和方差,通常用 指数移动平均 计算:
μ g l o b a l = 0.9 ⋅ μ g l o b a l + 0.1 ⋅ μ b a t c h \mu_{global} = 0.9 \cdot \mu_{global} + 0.1 \cdot \mu_{batch} μglobal=0.9μglobal+0.1μbatch
σ g l o b a l 2 = 0.9 ⋅ σ g l o b a l 2 + 0.1 ⋅ σ b a t c h 2 \sigma^2_{global} = 0.9 \cdot \sigma^2_{global} + 0.1 \cdot \sigma^2_{batch} σglobal2=0.9σglobal2+0.1σbatch2

这样,即使 batch size 变为 1,BN 也能正确工作。


5. 总的来说

  • Batch Normalization 在 mini-batch 级别计算均值和方差,不会跨 batch 计算。
  • 每个 batch 计算自己的均值和方差,互不影响。
  • 训练时用 mini-batch 统计量,推理时用全局统计量(移动平均计算)。
  • BN 主要作用:
    1. 让每层的输入分布稳定(均值 0,方差 1)。
    2. 让梯度更稳定,加速训练。
    3. 使网络更鲁棒,不依赖权重初始化。

以上


http://www.kler.cn/a/540813.html

相关文章:

  • Mac(m1)本地部署deepseek-R1模型
  • 防御综合实验
  • SQL自学,mysql从入门到精通 --- 第 14天,主键、外键的使用
  • ubuntu文件同步
  • 【redis】数据类型之list
  • postgresql 游标(cursor)的使用
  • Intellij IDEA调整栈内存空间大小详细教程,添加参数-Xss....
  • 【推荐】爽,在 IDE 中做 LeetCode 题目的插件
  • 基于 FFmpeg 和 OpenGLES 的 iOS 视频预览和录制技术方案设计
  • Spring容器初始化扩展点:ApplicationContextInitializer
  • MVVM设计模式
  • 大模型基础面试问题汇总
  • 1.2 环境搭建
  • 「vue3-element-admin」告别 vite-plugin-svg-icons!用 @unocss/preset-icons 加载本地 SVG 图标
  • 2.1 Mockito核心API详解
  • PriorityQueue优先级队列的使用和Top-k问题
  • 小白零基础学习深度学习之张量
  • 【C++语言】类和对象(下)
  • 备战蓝桥杯:二分算法详解以及模板题
  • Redis持久化机制详解
  • Proxy vs DefineProperty
  • 车载工具报错分析:CANoe、CANalyzer问题:Stuff Error
  • Java 大视界 -- Java 大数据在智能家居中的应用与场景构建(79)
  • Vue:Table合并行于列
  • 用Go实现 SSE 实时推送消息(消息通知)——思悟项目技术4
  • 绘制中国平安股价的交互式 K 线图