深度学习|表示学习|Mini-Batch Normalization 具体计算举例|23
如是我闻: Batch Normalization(BN)是一种在 mini-batch 级别计算均值和方差的归一化方法,它能够加速训练、稳定梯度,并减少对权重初始化的敏感性。
在 BN 过程中,我们不会使用整个数据集计算均值和方差,而是在每个 mini-batch 内单独计算,然后对数据进行标准化。
1. 假设数据集
我们假设有一个简单的数据集,其中:
- 20 个样本
- 每个样本有 3 个特征
- mini-batch 大小设为 5
整个数据集如下:
X = [ 1.0 2.0 3.0 2.0 3.0 4.0 3.0 4.0 5.0 4.0 5.0 6.0 5.0 6.0 7.0 6.0 7.0 8.0 7.0 8.0 9.0 8.0 9.0 10.0 9.0 10.0 11.0 10.0 11.0 12.0 11.0 12.0 13.0 . . . . . . . . . 20.0 21.0 22.0 ] X = \begin{bmatrix} 1.0 & 2.0 & 3.0 \\ 2.0 & 3.0 & 4.0 \\ 3.0 & 4.0 & 5.0 \\ 4.0 & 5.0 & 6.0 \\ 5.0 & 6.0 & 7.0 \\ \hline 6.0 & 7.0 & 8.0 \\ 7.0 & 8.0 & 9.0 \\ 8.0 & 9.0 & 10.0 \\ 9.0 & 10.0 & 11.0 \\ 10.0 & 11.0 & 12.0 \\ \hline 11.0 & 12.0 & 13.0 \\ ... & ... & ... \\ 20.0 & 21.0 & 22.0 \\ \end{bmatrix} X= 1.02.03.04.05.06.07.08.09.010.011.0...20.02.03.04.05.06.07.08.09.010.011.012.0...21.03.04.05.06.07.08.09.010.011.012.013.0...22.0
训练时,我们按 mini-batch 进行计算,每个 batch 处理 5 个样本。
2. Batch 1(样本 1-5)的 BN 计算
Step 1: 取出 mini-batch
第一个 mini-batch(样本 1-5):
X = [ 1.0 2.0 3.0 2.0 3.0 4.0 3.0 4.0 5.0 4.0 5.0 6.0 5.0 6.0 7.0 ] X = \begin{bmatrix} 1.0 & 2.0 & 3.0 \\ 2.0 & 3.0 & 4.0 \\ 3.0 & 4.0 & 5.0 \\ 4.0 & 5.0 & 6.0 \\ 5.0 & 6.0 & 7.0 \\ \end{bmatrix} X= 1.02.03.04.05.02.03.04.05.06.03.04.05.06.07.0
Step 2: 计算 batch 内每个特征的均值
μ j = 1 m ∑ i = 1 m x i , j \mu_j = \frac{1}{m} \sum_{i=1}^{m} x_{i,j} μj=m1i=1∑mxi,j
对于 3 个特征维度:
μ
1
=
1
+
2
+
3
+
4
+
5
5
=
3.0
\mu_1 = \frac{1+2+3+4+5}{5} = 3.0
μ1=51+2+3+4+5=3.0
μ
2
=
2
+
3
+
4
+
5
+
6
5
=
4.0
\mu_2 = \frac{2+3+4+5+6}{5} = 4.0
μ2=52+3+4+5+6=4.0
μ
3
=
3
+
4
+
5
+
6
+
7
5
=
5.0
\mu_3 = \frac{3+4+5+6+7}{5} = 5.0
μ3=53+4+5+6+7=5.0
即:
μ
=
[
3.0
,
4.0
,
5.0
]
\mu = [3.0, 4.0, 5.0]
μ=[3.0,4.0,5.0]
Step 3: 计算 batch 内方差
σ j 2 = 1 m ∑ i = 1 m ( x i , j − μ j ) 2 \sigma_j^2 = \frac{1}{m} \sum_{i=1}^{m} (x_{i,j} - \mu_j)^2 σj2=m1i=1∑m(xi,j−μj)2
计算得到:
σ
2
=
[
2.0
,
2.0
,
2.0
]
\sigma^2 = [2.0, 2.0, 2.0]
σ2=[2.0,2.0,2.0]
为了防止除零,我们加一个很小的数
ϵ
\epsilon
ϵ(通常是 1e-5):
σ
^
2
=
[
2.0
+
1
e
−
5
,
2.0
+
1
e
−
5
,
2.0
+
1
e
−
5
]
\hat{\sigma}^2 = [2.0 + 1e-5, 2.0 + 1e-5, 2.0 + 1e-5]
σ^2=[2.0+1e−5,2.0+1e−5,2.0+1e−5]
Step 4: 归一化
x ^ i , j = x i , j − μ j σ j 2 + ϵ \hat{x}_{i,j} = \frac{x_{i,j} - \mu_j}{\sqrt{\sigma_j^2 + \epsilon}} x^i,j=σj2+ϵxi,j−μj
计算所有样本的归一化值:
X ^ = [ − 1.414 − 1.414 − 1.414 − 0.707 − 0.707 − 0.707 0.0 0.0 0.0 0.707 0.707 0.707 1.414 1.414 1.414 ] \hat{X} =\begin{bmatrix} -1.414 & -1.414 & -1.414 \\ -0.707 & -0.707 & -0.707 \\ 0.0 & 0.0 & 0.0 \\ 0.707 & 0.707 & 0.707 \\ 1.414 & 1.414 & 1.414 \\ \end{bmatrix} X^= −1.414−0.7070.00.7071.414−1.414−0.7070.00.7071.414−1.414−0.7070.00.7071.414
Step 5: 重新缩放和平移
y i , j = γ j x ^ i , j + η j y_{i,j} = \gamma_j \hat{x}_{i,j} +\eta_j yi,j=γjx^i,j+ηj
其中:
- γ \gamma γ 和 η \eta η 是可学习参数,在训练过程中不断调整。
3. Batch 2(样本 6-10)的 BN 计算
当我们处理 下一个 mini-batch(样本 6-10) 时,重新计算均值和方差,而不会使用上一个 batch 的统计量。
取出 Batch 2:
X
=
[
6.0
7.0
8.0
7.0
8.0
9.0
8.0
9.0
10.0
9.0
10.0
11.0
10.0
11.0
12.0
]
X = \begin{bmatrix} 6.0 & 7.0 & 8.0 \\ 7.0 & 8.0 & 9.0 \\ 8.0 & 9.0 & 10.0 \\ 9.0 & 10.0 & 11.0 \\ 10.0 & 11.0 & 12.0 \\ \end{bmatrix}
X=
6.07.08.09.010.07.08.09.010.011.08.09.010.011.012.0
同样进行 BN 计算:
- 均值: μ = [ 8.0 , 9.0 , 10.0 ] \mu = [8.0, 9.0, 10.0] μ=[8.0,9.0,10.0]
- 方差: σ 2 = [ 2.0 , 2.0 , 2.0 ] \sigma^2 = [2.0, 2.0, 2.0] σ2=[2.0,2.0,2.0]
- 归一化:使均值变 0,方差变 1
- 重新缩放和平移:用 γ \gamma γ 和 eta$ 调整
这与 Batch 1 是完全独立的计算过程。
4. 训练 vs 预测
训练阶段
- 每个 mini-batch 计算自己的均值和方差
- 使用当前 batch 的均值和方差进行归一化
推理(测试)阶段
在测试时,我们不再依赖 mini-batch,而是用整个训练过程中累积的全局均值和方差,通常用 指数移动平均 计算:
μ
g
l
o
b
a
l
=
0.9
⋅
μ
g
l
o
b
a
l
+
0.1
⋅
μ
b
a
t
c
h
\mu_{global} = 0.9 \cdot \mu_{global} + 0.1 \cdot \mu_{batch}
μglobal=0.9⋅μglobal+0.1⋅μbatch
σ
g
l
o
b
a
l
2
=
0.9
⋅
σ
g
l
o
b
a
l
2
+
0.1
⋅
σ
b
a
t
c
h
2
\sigma^2_{global} = 0.9 \cdot \sigma^2_{global} + 0.1 \cdot \sigma^2_{batch}
σglobal2=0.9⋅σglobal2+0.1⋅σbatch2
这样,即使 batch size 变为 1,BN 也能正确工作。
5. 总的来说
- Batch Normalization 在 mini-batch 级别计算均值和方差,不会跨 batch 计算。
- 每个 batch 计算自己的均值和方差,互不影响。
- 训练时用 mini-batch 统计量,推理时用全局统计量(移动平均计算)。
- BN 主要作用:
- 让每层的输入分布稳定(均值 0,方差 1)。
- 让梯度更稳定,加速训练。
- 使网络更鲁棒,不依赖权重初始化。
以上