当前位置: 首页 > article >正文

【llm对话系统】大模型源码分析之llama模型的long context更长上下文支持

1. 引言

Llama模型的一个重要特性是支持长上下文处理。本文将深入分析Llama源码中实现长上下文的关键技术点,包括位置编码(position embedding)的外推方法、注意力机制的优化等。我们将通过详细的代码解析来理解其实现原理。

2. 位置编码的外推实现

2.1 旋转位置编码(RoPE)基础

Llama采用旋转位置编码(RoPE, Rotary Position Embedding)来编码token的位置信息。RoPE的实现包含几个关键步骤:

def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, scale: float = 1.0):
    """
    预计算RoPE的频率
    Args:
        dim: 隐藏层维度
        end: 序列最大长度
        theta: RoPE的基频参数
        scale: 位置缩放因子
    Returns:
        freqs_cis: 复数形式的频率矩阵
    """
    # 生成维度序列 [0, 2, ..., dim-2]
    dims = torch.arange(0, dim, 2)[: (dim // 2)].float()
    
    # 计算频率基数 1/θ^(2i/d)
    freqs = 1.0 / (theta ** (dims / dim))
    
    # 生成位置序列并应用缩放
    t = torch.arange(end, device=freqs.device) * scale
    
    # 计算位置和频率的外积
    freqs = torch.outer(t, freqs)
    
    # 转换为复数形式 e^(iθ)
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
    return freqs_cis

def apply_rotary_emb(
    xq: torch.Tensor,
    xk: torch.Tensor,
    freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    应用旋转位置编码
    Args:
        xq: query张量 [batch_size, seq_len, num_heads, head_dim]
        xk: key张量 [batch_size, seq_len, num_heads, head_dim]
        freqs_cis: 预计算的频率 [seq_len, head_dim//2]
    """
    # 重塑张量以方便运算
    xq_r, xq_i = xq.float().reshape(*xq.shape[:-1], -1, 2).unbind(-1)
    xk_r, xk_i = xk.float().reshape(*xk.shape[:-1], -1, 2).unbind(-1)
    
    # 提取频率的实部和虚部
    freqs_cos = freqs_cis.real()
    freqs_sin = freqs_cis.imag()

    # 应用旋转变换
    # xq_out = xq * cos(θ) + rotate_half(xq) * sin(θ)
    xq_out_r = xq_r * freqs_cos - xq_i * freqs_sin
    xq_out_i = xq_r * freqs_sin + xq_i * freqs_cos
    xk_out_r = xk_r * freqs_cos - xk_i * freqs_sin
    xk_out_i = xk_r * freqs_sin + xk_i * freqs_cos

    # 重新组合实部和虚部
    xq_out = torch.stack([xq_out_r, xq_out_i], dim=-1).flatten(-2)
    xk_out = torch.stack([xk_out_r, xk_out_i], dim=-1).flatten(-2)
    
    return xq_out.type_as(xq), xk_out.type_as(xk)

2.2 动态NTK外推方案

动态NTK缩放是实现长上下文的关键技术,它通过动态调整位置编码的缩放因子来改善模型在更长序列上的表现:

class LlamaConfig:
    def __init__(self):
        self.rope_scaling = {
            "type": "dynamic",  # 动态缩放类型
            "factor": 2.0,      # 基础缩放因子
            "original_max_position_embeddings": 2048  # 原始训练长度
        }

def compute_dynamic_ntk_scaling(
    ctx_len: int,
    orig_ctx_len: int = 2048,
    base_scale: float = 0.25,
    alpha: float = 1.0
) -> float:
    """
    计算动态NTK缩放因子
    Args:
        ctx_len: 当前上下文长度
        orig_ctx_len: 原始训练上下文长度
        base_scale: 基础缩放系数
        alpha: 缩放曲线的陡峭程度
    """
    # 使用对数曲线计算缩放因子
    return base_scale * math.log(ctx_len / orig_ctx_len) ** alpha

class LlamaAttention(nn.Module):
    def __init__(self, config: LlamaConfig):
        super().__init__()
        self.config = config
        self.rope_scaling = config.rope_scaling
        
    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
    ) -> torch.Tensor:
        """
        注意力前向计算
        Args:
            hidden_states: 输入张量 [batch_size, seq_len, hidden_size]
            attention_mask: 注意力掩码
            position_ids: 位置索引
        """
        seq_len = hidden_states.shape[1]
        
        # 计算动态缩放因子
        if self.rope_scaling["type"] == "dynamic":
            rope_scale = compute_dynamic_ntk_scaling(
                seq_len,
                self.config.rope_scaling["original_max_position_embeddings"],
                base_scale=self.rope_scaling["factor"]
            )
        else:
            rope_scale = 1.0
            
        # 计算位置编码
        freqs_cis = precompute_freqs_cis(
            self.head_dim,
            seq_len,
            scale=rope_scale
        )
        
        # 应用旋转位置编码
        query_states, key_states = apply_rotary_emb(
            self.q_proj(hidden_states),
            self.k_proj(hidden_states),
            freqs_cis
        )

3. 注意力机制优化

3.1 分块注意力计算

为了高效处理长序列,Llama实现了分块注意力计算。以下是详细的实现代码:

class ChunkedAttention(nn.Module):
    def __init__(self, chunk_size: int = 1024):
        super().__init__()
        self.chunk_size = chunk_size
        
    def forward(
        self,
        query: torch.Tensor,      # [batch, num_heads, seq_len, head_dim]
        key: torch.Tensor,        # [batch, num_heads, seq_len, head_dim]
        value: torch.Tensor,      # [batch, num_heads, seq_len, head_dim]
        mask: Optional[torch.Tensor] = None
    ) -> torch.Tensor:
        """
        分块计算注意力
        """
        batch_size, num_heads, seq_len, head_dim = query.shape
        
        # 计算需要的块数
        num_chunks = (seq_len + self.chunk_size - 1) // self.chunk_size
        
        # 存储每个块的输出
        chunked_outputs = []
        
        # 按块计算注意力
        for chunk_idx in range(num_chunks):
            # 计算当前块的起止位置
            chunk_start = chunk_idx * self.chunk_size
            chunk_end = min(chunk_start + self.chunk_size, seq_len)
            
            # 提取当前块的query
            chunk_query = query[:, :, chunk_start:chunk_end]
            
            # 计算注意力得分
            chunk_scores = torch.matmul(
                chunk_query,                    # [b, h, chunk_size, d]
                key.transpose(-2, -1)           # [b, h, d, seq_len]
            )   # 得到 [b, h, chunk_size, seq_len]
            
            # 缩放注意力得分
            chunk_scores = chunk_scores / math.sqrt(head_dim)
            
            # 应用attention mask
            if mask is not None:
                chunk_mask = mask[:, :, chunk_start:chunk_end, :]
                chunk_scores = chunk_scores + chunk_mask
            
            # 应用softmax
            chunk_attn = F.softmax(chunk_scores, dim=-1)
            
            # 计算输出
            chunk_output = torch.matmul(chunk_attn, value)
            chunked_outputs.append(chunk_output)
        
        # 拼接所有块的输出
        return torch.cat(chunked_outputs, dim=2)

3.2 优化的KV Cache实现

KV Cache的实现需要考虑内存效率和计算性能:

class KVCache:
    def __init__(
        self,
        max_batch_size: int,
        max_seq_length: int,
        num_heads: int,
        head_dim: int,
        dtype: torch.dtype = torch.float16
    ):
        """
        初始化KV缓存
        Args:
            max_batch_size: 最大批次大小
            max_seq_length: 最大序列长度
            num_heads: 注意力头数
            head_dim: 每个头的维度
            dtype: 数据类型
        """
        self.max_seq_length = max_seq_length
        
        # 初始化缓存张量
        self.k_cache = torch.zeros(
            max_batch_size,
            num_heads,
            max_seq_length,
            head_dim,
            dtype=dtype
        )
        self.v_cache = torch.zeros(
            max_batch_size,
            num_heads,
            max_seq_length,
            head_dim,
            dtype=dtype
        )
        
        # 记录当前序列长度
        self.current_length = 0
        
    def update(
        self,
        key: torch.Tensor,
        value: torch.Tensor,
        position: int
    ) -> None:
        """
        更新缓存
        Args:
            key: key状态 [batch_size, num_heads, seq_len, head_dim]
            value: value状态 [batch_size, num_heads, seq_len, head_dim]
            position: 起始位置
        """
        seq_len = key.shape[2]
        if position + seq_len > self.max_seq_length:
            raise ValueError(f"Position {position + seq_len} exceeds max_seq_length {self.max_seq_length}")
        
        # 更新缓存
        self.k_cache[:, :, position:position+seq_len] = key
        self.v_cache[:, :, position:position+seq_len] = value
        
        # 更新当前长度
        self.current_length = max(self.current_length, position + seq_len)
    
    def get_cached_kv(
        self,
        start_pos: int,
        end_pos: int
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        获取指定范围的缓存内容
        """
        return (
            self.k_cache[:, :, start_pos:end_pos],
            self.v_cache[:, :, start_pos:end_pos]
        )
    
    def clear(self) -> None:
        """清空缓存"""
        self.k_cache.zero_()
        self.v_cache.zero_()
        self.current_length = 0

4. 实际应用示例

让我们看一个完整的使用示例,展示如何处理长文本:

class LongContextProcessor:
    def __init__(
        self,
        model: LlamaModel,
        tokenizer,
        max_length: int = 16384,
        chunk_size: int = 1024
    ):
        self.model = model
        self.tokenizer = tokenizer
        self.chunk_size = chunk_size
        
        # 初始化KV缓存
        self.kv_cache = KVCache(
            max_batch_size=1,
            max_seq_length=max_length,
            num_heads=model.config.num_attention_heads,
            head_dim=model.config.hidden_size // model.config.num_attention_heads
        )
    
    def process_long_text(self, text: str) -> torch.Tensor:
        """
        处理长文本输入
        Args:
            text: 输入文本
        Returns:
            处理后的隐藏状态
        """
        # 分词
        tokens = self.tokenizer(
            text,
            return_tensors="pt",
            truncation=False
        ).input_ids
        
        # 清空KV缓存
        self.kv_cache.clear()
        
        # 分块处理
        all_hidden_states = []
        for i in range(0, tokens.size(1), self.chunk_size):
            # 获取当前块
            chunk = tokens[:, i:i+self.chunk_size]
            
            # 获取位置编码索引
            position_ids = torch.arange(
                i,
                i + chunk.size(1),
                dtype=torch.long,
                device=chunk.device
            ).unsqueeze(0)
            
            # 获取当前位置的缓存
            k_cache, v_cache = self.kv_cache.get_cached_kv(0, i)
            
            # 前向计算
            outputs = self.model(
                chunk,
                position_ids=position_ids,
                past_key_values=[(k_cache, v_cache)] * self.model.config.num_hidden_layers
            )
            
            # 更新缓存
            self.kv

http://www.kler.cn/a/524394.html

相关文章:

  • Vuex中的getter和mutation有什么区别
  • DeepSeek理解概率的能力
  • CMake常用命令指南(CMakeList.txt)
  • YOLOv10 介绍
  • 解读隐私保护工具 Fluidkey:如何畅游链上世界而不暴露地址?
  • 计算机网络__基础知识问答
  • 电路研究9.2.4——合宙Air780EP中MQTT 相关命令使用方法研究
  • 数仓ETL测试
  • 【华为OD-E卷 - 最长方连续方波信号 100分(python、java、c++、js、c)】
  • 【电工基础】2.低压带电作业定义,范围,工作要求,电工基本工具
  • CSS基础语法(全)
  • pytorch实现主成分分析 (PCA):用于数据降维和特征提取
  • 解决ImportError: cannot import name ‘notf‘
  • 虚幻基础10:isValid
  • go到底是什么意思:对go的猜测或断言
  • Clojure语言的系统运维
  • Deepseek的RL算法GRPO解读
  • PostgreSQL 数据备份与恢复:掌握 pg_dump 和 pg_restore 的最佳实践
  • 10.6.3 XML文件读写
  • Brave132 编译指南 Windows 篇:配置 Git(四)
  • 图论——最小生成树的扩展应用
  • 流浪动物救助微信小程序springboot+论文源码调试讲解
  • AI学习指南Ollama篇-Ollama性能优化与监控
  • JDK15主要特性
  • 算法-加油站问题
  • yolov11配置环境,实现OBB带方向目标检测