颠覆大模型归一化!Meta | 提出动态Tanh:DyT,无归一化的 Transformer 性能更强
源自: AINLPer(每日干货分享!!)
编辑: ShuYini
校稿: ShuYini
时间: 2025-3-16
更多:>>>>大模型/AIGC、学术前沿的知识分享!
引言
归一化层在现代神经网络中无处不在,常看到的就是Post-Normalization、Pre-Normalization,其中,**前向归一化(Pre-Normalization)**是在操作(如线性变换、自注意力等)之前,先对输入数据进行归一化处理;**后向归一化(Post-Normalization)**则是放在每个子模块的输出之后。随着大模型的发展,模型的归一化层也逐渐从后归一化改到了前规划一化,且当前主流的大模型基本都采用了后归一化。
今天Meta的这篇文章突破归一化层不可或缺传统观点,具体来说:提出了一种名为动态Tanh(DyT)的简单技术,用于替代Transformer中的归一化层,实验表明,使用DyT的无归一化Transformer在多种任务和领域中均能达到或超过传统归一化模型的性能,且大多无需超参数调整。(前向归一化和后向归一化的区别,这道八股面试题是不是要没有了~)
论文:https://arxiv.org/pdf/2503.10622
背景介绍
归一化最早始于2015年的批量归一化发明,它帮助视觉模型实现了更快更好地收敛。受其启发,十年间也出现了多种归一化层变体。直到今天,层归一化层仍然是现代神经网络的基本组件,特别是在Transformer架构中广泛应用的层归一化(Layer Norm)。大多数归一化层共享一个通用公式。对于形状为
(
B
,
T
,
C
)
(B, T, C)
(B,T,C) 的输入
x
x
x,其中
B
B
B 是批量大小,
T
T
T 是Token数量,
C
C
C 是每个Token的嵌入维度,归一化后的输出通常计算如下:
NormaLization
(
x
)
=
γ
∗
(
x
−
μ
σ
2
+
ϵ
)
+
β
\text{NormaLization}(x) = \gamma * \left(\frac{x - \mu}{\sqrt{\sigma^2 + \epsilon}}\right) + \beta
NormaLization(x)=γ∗(σ2+ϵx−μ)+β
其中, ϵ \epsilon ϵ 是一个小常数, γ \gamma γ 和 β \beta β 是形状为 ( C , ) (C,) (C,) 的可学习向量参数。它们是“缩放”和“偏移”仿射参数,使输出可以在任意范围内变化。 μ \mu μ 和 σ 2 \sigma^2 σ2 代表输入的均值和方差。不同方法主要区别在于如何计算这些统计量,这导致 μ \mu μ 和 σ 2 \sigma^2 σ2 具有不同的维度,并在计算过程中进行广播。
目前,归一化主要分为批归一化(Batch Normalization, BN)、层归一化(Layer Normalization, LN)。其中:
批归一化BN主要用于卷积网络(ConvNet)模型,BN 在批量和Token维度上计算均值和方差,具体计算如下:
μ
k
=
1
B
T
∑
i
,
j
x
i
j
k
,
σ
k
2
=
1
B
T
∑
i
,
j
(
x
i
j
k
−
μ
k
)
2
\mu_k = \frac{1}{BT} \sum_{i,j} x_{ijk}, \quad \sigma^2_k = \frac{1}{BT} \sum_{i,j} (x_{ijk} - \mu_k)^2
μk=BT1i,j∑xijk,σk2=BT1i,j∑(xijk−μk)2
其他在卷积网络中流行的归一化层包括组归一化(Group Normalization, GN)和实例归一化(Instance Normalization, IN)。这些方法最初是为目标检测和图像风格化等专门任务提出的。它们在总体公式上与 BN 相似,但在计算统计量的维度和范围上有所不同。
层归一化BN在每个样本的每一层独立计算统计量,其中计算公式如下:
μ
i
j
=
1
C
∑
k
x
i
j
k
,
σ
i
j
2
=
1
C
∑
k
(
x
i
j
k
−
μ
i
j
)
2
\mu_{ij} = \frac{1}{C} \sum_k x_{ijk}, \quad \sigma^2_{ij} = \frac{1}{C} \sum_k (x_{ijk} - \mu_{ij})^2
μij=C1k∑xijk,σij2=C1k∑(xijk−μij)2
但目前常见的大模型常用的归一化为:RMSNorm,该方法去掉了均值归一化步骤,并将均值设为 0 0 0、方差设为 1 C ∑ k x i j k 2 \frac{1}{C} \sum_k x_{ijk}^2 C1∑kxijk2 来简化 LN,其主要因为其简单性和通用性。
本来岁月静好,本文作者却打破了这一传统观念,提出了Transformer中归一化层的简单替代方案:Tanh。作者研究发现Layer Norm层产生类似tanh的S形曲线输出,既缩放输入又压缩极端值。因此,提出了动态双曲正切(Dynamic Tanh,DyT),定义为:DyT(x) = tanh(αx),其中α是可学习参数。DyT不需要计算统计量就能实现与归一化层相似的效果。
应用DyT很简单:直接用它替换现有的归一化层。实验表明,使用DyT的模型能稳定训练并达到高性能,通常不需要调整原有的训练参数。我们的研究挑战了归一化层不可或缺的观念,并提供了新的见解。初步测量还显示DyT提高了训练和推理速度,使其成为高效网络设计的候选方案。
动态双曲正切:DyT
动态双曲正切函数 (Dynamic Tanh, DyT) 作为神经网络中规范化层的直接替代方案,灵感来源于规范化层形状与缩放双曲正切函数之间的相似性。DyT层定义为:
D
y
T
(
x
)
=
γ
∗
t
a
n
h
(
α
x
)
+
β
DyT(x) = γ * tanh(αx) + β
DyT(x)=γ∗tanh(αx)+β
其中
α
α
α是可学习的标量参数,允许根据输入范围进行不同的缩放;
γ
γ
γ和
β
β
β则是可学习的、按通道的向量参数,与规范化层中使用的参数类似,使输出能够缩放回适当的范围。
DyT的实现与集成相当直接,可以替代现有架构中的规范化层,包括注意力模块、前馈神经网络模块和最终规范化层内的规范化层,同时不改变激活函数或网络的其他组件,且几乎不需要调整超参数就能表现良好。在参数初始化方面, γ γ γ初始化为全1向量, β β β初始化为全0向量,而 α α α通常初始化为0.5(LLM训练除外)。
值得注意的是,DyT并非一种新型的规范化层,因为它独立地处理输入张量中的每个元素,不计算统计量或执行聚合操作。然而,它通过非线性方式压缩极端值并几乎线性地转换输入的中心部分,保留了规范化的核心效果。伪代码实现展示了DyT如何通过简洁的类定义被整合到类似PyTorch的框架中,处理初始化和前向传播功能(如下图所示)。
为什么DyT可以替代归一化?
作者训练了ViT-B、wav2vec 2.0 Large Transformer、DiT-XL三个模型,然后从三个模型网络中抽取一个小批量样本,并执行前向传播。测量归一化层前后的张量,即在执行归一化操作前后的张量由于 LN 保持输入张量的维度不变,因此可以直接可视化输入和输出张量元素之间的一一对应关系,并绘制其映射关系(如下图所示)。
可以发现在较浅的 LN 层(第一列),输入-输出关系大多是线性的,在图中类似于一条直线。然而,在较深的 LN 层,大多数曲线的形状类似完整或部分的 S 形曲线,与 tanh 函数的形状高度相似。通常情况下,我们会期望 LN 只是线性变换输入张量,例如减去均值并除以标准差,这些都是线性操作。然而,LN 以每个Token为单位进行归一化,仅线性变换每个Token的激活值。由于不同Token的均值和标准差不同,因此线性并不适用于输入张量的所有激活。因此,作者惊讶地发现,LN 的非线性变换高度类似于缩放后的 tanh 函数。
LN 主要作用是 压缩(squash)极端值,使其更接近大多数点的范围,而这无法通过简单的仿射变换层来实现。因此作者推测,这种非线性和非均匀的压缩效应是归一化层不可或缺的重要特性。
LN 层如何对每个Token执行线性变换的同时,还能以非线性方式压缩极端值?为了解释这一点,将点按 Token(tokens) 和 通道(channels) 分组,并可视化它们的分布情况(如下图所示)。
按Token分组可以看到每个Token的所有点形成一条直线,由于每个Token的方差不同,斜率也不同。输入值范围较小的Token,其方差较小,因此归一化层使用较小的标准差进行归一化,使得图中斜率变大。这些曲线整体形成 S 形,与 tanh 函数类似。
按通道分组 可以看到发现不同通道的激活值斜率变化极大,其中,少数通道(如红色、绿色、粉色)斜率远大于其他通道,表明这些通道受到 LN 影响最大。
这些结果表明,LN 在每个Token级别上是线性的,但在整个模型中呈现出非线性行为。这种非均匀的归一化作用,使得 LN 能够有效地限制极端值,提高模型的稳定性和表达能力。正是基于以上分析,作者才想出使用 tanh 函数来替代归一化层。
实验结果
ViT-B 和 ConvNeXt-B 模型的训练损失曲线。两种模型类型的损失曲线在 LN 和 DyT 之间表现出相似的模式,这表明 LN 和 DyT 可能具有相似的学习动态。
下图展示了两种方式在LLaMA模型预训练损失标展。DyT 和 RMNSNorm 模型的损失曲线在不同模型大小之间紧密相关。
使用 RMSNorm、DyT 分别对 LLaMA 7B 模型进行基准测试,通过使用单个 4096 个 token 序列测量 100 次前向传递(推理)和 100 次前向-后向传递(训练)所需的总时间,DyT 层显著减少了计算时间。
作者同时也指出了DyT的局限性,DyT 很难在 ResNets 等经典网络中直接取代 BN。
更多:>>>>大模型/AIGC、学术前沿的知识分享!
更多:>>>>大模型/AIGC、学术前沿的知识分享!