【llm对话系统】大模型源码分析之 LLaMA 位置编码 RoPE
在自然语言处理(NLP)领域,Transformer 模型已经成为主流。然而,Transformer 本身并不具备处理序列顺序的能力。为了让模型理解文本中词语的相对位置,我们需要引入位置编码(Positional Encoding)。本文将深入探讨 LLaMA 模型中使用的 Rotary Embedding(旋转式嵌入)位置编码方法,并对比传统的 Transformer 位置编码方案,分析其设计与实现的优势。
1. 传统 Transformer 的位置编码
1.1 正弦余弦编码
在原始的 Transformer 模型中,使用了基于正弦和余弦函数的位置编码。这种编码方式的公式如下:
PE(pos, 2i) = sin(pos / 10000^(2i/d_model))
PE(pos, 2i+1) = cos(pos / 10000^(2i/d_model))
其中:
pos
代表词语在序列中的位置。i
代表编码向量的维度索引。d_model
是模型的维度大小。
这种编码方式的主要特点是:
- 绝对位置编码: 为每个位置生成唯一的向量。
- 易于泛化到更长的序列: 可以外推到训练期间未见过的序列长度。
- 维度变化: 编码向量的每个维度上的频率都不同。
1.2 代码示例 (PyTorch)
import torch
import math
def positional_encoding(pos, d_model):
pe = torch.zeros(1, d_model)
for i in range(0, d_model, 2):
pe[0, i] = math.sin(pos / (10000 ** (i / d_model)))
pe[0, i + 1] = math.cos(pos / (10000 ** (i / d_model)))
return pe
# 示例
d_model = 512
max_len = 10
pos_encodings = torch.stack([positional_encoding(i, d_model) for i in range(max_len)])
print("Position Encodings Shape:", pos_encodings.shape) # 输出: torch.Size([10, 1, 512])
print("First 3 position encodings:\n", pos_encodings[:3])
1.3 缺点
传统的正弦余弦位置编码虽然有效,但也有其局限性:
- 缺乏相对位置信息: 尽管编码能提供绝对位置,但难以直接捕捉词语之间的相对距离关系。
- 位置编码与输入向量独立: 位置编码是直接加到输入词向量上的,没有与词向量进行交互,信息损失比较明显。
2. LLaMA 的 Rotary Embedding (RoPE)
LLaMA 模型采用了 Rotary Embedding(RoPE),一种相对位置编码方法,它通过旋转的方式将位置信息嵌入到词向量中。RoPE 的核心思想是将位置信息编码为旋转矩阵,然后将词向量进行旋转,从而引入位置信息。
2.1 RoPE 的核心公式
RoPE 的核心公式如下:
RoPE(q, k, pos) = rotate(q, pos, Θ)
其中:
q
和k
分别代表查询向量和键向量。pos
是两个向量之间的相对位置。Θ
是一个旋转矩阵,根据pos
和预定义的频率生成。rotate(q, pos, Θ)
表示将q
旋转Θ
角度后的结果。
更具体来说,对于维度为 d
的向量 q
,RoPE 将其分为 d/2
对 (q0, q1), (q2, q3) …, (qd-2, qd-1)。每个维度对应用不同的旋转角度。旋转矩阵 R
的定义是:
R(pos) = [[cos(pos * θ_0), -sin(pos * θ_0)],
[sin(pos * θ_0), cos(pos * θ_0)]]
[[cos(pos * θ_1), -sin(pos * θ_1)],
[sin(pos * θ_1), cos(pos * θ_1)]]
...
[[cos(pos * θ_d/2-1), -sin(pos * θ_d/2-1)],
[sin(pos * θ_d/2-1), cos(pos * θ_d/2-1)]]
其中 θ_i = 10000^(-2i/d) ,每个维度对的旋转角度不同。
将旋转矩阵应用于向量 q
,就是:
q_rotated = R(pos) * q
2.2 LLaMA 源码实现
下面是 LLaMA 中 RoPE 的核心代码(简化版,使用 PyTorch):
import torch
import math
def precompute_freqs(dim, end, theta=10000.0):
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
t = torch.arange(end)
freqs = torch.outer(t, freqs)
return torch.cat((freqs, freqs), dim=1)
def apply_rotary_emb(xq, xk, freqs):
xq_complex = torch.complex(xq.float(), torch.roll(xq.float(), shifts=-xq.shape[-1]//2, dims=-1))
xk_complex = torch.complex(xk.float(), torch.roll(xk.float(), shifts=-xk.shape[-1]//2, dims=-1))
freqs_complex = torch.complex(torch.cos(freqs), torch.sin(freqs))
xq_rotated = xq_complex * freqs_complex
xk_rotated = xk_complex * freqs_complex
return xq_rotated.real.type_as(xq), xk_rotated.real.type_as(xk)
# 示例
batch_size = 2
seq_len = 5
d_model = 512
head_dim = d_model//8
xq = torch.randn(batch_size, seq_len, 8, head_dim) # 输入查询向量
xk = torch.randn(batch_size, seq_len, 8, head_dim) # 输入键向量
freqs = precompute_freqs(head_dim, seq_len)
xq_rotated, xk_rotated = apply_rotary_emb(xq, xk, freqs)
print("Rotated Query Shape:", xq_rotated.shape)
print("Rotated Key Shape:", xk_rotated.shape)
代码解释:
precompute_freqs(dim, end, theta)
:- 此函数用于预计算旋转矩阵中使用的频率。
dim
: 表示词向量维度。end
: 表示最大序列长度。- 返回包含所有位置的频率列表。
apply_rotary_emb(xq, xk, freqs)
:- 函数将旋转操作应用于查询向量
xq
和键向量xk
。 - 通过 complex 表示实数向量的旋转,并使用复数乘法完成旋转操作。
- 使用
torch.roll()
函数将 xq 分成实部和虚部,使用complex类型可以更快的完成旋转计算,避免了循环遍历,提高计算速度。 - 使用复数乘法完成旋转,通过
.real
属性取出旋转后的实部,并将类型转换回原始类型
- 函数将旋转操作应用于查询向量
2.3 RoPE 的优势
与传统的正弦余弦位置编码相比,RoPE 具有以下优势:
- 相对位置编码: RoPE 专注于编码词语之间的相对位置信息,而不仅仅是绝对位置。通过向量旋转,使得向量之间的相对位置信息更直观。
- 高效计算: 通过使用复数乘法,RoPE 可以在GPU上进行高效的并行计算。
- 良好的外推能力: RoPE 可以比较容易地推广到训练期间未见过的序列长度,并且性能保持稳定。
- 可解释性: RoPE 的旋转操作使其相对位置信息具有更强的可解释性,有助于理解模型的行为。
3. 总结
本文详细介绍了 LLaMA 模型中使用的 Rotary Embedding 位置编码方法。通过源码分析和对比传统的位置编码,我们了解了 RoPE 的核心原理和优势。RoPE 通过旋转操作高效地编码相对位置信息,为 LLaMA 模型的强大性能提供了重要的基础。希望本文能帮助你更深入地理解 Transformer 模型中的位置编码机制。
4. 参考资料
- RoFormer: Enhanced Transformer with Rotary Position Embedding
- Attention is All You Need