PyTorch Instance Normalization介绍
Instance Normalization(实例归一化) 是一种标准化技术,与 Batch Normalization 类似,但它对每个样本独立地对每个通道进行归一化,而不依赖于小批量数据的统计信息。这使得它非常适合小批量训练任务以及图像生成任务(如风格迁移)。
Instance Normalization 的原理
对每个样本 xx 的每个通道 cc 独立进行标准化。
对于输入数据 (以二维输入为例):
1. 计算每个通道的均值和方差:
-
- n:样本索引。
- c:通道索引。
- H,W:输入的高度和宽度。
2. 归一化:
-
- ϵ 是一个小值,用于防止除零。
3. 缩放和平移:
- γc 和 βc 是可学习参数,用于恢复表达能力。
Instance Normalization 的特点
- 独立于批量大小:每个样本独立进行归一化,解决了小批量训练中均值和方差不稳定的问题。
- 适用于风格迁移任务:在风格迁移中,Instance Normalization 能更好地捕捉图像的风格特征。
- 不适合特征间强相关任务:Instance Normalization破坏了样本之间的特征相关性,因此不适用于依赖特征间关系的任务(如分类)。
PyTorch 中的 InstanceNorm 类
PyTorch 提供了以下三种适用于不同输入维度的 Instance Normalization 类:
torch.nn.InstanceNorm1d
:适用于一维数据(如序列或嵌入向量)。torch.nn.InstanceNorm2d
:适用于二维数据(如图像)。torch.nn.InstanceNorm3d
:适用于三维数据(如视频或体数据)。
1. torch.nn.InstanceNorm1d
参数:
num_features
:输入的通道数。eps
: