当前位置: 首页 > article >正文

Batch Normalization学习笔记

文章目录

  • 一、为何引入 Batch Normalization
  • 二、具体步骤
    • 1、训练阶段
    • 2、预测阶段
  • 三、关键代码实现
  • 四、补充
  • 五、参考文献

一、为何引入 Batch Normalization

  现在主流的卷积神经网络几乎都使用了批量归一化(Batch Normalization,BN)1,它是一种逐层归一化方法,可以对神经网络中任意的中间层进行归一化操作。我们可以从不同角度来理解为什么要引入 Batch Normalization:

① 训练时的误差表面(error surface) 可能会十分崎岖,使得做优化时容易陷入局部最优值或鞍点等。通常我们会使用各种算法如Adam等进行优化,那么能不能直接改误差表面的地貌,“把山铲平”,让它变得比较好训练呢?Batch Normalization 就是其中一个“把山铲平”的想法2。另外一个好处是,误差表面变得没那么崎岖后,我们在训练时便可以增大学习率,使得网络更快收敛。

② 对于典型的多层感知机或卷积神经网络,在训练时中间层中的变量可能具有更广的变化范围。也就是说,随着训练时间的推移,每一层的模型参数分布范围变化莫测(比如一个深层网络,反向传播更新参数时,顶层与最底层数据范围差异会比较大,因为最底层相当于通过链式法则乘了一堆偏导数,导致数据范围非常大或小):

在这里插入图片描述

变量分布中的不规则的偏移可能会阻碍网络的收敛,因此为了使各层拥有适当的数据范围,通过 Batch Normalization“强制性”地调整数据分布使其约束到更小的范围(标准正态分布),这样便可以使得训练更加稳定,且对于初始值的设置没那么敏感。调整之后示意图如下:

在这里插入图片描述

③ 深层的网络很复杂,容易过拟合。而 Batch Normalization可以作为一种隐形的正则化方法,减轻过拟合(因此有时候使用BN后,dropout显得没那么必要使用)。由于Batch Normalization是基于一个 mini batch的,因此在训练时,神经网络对一个样本的预测不仅和该样本自身相关,也和同一批次中的其他样本相关,这种选取批次的随机性,使得神经网络不会“过拟合”到某个特定样本,从而提高网络的泛化能力。

总而言之,Batch Normalization 的优点如下3

  • 不那么依赖初始值(对于初始值不用那么神经质)。
  • 可以使学习快速进行(可以增大学习率)。
  • 抑制过拟合(降低Dropout等的必要性)。


二、具体步骤

batch normalization本质是对不同样本的同一特征做标准化。

1、训练阶段

  在训练时,Batch Normalization会逐步对每个mini-batch进行归一化。具体步骤如下:

设一个mini-batch中有 m m m 个输入数据,记为集合 B = { x 1 , x 2 , ⋯   , x m } B=\{x_1,x_2,\cdots,x_m\} B={x1,x2,,xm},对该集合求均值 μ B \mu_B μB 和方差 σ B 2 \sigma_B^2 σB2
μ B ← 1 m ∑ i = 1 m x i \begin{aligned}\mu_B\leftarrow\frac{1}{m}\sum_{i=1}^mx_i\end{aligned} μBm1i=1mxi

σ B 2 ← 1 m ∑ i = 1 m ( x i − μ B ) 2 \begin{aligned}\sigma_B^2\leftarrow\frac{1}{m}\sum_{i=1}^m(x_i-\mu_B)^2\end{aligned} σB2m1i=1m(xiμB)2
接下来利用求得的均值和方差对输入数据进行归一化:
x ^ i ← x i − μ B σ B 2 + ε \hat{x}_i\leftarrow\frac{x_i-\mu_B}{\sqrt{\sigma_B^2+\varepsilon}} x^iσB2+ε xiμB
其中 ε \varepsilon ε 是一个微小值(如 10 e − 7 10e^{-7} 10e7 等),以防止出现除以0的情况。

于是便可以将输入数据转换为均值为0,方差为1的数据 { x ^ 1 , x ^ 2 , ⋯   , x ^ m } \left\{\hat{x}_1,\hat{x}_2,\cdots,\hat{x}_m\right\} {x^1,x^2,,x^m} 了。

  为了使得归一化不对网络的表示能力造成负面影响,再通过一个附加的缩放和平移变换改变新数据的取值区间(虽然归一化加快了训练速度和稳定性,但它改变了数据的原始分布。对于某些任务来说,直接使用归一化的数据可能会限制模型的表达能力,因此引入可以学习的超参数 γ \gamma γ β \beta β ,使得模型可以灵活地调整归一化后的数据分布,恢复其自由度):
y i ← γ x ^ i + β y_i\leftarrow\gamma\hat{x}_i+\beta yiγx^i+β

最后把上述所有处理插入到激活函数的前面即可(整个过程相当于一个BatchNorm层),示意图如下:

在这里插入图片描述


示意图二(其中 W W W 是全连接层, L ^ \widehat{\mathcal{L}} L 是损失函数)4

在这里插入图片描述


2、预测阶段

  在训练过程中,我们无法得知整个数据集来估计平均值和方差,所以只能根据每个小批次(mini-batch)的平均值和方差不断训练模型。 而在预测模式下,一般使用整个预测数据集的均值和方差(因为这时候已经经过完整的训练了,因此可以得知全局信息)。为了节省存储资源,实际中大多采用**移动平均(moving average)**的方式来计算全局的均值和方差。移动平均的计算过程如下式所示:
μ t o t a l = λ ∗ μ t o t a l + ( 1 − λ ) ∗ μ B σ t o t a l 2 = λ ∗ σ t o t a l 2 + ( 1 − λ ) ∗ σ B 2 \begin{aligned}\mu_{total}&=\lambda*\mu_{total}+(1-\lambda)*\mu_{\mathcal{B}}\\\sigma_{total}^2&=\lambda*\sigma_{total}^2+(1-\lambda)*\sigma_{\mathcal{B}}^2\end{aligned} μtotalσtotal2=λμtotal+(1λ)μB=λσtotal2+(1λ)σB2



三、关键代码实现

以动手学深度学习第二版5的代码为例(Pytorch):

import torch
from torch import nn
from d2l import torch as d2l


def batch_norm(X, gamma, beta, moving_mean, moving_var, eps, momentum):
    # 通过is_grad_enabled来判断当前模式是训练模式还是预测模式
    if not torch.is_grad_enabled():
        # 如果是在预测模式下,直接使用传入的移动平均所得的均值和方差
        X_hat = (X - moving_mean) / torch.sqrt(moving_var + eps)
    else:
        assert len(X.shape) in (2, 4)
        if len(X.shape) == 2:
            # 使用全连接层的情况,计算特征维上的均值和方差
            mean = X.mean(dim=0)
            var = ((X - mean) ** 2).mean(dim=0)
        else:
            # 使用二维卷积层的情况,计算通道维上(axis=1)的均值和方差。
            # 这里我们需要保持X的形状以便后面可以做广播运算
            mean = X.mean(dim=(0, 2, 3), keepdim=True)
            var = ((X - mean) ** 2).mean(dim=(0, 2, 3), keepdim=True)
        # 训练模式下,用当前的均值和方差做标准化
        X_hat = (X - mean) / torch.sqrt(var + eps)
        # 更新移动平均的均值和方差
        moving_mean = momentum * moving_mean + (1.0 - momentum) * mean
        moving_var = momentum * moving_var + (1.0 - momentum) * var
    Y = gamma * X_hat + beta  # 缩放和移位
    return Y, moving_mean.data, moving_var.data

解释几个可能的疑惑点:

为什么分为全连接层和卷积层两种情况?

  全连接层和卷积层比较典型,它们的批量规范化实现略有不同:当作用在全连接层时,实际上是作用在特征维;当作用在卷积层上时,实际上是作用在通道维(将通道维当成是卷积层的特征维)。

为什么作用在通道维?因为每个通道都有自己的拉伸参数偏移参数,并且都是标量。例如下图6所示:

在这里插入图片描述

上图各颜色通道中的像素值通常具有不同的分布和范围,这种不一致性可能会导致训练出错或网络不收敛等问题。因此需要通过Normalize操作,将每个通道的像素值标准化为均值为0、标准差为1的分布,使得所有通道的像素值范围和分布一致。(因此假如扩展到n维张量,你也只需对通道维求均值即可)



为什么全连接层设置 dim=0,而卷积层设置 dim=(0,2,3)

  全连接层是二维的,即(batch_size, feature) ,计算全连接层时,计算的是特征维的均值和方差,而每个行代表一个样本,每列代表一个特征。

dim=0dim=1 的含义

  • dim=0 表示沿着 “行” 的方向进行操作(也就是跨样本的操作),即对每个特征维的所有样本值进行聚合计算,比如求均值、方差等。
  • dim=1 表示沿着 “列” 的方向进行操作(也就是跨特征的操作),即对每个样本的所有特征值聚合计算。

下图重量/甜度/颜色评分为苹果的特征维,我们来计算特征维的均值:

苹果编号重量(克)甜度(°Bx)颜色评分(1 - 10)
苹果 1200127
苹果 2180106
苹果 3220148

dim=0代表行,dim=1代表列,既然我们要求特征维的均值,那么需要让 dim=0 ,也就是沿着行的方向“拍扁”。上图沿着行方向“拍扁”后得到的特征维的均值如下:

计算结果
重量:(200 + 180 + 220) / 3 = 200
甜度:(12 + 10 + 14) / 3 = 12
颜色评分: (7 + 6 + 8) / 3 = 7

那么卷积层 (batch_size, channels, height, width)设dim=(0,2,3)也很好理解了,我们需要得到通道维的均值,那么就得把其它几个维都“拍扁”。



为什么全连接层无需设置keepdim=True 而卷积层需设置keepdim=True

  由于pytorch的广播机制,只会从左边补1,换个说法即只会补齐最外层的维度,因此前者无需设置而后者需设置keepdim=True来保证广播机制的正常启动。

有点抽象,举例子说明:

# 构造一个形状为 (2, 3, 4, 5, 6) 的五维张量
A = torch.randn(2, 3, 4, 5, 6)

# 打印张量 A 的形状
print("张量 A 的形状:", A.shape)

# 构造一个形状为 (3, 4, 5, 6) 的四维张量
B = torch.randn(3, 4, 5, 6)
print("张量 B 的形状:", B.shape)

try:
    # 尝试执行 A + B
    A + B
    print("可以成功输出")
except Exception as e:
    # 如果发生异常,打印失败信息
    print("失败输出:", e)

输出结果为:

张量 A 的形状: torch.Size([2, 3, 4, 5, 6])
张量 B 的形状: torch.Size([3, 4, 5, 6])
可以成功输出

因为广播机制会让B的维度补齐成(1,3,4,5,6),也就是最左边补“1”,于是就可以执行 A+B操作了。

而如下情况,即仅仅稍微改变一下B的形状:

# 构造一个形状为 (2, 3, 4, 5, 6) 的五维张量
A = torch.randn(2, 3, 4, 5, 6)

# 打印张量 A 的形状
print("张量 A 的形状:", A.shape)

# 构造一个形状为 (2, 3, 4, 5) 的四维张量
B = torch.randn(2, 3, 4, 5)
print("张量 B 的形状:", B.shape)

try:
    # 尝试执行 A + B
    A + B
    print("可以成功输出")
except Exception as e:
    # 如果发生异常,打印失败信息
    print("失败输出:", e)

输出结果为:

张量 A 的形状: torch.Size([2, 3, 4, 5, 6])
张量 B 的形状: torch.Size([2, 3, 4, 5])
失败输出: The size of tensor a (6) must match the size of tensor b (5) at non-singleton dimension 4

因为广播机制只会往最左边补“1”,而这里B补“1”后形状变成(1,2,3,4,5),依旧和张量A的形状不一致,所以不能做相加操作。

回到 Batch-Normalization 的代码:

 if len(X.shape) == 2:
            # 使用全连接层的情况,计算特征维上的均值和方差
            mean = X.mean(dim=0)
            var = ((X - mean) ** 2).mean(dim=0)
        else:
            # 使用二维卷积层的情况,计算通道维上(axis=1)的均值和方差。
            # 这里我们需要保持X的形状以便后面可以做广播运算
            mean = X.mean(dim=(0, 2, 3), keepdim=True)
            var = ((X - mean) ** 2).mean(dim=(0, 2, 3), keepdim=True)

我们知道, dim等于哪个维,就是将那个维进行“拍扁”。

对于全连接层(batch_size, feature),设置 dim=0时,相当于将第 0 维“拍扁”,拍扁了相当于那个维直接“消失”了,此时meanvar的形状为(feature)。于是直接可以通过广播机制,在最左边补“1”,变成(1, feature),便可以和变量 X 一起计算了【X的形状(batch_size, feature)】。

而卷积层 (batch_size, channels, height, width)dim=(0,2,3)时,相当于将第 0,2,3 维“拍扁”,此时meanvar的形状为(channels),而 X 的形状是 (batch_size, channels, height, width),你得将meanvar的形状扩展到和 X 一致才可以进行计算,而广播机制只能往最左边补“1”,因此(channels)无法扩展成和X一致的形状,顶多扩展成(1, channels),所以无法和 X 进行计算,程序报错。

因此需要对卷积层使用 keepdim=True这个参数,这样meanvar的形状就可以扩展成 (1, channels, 1, 1),与X一致,才能进行接下来的计算。



if not torch.is_grad_enabled() 为什么可以判断是训练还是预测模式?

  反向传播时会涉及梯度的计算,而只有训练时才会进行反向传播,因此可以通过是否进行梯度的计算来判断训练模式还是预测模式。



四、补充

  原论文中提出Batch-Normalization的优点是减少了内部协变量转移(internal covariate shift,简单来说就是变量值的分布在训练过程中会发生变化,但是这种解释在后续论文被证实比较不严谨,发现它并没有减少内部协变量的转移 [Santurkar et al.,2018]。



五、参考文献


  1. Ioffe S. Batch normalization: Accelerating deep network training by reducing internal covariate shift[J]. arXiv preprint arXiv:1502.03167, 2015. ↩︎

  2. 王琦, 杨毅远, 江季, 深度学习详解, 北京:人民邮电出版社, 2024 ↩︎

  3. (日)斋藤康毅著, 陆宇杰译, 深度学习入门基于Python的理论与实现, 北京:人民邮电出版社, 2018.07 ↩︎

  4. Santurkar S, Tsipras D, Ilyas A, et al. How does batch normalization help optimization?[J]. Advances in neural information processing systems, 2018, 31. ↩︎

  5. 阿斯顿·张(Aston Zhang), 李沐(Mu Li), [美] 扎卡里·C. 立顿(Zachary C. Lipton), 等. 动手学深度学习(PyTorch版)[M]. 第二版. 人民邮电出版社, 2023-2. ↩︎

  6. 【Batch Normalization】 https://www.bilibili.com/video/BV11s4y1c7pg/?share_source=copy_web&vd_source=199a3f4e3a9db6061e1523e94505165a ↩︎


http://www.kler.cn/a/519495.html

相关文章:

  • 一次端口监听正常,tcpdump无法监听到指定端口报文问题分析
  • cuda reductionreduce
  • 【科研建模】Pycaret自动机器学习框架使用流程及多分类项目实战案例详解
  • 【QT】-explicit关键字
  • 利用ML.NET精准提取人名
  • 【工程篇】01:GPU可用测试代码
  • 77,【1】.[CISCN2019 华东南赛区]Web4
  • Java数据结构 (链表反转(LinkedList----Leetcode206))
  • Qt网络通信(TCP/UDP)
  • 运维实战---多种方式在Linux中部署并初始化MySQL
  • DeepSeek_R1论文翻译稿
  • RV1126画面质量五:Profile和编码等级讲解
  • 【北京大学 凸优化】Lec1 凸优化问题定义
  • Linux Futex学习笔记
  • 第 10 课 Python 内置函数
  • 在 Ubuntu22.04 上安装 Splunk
  • 2025年1月22日(什么是扫频)
  • vue router路由复用及刷新问题研究
  • 从 VJ 拥塞控制到 BBR:ACK 自时钟和 pacing
  • 《Kotlin核心编程》上篇
  • 【动态规划】杨表
  • YOLOv11改进,YOLOv11检测头融合DSConv(动态蛇形卷积),并添加小目标检测层(四头检测),适合目标检测、分割等任务
  • SQL注入漏洞之SQL注入基础知识点 如何检测是否含有sql注入漏洞
  • 【leetcode100】二叉树的层序遍历
  • Elasticsearch中的度量聚合:深度解析与实战应用
  • mock可视化生成前端代码