Gemma2DecoderLayer 解析:Pre-FFW 和 Post-FFW LayerNorm 的作用
Gemma2DecoderLayer 解析:Pre-FFW 和 Post-FFW LayerNorm 的作用
1. 引言
在大规模 Transformer 模型(如 LLaMA、Gemma)中,层归一化(Layer Normalization,LN)是确保模型稳定性和训练收敛性的关键技术。Gemma2DecoderLayer 在前馈网络(Feedforward Network, FFW)部分引入了前归一化(Pre-FFW LayerNorm) 和 后归一化(Post-FFW LayerNorm) 两种方式,以进一步优化训练和推理的稳定性。
本文将深入分析:
- 什么是 Pre-FFW LayerNorm 和 Post-FFW LayerNorm
- 为什么需要这两种归一化
- 它们如何影响 Transformer 计算
- 在
Gemma2DecoderLayer
代码中的具体实现
2. Transformer 层结构
在标准 Transformer 解码器(Decoder Layer)中,每一层由自注意力(Self-Attention)和前馈网络(Feedforward, FFW)组成:
+------------------+
| Self-Attention |
+------------------+
↓
+------------------+
| Feedforward (FFW) |
+------------------+
为了提高训练稳定性,通常会在自注意力和前馈网络的输入或输出处添加层归一化(LayerNorm),以确保分布稳定。
在 Gemma2DecoderLayer 中,FFW 归一化分为两种:
- Pre-FFW LayerNorm(前归一化)
- Post-FFW LayerNorm(后归一化)
3. 什么是 Pre-FFW LayerNorm 和 Post-FFW LayerNorm?
3.1 Pre-FFW LayerNorm(前归一化)
前归一化的思想是:在进入前馈网络(FFW)之前先进行层归一化,以稳定输入数据的分布。
Norm
(
X
)
=
X
−
μ
σ
+
ϵ
\text{Norm}(X) = \frac{X - \mu}{\sigma + \epsilon}
Norm(X)=σ+ϵX−μ
FFW
(
Norm
(
X
)
)
\text{FFW}(\text{Norm}(X))
FFW(Norm(X))
在代码中:
if self.pre_feedforward_layernorm is not None:
hidden_states = self.pre_feedforward_layernorm(hidden_states)
作用:
- 使 FFW 层的输入具有更稳定的分布,避免梯度爆炸或梯度消失。
- 在 Transformer 预归一化结构(Pre-Norm Transformer) 中常用,如 GPT-3 和 LLaMA。
3.2 Post-FFW LayerNorm(后归一化)
后归一化的思路是:在 FFW 计算完成后进行层归一化,确保前馈网络的输出分布稳定。
FFW
(
X
)
→
Norm
(
FFW
(
X
)
)
\text{FFW}(X) \rightarrow \text{Norm}(\text{FFW}(X))
FFW(X)→Norm(FFW(X))
在代码中:
if self.post_feedforward_layernorm is not None:
hidden_states = self.post_feedforward_layernorm(hidden_states)
作用:
- 让前馈网络输出分布更稳定,使后续层的输入更具一致性。
- 在 标准 Transformer(Post-Norm Transformer) 结构中常见,如原始的 BERT。
4. 为什么需要 Pre-FFW 和 Post-FFW?
Transformer 训练过程中,如果 LayerNorm 放置不当,可能会导致:
- 梯度爆炸或梯度消失
- 收敛速度变慢
- 长文本任务中不稳定
4.1 为什么要使用 Pre-FFW LayerNorm?
在大规模 Transformer(如 GPT-3、LLaMA)中,Pre-LN(Pre-Norm Transformer) 比标准 Transformer 更稳定,因为:
- 标准 Transformer(Post-LN) 在深度增加时,容易出现梯度消失问题,导致训练难以收敛。
- Pre-LN 先归一化输入,使梯度更稳定,能更快收敛。
4.2 为什么还要 Post-FFW LayerNorm?
有些架构仍然保留 Post-FFW LayerNorm,原因:
- Post-LN 可以让 FFW 的输出分布更稳定,避免梯度抖动。
- 在推理阶段,Post-LN 可能有更好的表现,特别是在处理长文本时。
5. 代码解析
5.1 Gemma2DecoderLayer
的 LayerNorm 结构
self.pre_feedforward_layernorm = (
RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
if config.use_pre_ffw_norm
else None
)
self.post_feedforward_layernorm = (
RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
if config.use_post_ffw_norm
else None
)
- 如果
use_pre_ffw_norm=True
,则启用 Pre-FFW LayerNorm - 如果
use_post_ffw_norm=True
,则启用 Post-FFW LayerNorm - 可以灵活配置是否使用 Pre-LN 或 Post-LN。
5.2 前馈网络部分
residual = hidden_states
if self.pre_feedforward_layernorm is not None:
hidden_states = self.pre_feedforward_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
if self.post_feedforward_layernorm is not None:
hidden_states = self.post_feedforward_layernorm(hidden_states)
hidden_states = residual + hidden_states
- 先检查是否有 Pre-FFW LN,如果有,归一化
hidden_states
- 进入 MLP 前馈网络
- 检查是否有 Post-FFW LN,如果有,再归一化
hidden_states
- 残差连接(Residual Connection) 保持信息流稳定。
6. Pre-LN 和 Post-LN 的对比
归一化方式 | 计算公式 | 优点 | 缺点 | 使用场景 |
---|---|---|---|---|
Pre-FFW LayerNorm | Norm(X) -> FFW(X) | 稳定梯度,收敛快 | 影响表达能力 | GPT-3, LLaMA |
Post-FFW LayerNorm | FFW(X) -> Norm(X) | 保持分布一致性 | 可能导致深度梯度消失 | BERT, T5 |
7. 总结
- Transformer 计算中 LayerNorm 影响模型稳定性和训练收敛速度。
- Pre-FFW LayerNorm 先归一化输入,适用于深层网络,避免梯度消失(Pre-Norm Transformer)。
- Post-FFW LayerNorm 归一化输出,保持输出分布稳定,适用于推理任务。
- Gemma2DecoderLayer 结合 Pre-FFW 和 Post-FFW,提供更灵活的归一化方式,可以根据不同任务需求调整归一化策略。
🚀 理解 LayerNorm 的作用,对于优化 Transformer 训练至关重要!
附录
源代码:
class Gemma2DecoderLayer(nn.Module):
def __init__(
self,
config: gemma_config.GemmaConfig,
attn_type: gemma_config.AttentionType,
):
super().__init__()
self.self_attn = GemmaAttention(
hidden_size=config.hidden_size,
num_heads=config.num_attention_heads,
num_kv_heads=config.num_key_value_heads,
attn_logit_softcapping=config.attn_logit_softcapping,
query_pre_attn_scalar=config.query_pre_attn_scalar,
head_dim=config.head_dim,
quant=config.quant,
attn_type=attn_type,
sliding_window_size=config.sliding_window_size,
)
self.mlp = GemmaMLP(
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
quant=config.quant,
)
self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.pre_feedforward_layernorm = (
RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
if config.use_pre_ffw_norm
else None
)
self.post_feedforward_layernorm = (
RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
if config.use_post_ffw_norm
else None
)
def forward(
self,
hidden_states: torch.Tensor,
freqs_cis: torch.Tensor,
kv_write_indices: torch.Tensor,
kv_cache: Tuple[torch.Tensor, torch.Tensor],
mask: torch.Tensor,
) -> torch.Tensor:
# Self Attention
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
hidden_states = self.self_attn(
hidden_states=hidden_states,
freqs_cis=freqs_cis,
kv_write_indices=kv_write_indices,
kv_cache=kv_cache,
mask=mask,
)
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = residual + hidden_states
# MLP
residual = hidden_states
if self.pre_feedforward_layernorm is not None:
hidden_states = self.pre_feedforward_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
if self.post_feedforward_layernorm is not None:
hidden_states = self.post_feedforward_layernorm(hidden_states)
hidden_states = residual + hidden_states
return hidden_states
后记
2025年2月24日16点26分于上海,在GPT 4o大模型辅助下完成。