【llm对话系统】大模型源码分析之llama模型的long context更长上下文支持
1. 引言
Llama模型的一个重要特性是支持长上下文处理。本文将深入分析Llama源码中实现长上下文的关键技术点,包括位置编码(position embedding)的外推方法、注意力机制的优化等。我们将通过详细的代码解析来理解其实现原理。
2. 位置编码的外推实现
2.1 旋转位置编码(RoPE)基础
Llama采用旋转位置编码(RoPE, Rotary Position Embedding)来编码token的位置信息。RoPE的实现包含几个关键步骤:
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, scale: float = 1.0):
"""
预计算RoPE的频率
Args:
dim: 隐藏层维度
end: 序列最大长度
theta: RoPE的基频参数
scale: 位置缩放因子
Returns:
freqs_cis: 复数形式的频率矩阵
"""
# 生成维度序列 [0, 2, ..., dim-2]
dims = torch.arange(0, dim, 2)[: (dim // 2)].float()
# 计算频率基数 1/θ^(2i/d)
freqs = 1.0 / (theta ** (dims / dim))
# 生成位置序列并应用缩放
t = torch.arange(end, device=freqs.device) * scale
# 计算位置和频率的外积
freqs = torch.outer(t, freqs)
# 转换为复数形式 e^(iθ)
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
return freqs_cis
def apply_rotary_emb(
xq: torch.Tensor,
xk: torch.Tensor,
freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
应用旋转位置编码
Args:
xq: query张量 [batch_size, seq_len, num_heads, head_dim]
xk: key张量 [batch_size, seq_len, num_heads, head_dim]
freqs_cis: 预计算的频率 [seq_len, head_dim//2]
"""
# 重塑张量以方便运算
xq_r, xq_i = xq.float().reshape(*xq.shape[:-1], -1, 2).unbind(-1)
xk_r, xk_i = xk.float().reshape(*xk.shape[:-1], -1, 2).unbind(-1)
# 提取频率的实部和虚部
freqs_cos = freqs_cis.real()
freqs_sin = freqs_cis.imag()
# 应用旋转变换
# xq_out = xq * cos(θ) + rotate_half(xq) * sin(θ)
xq_out_r = xq_r * freqs_cos - xq_i * freqs_sin
xq_out_i = xq_r * freqs_sin + xq_i * freqs_cos
xk_out_r = xk_r * freqs_cos - xk_i * freqs_sin
xk_out_i = xk_r * freqs_sin + xk_i * freqs_cos
# 重新组合实部和虚部
xq_out = torch.stack([xq_out_r, xq_out_i], dim=-1).flatten(-2)
xk_out = torch.stack([xk_out_r, xk_out_i], dim=-1).flatten(-2)
return xq_out.type_as(xq), xk_out.type_as(xk)
2.2 动态NTK外推方案
动态NTK缩放是实现长上下文的关键技术,它通过动态调整位置编码的缩放因子来改善模型在更长序列上的表现:
class LlamaConfig:
def __init__(self):
self.rope_scaling = {
"type": "dynamic", # 动态缩放类型
"factor": 2.0, # 基础缩放因子
"original_max_position_embeddings": 2048 # 原始训练长度
}
def compute_dynamic_ntk_scaling(
ctx_len: int,
orig_ctx_len: int = 2048,
base_scale: float = 0.25,
alpha: float = 1.0
) -> float:
"""
计算动态NTK缩放因子
Args:
ctx_len: 当前上下文长度
orig_ctx_len: 原始训练上下文长度
base_scale: 基础缩放系数
alpha: 缩放曲线的陡峭程度
"""
# 使用对数曲线计算缩放因子
return base_scale * math.log(ctx_len / orig_ctx_len) ** alpha
class LlamaAttention(nn.Module):
def __init__(self, config: LlamaConfig):
super().__init__()
self.config = config
self.rope_scaling = config.rope_scaling
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
) -> torch.Tensor:
"""
注意力前向计算
Args:
hidden_states: 输入张量 [batch_size, seq_len, hidden_size]
attention_mask: 注意力掩码
position_ids: 位置索引
"""
seq_len = hidden_states.shape[1]
# 计算动态缩放因子
if self.rope_scaling["type"] == "dynamic":
rope_scale = compute_dynamic_ntk_scaling(
seq_len,
self.config.rope_scaling["original_max_position_embeddings"],
base_scale=self.rope_scaling["factor"]
)
else:
rope_scale = 1.0
# 计算位置编码
freqs_cis = precompute_freqs_cis(
self.head_dim,
seq_len,
scale=rope_scale
)
# 应用旋转位置编码
query_states, key_states = apply_rotary_emb(
self.q_proj(hidden_states),
self.k_proj(hidden_states),
freqs_cis
)
3. 注意力机制优化
3.1 分块注意力计算
为了高效处理长序列,Llama实现了分块注意力计算。以下是详细的实现代码:
class ChunkedAttention(nn.Module):
def __init__(self, chunk_size: int = 1024):
super().__init__()
self.chunk_size = chunk_size
def forward(
self,
query: torch.Tensor, # [batch, num_heads, seq_len, head_dim]
key: torch.Tensor, # [batch, num_heads, seq_len, head_dim]
value: torch.Tensor, # [batch, num_heads, seq_len, head_dim]
mask: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""
分块计算注意力
"""
batch_size, num_heads, seq_len, head_dim = query.shape
# 计算需要的块数
num_chunks = (seq_len + self.chunk_size - 1) // self.chunk_size
# 存储每个块的输出
chunked_outputs = []
# 按块计算注意力
for chunk_idx in range(num_chunks):
# 计算当前块的起止位置
chunk_start = chunk_idx * self.chunk_size
chunk_end = min(chunk_start + self.chunk_size, seq_len)
# 提取当前块的query
chunk_query = query[:, :, chunk_start:chunk_end]
# 计算注意力得分
chunk_scores = torch.matmul(
chunk_query, # [b, h, chunk_size, d]
key.transpose(-2, -1) # [b, h, d, seq_len]
) # 得到 [b, h, chunk_size, seq_len]
# 缩放注意力得分
chunk_scores = chunk_scores / math.sqrt(head_dim)
# 应用attention mask
if mask is not None:
chunk_mask = mask[:, :, chunk_start:chunk_end, :]
chunk_scores = chunk_scores + chunk_mask
# 应用softmax
chunk_attn = F.softmax(chunk_scores, dim=-1)
# 计算输出
chunk_output = torch.matmul(chunk_attn, value)
chunked_outputs.append(chunk_output)
# 拼接所有块的输出
return torch.cat(chunked_outputs, dim=2)
3.2 优化的KV Cache实现
KV Cache的实现需要考虑内存效率和计算性能:
class KVCache:
def __init__(
self,
max_batch_size: int,
max_seq_length: int,
num_heads: int,
head_dim: int,
dtype: torch.dtype = torch.float16
):
"""
初始化KV缓存
Args:
max_batch_size: 最大批次大小
max_seq_length: 最大序列长度
num_heads: 注意力头数
head_dim: 每个头的维度
dtype: 数据类型
"""
self.max_seq_length = max_seq_length
# 初始化缓存张量
self.k_cache = torch.zeros(
max_batch_size,
num_heads,
max_seq_length,
head_dim,
dtype=dtype
)
self.v_cache = torch.zeros(
max_batch_size,
num_heads,
max_seq_length,
head_dim,
dtype=dtype
)
# 记录当前序列长度
self.current_length = 0
def update(
self,
key: torch.Tensor,
value: torch.Tensor,
position: int
) -> None:
"""
更新缓存
Args:
key: key状态 [batch_size, num_heads, seq_len, head_dim]
value: value状态 [batch_size, num_heads, seq_len, head_dim]
position: 起始位置
"""
seq_len = key.shape[2]
if position + seq_len > self.max_seq_length:
raise ValueError(f"Position {position + seq_len} exceeds max_seq_length {self.max_seq_length}")
# 更新缓存
self.k_cache[:, :, position:position+seq_len] = key
self.v_cache[:, :, position:position+seq_len] = value
# 更新当前长度
self.current_length = max(self.current_length, position + seq_len)
def get_cached_kv(
self,
start_pos: int,
end_pos: int
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
获取指定范围的缓存内容
"""
return (
self.k_cache[:, :, start_pos:end_pos],
self.v_cache[:, :, start_pos:end_pos]
)
def clear(self) -> None:
"""清空缓存"""
self.k_cache.zero_()
self.v_cache.zero_()
self.current_length = 0
4. 实际应用示例
让我们看一个完整的使用示例,展示如何处理长文本:
class LongContextProcessor:
def __init__(
self,
model: LlamaModel,
tokenizer,
max_length: int = 16384,
chunk_size: int = 1024
):
self.model = model
self.tokenizer = tokenizer
self.chunk_size = chunk_size
# 初始化KV缓存
self.kv_cache = KVCache(
max_batch_size=1,
max_seq_length=max_length,
num_heads=model.config.num_attention_heads,
head_dim=model.config.hidden_size // model.config.num_attention_heads
)
def process_long_text(self, text: str) -> torch.Tensor:
"""
处理长文本输入
Args:
text: 输入文本
Returns:
处理后的隐藏状态
"""
# 分词
tokens = self.tokenizer(
text,
return_tensors="pt",
truncation=False
).input_ids
# 清空KV缓存
self.kv_cache.clear()
# 分块处理
all_hidden_states = []
for i in range(0, tokens.size(1), self.chunk_size):
# 获取当前块
chunk = tokens[:, i:i+self.chunk_size]
# 获取位置编码索引
position_ids = torch.arange(
i,
i + chunk.size(1),
dtype=torch.long,
device=chunk.device
).unsqueeze(0)
# 获取当前位置的缓存
k_cache, v_cache = self.kv_cache.get_cached_kv(0, i)
# 前向计算
outputs = self.model(
chunk,
position_ids=position_ids,
past_key_values=[(k_cache, v_cache)] * self.model.config.num_hidden_layers
)
# 更新缓存
self.kv