【llm对话系统】大模型 Llama 源码分析之归一化方法 RMS Norm
1. 引言
在深度学习中,归一化 (Normalization) 是一种常用的技术,它可以加速模型的训练并提高模型的性能。常见的归一化方法包括 Batch Normalization (BatchNorm)、Layer Normalization (LayerNorm) 等。Llama 模型采用了一种称为 RMS Norm 的归一化方法,它是一种对 LayerNorm 的简化和改进。
本文将深入 Llama 源码,分析 RMS Norm 的实现逻辑,并探讨其相比于其他归一化方法的优势。
2. 归一化方法回顾
2.1 Batch Normalization (BatchNorm)
BatchNorm 对每个 mini-batch 的数据进行归一化,使其均值为 0,方差为 1。它引入了两个可学习的参数:缩放因子 (scale) 和偏移因子 (shift)。
公式:
y = (x - mean(x)) / sqrt(variance(x) + epsilon) * scale + shift
优点:
- 加速训练。
- 具有一定的正则化效果。
缺点:
- 依赖于 batch size,当 batch size 较小时,效果较差。
- 不适用于 RNN 等序列模型。
2.2 Layer Normalization (LayerNorm)
LayerNorm 对每个样本的特征进行归一化,使其均值为 0,方差为 1。它也引入了两个可学习的参数:缩放因子 (scale) 和偏移因子 (shift)。
公式:
y = (x - mean(x)) / sqrt(variance(x) + epsilon) * scale + shift
优点:
- 不依赖于 batch size。
- 适用于 RNN 等序列模型。
缺点:
- 计算量比 BatchNorm 略大。
3. RMS Norm 原理
RMS Norm (Root Mean Square Normalization) 可以看作是 LayerNorm 的一个特例。它只对输入进行 均方根 (Root Mean Square) 归一化,并保留了可学习的缩放因子,但 去除了偏移因子。
公式:
y = x / sqrt(mean(x^2) + epsilon) * scale
其中:
x
是输入向量。mean(x^2)
是x
各元素的平方的平均值。epsilon
是一个很小的常数,用于防止除零错误。scale
是可学习的缩放因子,通常初始化为 1。
与 LayerNorm 的比较:
- RMS Norm 没有减去均值 (即没有中心化)。
- RMS Norm 没有偏移因子。
4. Llama 中 RMS Norm 的实现
Llama 源码中 RMS Norm 的实现位于 llama/model.py
文件中,定义在 RMSNorm
类中:
import torch
import torch.nn as nn
class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
"""
初始化 RMSNorm.
Args:
dim: 输入的维度
eps: 用于数值稳定的小常数
"""
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def _norm(self, x):
"""
执行 RMS 归一化.
Args:
x: 输入张量 (..., dim)
Returns:
归一化后的张量 (..., dim)
"""
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
"""
前向传播.
Args:
x: 输入张量 (..., dim)
Returns:
归一化并缩放后的张量 (..., dim)
"""
output = self._norm(x.float()).type_as(x)
return output * self.weight
代码解释:
-
__init__
函数:dim
:输入的维度。eps
:用于数值稳定的小常数,默认为1e-6
。weight
:可学习的缩放因子,初始化为全 1 的张量。
-
_norm
函数:- 计算输入
x
的均方根的倒数:torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
。x.pow(2)
:计算x
每个元素的平方。.mean(-1, keepdim=True)
:沿着最后一个维度计算平均值,并保持维度不变。torch.rsqrt()
:计算平方根的倒数。
- 将
x
与均方根的倒数相乘,实现归一化。
- 计算输入
-
forward
函数:- 调用
_norm
函数进行归一化。 - 将归一化后的结果与可学习的
weight
相乘,进行缩放。 .type_as(x)
:将结果转换为与输入x
相同的类型。
- 调用
使用示例:
# 假设输入维度为 512
dim = 512
rms_norm = RMSNorm(dim)
# 模拟一个输入张量
x = torch.randn(1, 10, dim)
# 进行 RMS Norm 归一化
y = rms_norm(x)
print(y.shape) # 输出: torch.Size([1, 10, 512])
5. RMS Norm 的优势
- 计算效率高:RMS Norm 比 LayerNorm 少了均值计算和偏移操作,计算速度更快。
- 性能相当:实验表明,RMS Norm 的性能与 LayerNorm 相当,甚至在某些任务上略有提升。
- 更稳定:RMS Norm 对输入的缩放更加鲁棒,因为它只依赖于输入的平方的平均值,而不依赖于输入的均值。
为什么 RMS Norm 可以去掉偏移因子?
在 Transformer 架构中,通常在 RMS Norm 之后会跟一个线性层 (例如,多头注意力机制中的 Q, K, V 投影)。这个线性层可以学习到偏移的效果。因此,RMS Norm 中的偏移因子就显得多余了。
6. 总结
RMS Norm 是一种高效且有效的归一化方法,它通过对 LayerNorm 进行简化,去除了均值计算和偏移因子,提高了计算效率并保持了良好的性能。