当前位置: 首页 > article >正文

Gemma2DecoderLayer 解析:Pre-FFW 和 Post-FFW LayerNorm 的作用

Gemma2DecoderLayer 解析:Pre-FFW 和 Post-FFW LayerNorm 的作用

1. 引言

在大规模 Transformer 模型(如 LLaMA、Gemma)中,层归一化(Layer Normalization,LN)是确保模型稳定性和训练收敛性的关键技术。Gemma2DecoderLayer 在前馈网络(Feedforward Network, FFW)部分引入了前归一化(Pre-FFW LayerNorm)后归一化(Post-FFW LayerNorm) 两种方式,以进一步优化训练和推理的稳定性。

本文将深入分析:

  • 什么是 Pre-FFW LayerNormPost-FFW LayerNorm
  • 为什么需要这两种归一化
  • 它们如何影响 Transformer 计算
  • Gemma2DecoderLayer 代码中的具体实现

2. Transformer 层结构

在标准 Transformer 解码器(Decoder Layer)中,每一层由自注意力(Self-Attention)和前馈网络(Feedforward, FFW)组成:

+------------------+
| Self-Attention  |
+------------------+
       ↓
+------------------+
| Feedforward (FFW) |
+------------------+

为了提高训练稳定性,通常会在自注意力和前馈网络的输入或输出处添加层归一化(LayerNorm),以确保分布稳定。

Gemma2DecoderLayer 中,FFW 归一化分为两种:

  1. Pre-FFW LayerNorm(前归一化)
  2. Post-FFW LayerNorm(后归一化)

3. 什么是 Pre-FFW LayerNorm 和 Post-FFW LayerNorm?

3.1 Pre-FFW LayerNorm(前归一化)

前归一化的思想是:在进入前馈网络(FFW)之前先进行层归一化,以稳定输入数据的分布。
Norm ( X ) = X − μ σ + ϵ \text{Norm}(X) = \frac{X - \mu}{\sigma + \epsilon} Norm(X)=σ+ϵXμ
FFW ( Norm ( X ) ) \text{FFW}(\text{Norm}(X)) FFW(Norm(X))
在代码中:

if self.pre_feedforward_layernorm is not None:
    hidden_states = self.pre_feedforward_layernorm(hidden_states)

作用:

  • 使 FFW 层的输入具有更稳定的分布,避免梯度爆炸或梯度消失。
  • Transformer 预归一化结构(Pre-Norm Transformer) 中常用,如 GPT-3 和 LLaMA。

3.2 Post-FFW LayerNorm(后归一化)

后归一化的思路是:在 FFW 计算完成后进行层归一化,确保前馈网络的输出分布稳定。
FFW ( X ) → Norm ( FFW ( X ) ) \text{FFW}(X) \rightarrow \text{Norm}(\text{FFW}(X)) FFW(X)Norm(FFW(X))
在代码中:

if self.post_feedforward_layernorm is not None:
    hidden_states = self.post_feedforward_layernorm(hidden_states)

作用:

  • 让前馈网络输出分布更稳定,使后续层的输入更具一致性。
  • 标准 Transformer(Post-Norm Transformer) 结构中常见,如原始的 BERT。

4. 为什么需要 Pre-FFW 和 Post-FFW?

Transformer 训练过程中,如果 LayerNorm 放置不当,可能会导致:

  • 梯度爆炸或梯度消失
  • 收敛速度变慢
  • 长文本任务中不稳定

4.1 为什么要使用 Pre-FFW LayerNorm?

在大规模 Transformer(如 GPT-3、LLaMA)中,Pre-LN(Pre-Norm Transformer) 比标准 Transformer 更稳定,因为:

  • 标准 Transformer(Post-LN) 在深度增加时,容易出现梯度消失问题,导致训练难以收敛。
  • Pre-LN 先归一化输入,使梯度更稳定,能更快收敛。

4.2 为什么还要 Post-FFW LayerNorm?

有些架构仍然保留 Post-FFW LayerNorm,原因:

  • Post-LN 可以让 FFW 的输出分布更稳定,避免梯度抖动。
  • 在推理阶段,Post-LN 可能有更好的表现,特别是在处理长文本时。

5. 代码解析

5.1 Gemma2DecoderLayer 的 LayerNorm 结构

self.pre_feedforward_layernorm = (
    RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
    if config.use_pre_ffw_norm
    else None
)
self.post_feedforward_layernorm = (
    RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
    if config.use_post_ffw_norm
    else None
)
  • 如果 use_pre_ffw_norm=True,则启用 Pre-FFW LayerNorm
  • 如果 use_post_ffw_norm=True,则启用 Post-FFW LayerNorm
  • 可以灵活配置是否使用 Pre-LN 或 Post-LN。

5.2 前馈网络部分

residual = hidden_states
if self.pre_feedforward_layernorm is not None:
    hidden_states = self.pre_feedforward_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
if self.post_feedforward_layernorm is not None:
    hidden_states = self.post_feedforward_layernorm(hidden_states)
hidden_states = residual + hidden_states
  • 先检查是否有 Pre-FFW LN,如果有,归一化 hidden_states
  • 进入 MLP 前馈网络
  • 检查是否有 Post-FFW LN,如果有,再归一化 hidden_states
  • 残差连接(Residual Connection) 保持信息流稳定。

6. Pre-LN 和 Post-LN 的对比

归一化方式计算公式优点缺点使用场景
Pre-FFW LayerNormNorm(X) -> FFW(X)稳定梯度,收敛快影响表达能力GPT-3, LLaMA
Post-FFW LayerNormFFW(X) -> Norm(X)保持分布一致性可能导致深度梯度消失BERT, T5

7. 总结

  • Transformer 计算中 LayerNorm 影响模型稳定性和训练收敛速度。
  • Pre-FFW LayerNorm 先归一化输入,适用于深层网络,避免梯度消失(Pre-Norm Transformer)。
  • Post-FFW LayerNorm 归一化输出,保持输出分布稳定,适用于推理任务
  • Gemma2DecoderLayer 结合 Pre-FFW 和 Post-FFW,提供更灵活的归一化方式,可以根据不同任务需求调整归一化策略。

🚀 理解 LayerNorm 的作用,对于优化 Transformer 训练至关重要!

附录

源代码:

class Gemma2DecoderLayer(nn.Module):
    def __init__(
        self,
        config: gemma_config.GemmaConfig,
        attn_type: gemma_config.AttentionType,
    ):
        super().__init__()
        self.self_attn = GemmaAttention(
            hidden_size=config.hidden_size,
            num_heads=config.num_attention_heads,
            num_kv_heads=config.num_key_value_heads,
            attn_logit_softcapping=config.attn_logit_softcapping,
            query_pre_attn_scalar=config.query_pre_attn_scalar,
            head_dim=config.head_dim,
            quant=config.quant,
            attn_type=attn_type,
            sliding_window_size=config.sliding_window_size,
        )
        self.mlp = GemmaMLP(
            hidden_size=config.hidden_size,
            intermediate_size=config.intermediate_size,
            quant=config.quant,
        )
        self.input_layernorm = RMSNorm(config.hidden_size,
                                       eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(config.hidden_size,
                                                eps=config.rms_norm_eps)
        self.pre_feedforward_layernorm = (
            RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
            if config.use_pre_ffw_norm
            else None
        )
        self.post_feedforward_layernorm = (
            RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
            if config.use_post_ffw_norm
            else None
        )

    def forward(
        self,
        hidden_states: torch.Tensor,
        freqs_cis: torch.Tensor,
        kv_write_indices: torch.Tensor,
        kv_cache: Tuple[torch.Tensor, torch.Tensor],
        mask: torch.Tensor,
    ) -> torch.Tensor:
        # Self Attention
        residual = hidden_states
        hidden_states = self.input_layernorm(hidden_states)
        hidden_states = self.self_attn(
            hidden_states=hidden_states,
            freqs_cis=freqs_cis,
            kv_write_indices=kv_write_indices,
            kv_cache=kv_cache,
            mask=mask,
        )
        hidden_states = self.post_attention_layernorm(hidden_states)
        hidden_states = residual + hidden_states

        # MLP
        residual = hidden_states
        if self.pre_feedforward_layernorm is not None:
            hidden_states = self.pre_feedforward_layernorm(hidden_states)
        hidden_states = self.mlp(hidden_states)
        if self.post_feedforward_layernorm is not None:
            hidden_states = self.post_feedforward_layernorm(hidden_states)
        hidden_states = residual + hidden_states

        return hidden_states

后记

2025年2月24日16点26分于上海,在GPT 4o大模型辅助下完成。


http://www.kler.cn/a/561304.html

相关文章:

  • DDR3模块、HDMI、晶振的布局原则
  • 山东大学软件学院nosql实验二
  • 解决鼠标唤醒关屏状态下的笔记本
  • 开源嵌入式实时操作系统uC/OS-II介绍
  • 学习笔记--电磁兼容性EMC
  • DeepSeek 15天指导手册——从入门到精通 PDF(附下载)
  • Figure自研模型Helix发布,人形机器人迈向新纪元?
  • C++程序员内功修炼——Linux C/C++编程技术汇总
  • 如何实现使用DeepSeek的CV模型对管道内模糊、低光照或水渍干扰的图像进行去噪、超分辨率重建。...
  • 解锁Redis的深层能力:事务与消息队列的最佳实践
  • 华为数通 HCIP-Datacom H12-831 新题
  • 使用 DeepSeek + OmniParser v2 + UIAutomation 实现 GUI 应用自动化测试的探索
  • c++中sleep是什么意思(不是Sleep() )
  • Spark MLlib中的机器学习算法及其应用场景
  • 毕业项目推荐:基于yolov8/yolov5/yolo11的番茄成熟度检测识别系统(python+卷积神经网络)
  • sqlclchery面对复杂的sql语句怎么办
  • Windows 11 使用容器(Docker Podman)
  • AI到底能做些什么:详细产品功能对比
  • 力扣-贪心-376 摆动序列
  • 人工智能 阿里云算力服务器的使用