当前位置: 首页 > article >正文

【llm对话系统】大模型源码分析之 LLaMA 位置编码 RoPE

在自然语言处理(NLP)领域,Transformer 模型已经成为主流。然而,Transformer 本身并不具备处理序列顺序的能力。为了让模型理解文本中词语的相对位置,我们需要引入位置编码(Positional Encoding)。本文将深入探讨 LLaMA 模型中使用的 Rotary Embedding(旋转式嵌入)位置编码方法,并对比传统的 Transformer 位置编码方案,分析其设计与实现的优势。

1. 传统 Transformer 的位置编码

1.1 正弦余弦编码

在原始的 Transformer 模型中,使用了基于正弦和余弦函数的位置编码。这种编码方式的公式如下:

PE(pos, 2i)   = sin(pos / 10000^(2i/d_model))
PE(pos, 2i+1) = cos(pos / 10000^(2i/d_model))

其中:

  • pos 代表词语在序列中的位置。
  • i 代表编码向量的维度索引。
  • d_model 是模型的维度大小。

这种编码方式的主要特点是:

  • 绝对位置编码: 为每个位置生成唯一的向量。
  • 易于泛化到更长的序列: 可以外推到训练期间未见过的序列长度。
  • 维度变化: 编码向量的每个维度上的频率都不同。

1.2 代码示例 (PyTorch)

import torch
import math

def positional_encoding(pos, d_model):
    pe = torch.zeros(1, d_model)
    for i in range(0, d_model, 2):
        pe[0, i] = math.sin(pos / (10000 ** (i / d_model)))
        pe[0, i + 1] = math.cos(pos / (10000 ** (i / d_model)))
    return pe

# 示例
d_model = 512
max_len = 10
pos_encodings = torch.stack([positional_encoding(i, d_model) for i in range(max_len)])

print("Position Encodings Shape:", pos_encodings.shape) # 输出: torch.Size([10, 1, 512])
print("First 3 position encodings:\n", pos_encodings[:3])

1.3 缺点

传统的正弦余弦位置编码虽然有效,但也有其局限性:

  • 缺乏相对位置信息: 尽管编码能提供绝对位置,但难以直接捕捉词语之间的相对距离关系。
  • 位置编码与输入向量独立: 位置编码是直接加到输入词向量上的,没有与词向量进行交互,信息损失比较明显。

2. LLaMA 的 Rotary Embedding (RoPE)

LLaMA 模型采用了 Rotary Embedding(RoPE),一种相对位置编码方法,它通过旋转的方式将位置信息嵌入到词向量中。RoPE 的核心思想是将位置信息编码为旋转矩阵,然后将词向量进行旋转,从而引入位置信息。

2.1 RoPE 的核心公式

RoPE 的核心公式如下:

RoPE(q, k, pos) = rotate(q, pos, Θ)

其中:

  • qk 分别代表查询向量和键向量。
  • pos 是两个向量之间的相对位置。
  • Θ 是一个旋转矩阵,根据 pos 和预定义的频率生成。
  • rotate(q, pos, Θ) 表示将 q 旋转 Θ 角度后的结果。

更具体来说,对于维度为 d 的向量 q,RoPE 将其分为 d/2 对 (q0, q1), (q2, q3) …, (qd-2, qd-1)。每个维度对应用不同的旋转角度。旋转矩阵 R 的定义是:

R(pos) =  [[cos(pos * θ_0), -sin(pos * θ_0)],
          [sin(pos * θ_0),  cos(pos * θ_0)]]  
          [[cos(pos * θ_1), -sin(pos * θ_1)],
          [sin(pos * θ_1),  cos(pos * θ_1)]]
          ...
          [[cos(pos * θ_d/2-1), -sin(pos * θ_d/2-1)],
          [sin(pos * θ_d/2-1),  cos(pos * θ_d/2-1)]]

其中 θ_i = 10000^(-2i/d) ,每个维度对的旋转角度不同。

将旋转矩阵应用于向量 q ,就是:
q_rotated = R(pos) * q

2.2 LLaMA 源码实现

下面是 LLaMA 中 RoPE 的核心代码(简化版,使用 PyTorch):

import torch
import math

def precompute_freqs(dim, end, theta=10000.0):
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
    t = torch.arange(end)
    freqs = torch.outer(t, freqs)
    return torch.cat((freqs, freqs), dim=1)
    
def apply_rotary_emb(xq, xk, freqs):
    xq_complex = torch.complex(xq.float(), torch.roll(xq.float(), shifts=-xq.shape[-1]//2, dims=-1))
    xk_complex = torch.complex(xk.float(), torch.roll(xk.float(), shifts=-xk.shape[-1]//2, dims=-1))
    
    freqs_complex = torch.complex(torch.cos(freqs), torch.sin(freqs))
    
    xq_rotated = xq_complex * freqs_complex
    xk_rotated = xk_complex * freqs_complex

    return xq_rotated.real.type_as(xq), xk_rotated.real.type_as(xk)

# 示例
batch_size = 2
seq_len = 5
d_model = 512
head_dim = d_model//8
xq = torch.randn(batch_size, seq_len, 8, head_dim) # 输入查询向量
xk = torch.randn(batch_size, seq_len, 8, head_dim) # 输入键向量

freqs = precompute_freqs(head_dim, seq_len)
xq_rotated, xk_rotated  = apply_rotary_emb(xq, xk, freqs)
print("Rotated Query Shape:", xq_rotated.shape)
print("Rotated Key Shape:", xk_rotated.shape)

代码解释

  1. precompute_freqs(dim, end, theta):
    • 此函数用于预计算旋转矩阵中使用的频率。
    • dim: 表示词向量维度。
    • end: 表示最大序列长度。
    • 返回包含所有位置的频率列表。
  2. apply_rotary_emb(xq, xk, freqs):
    • 函数将旋转操作应用于查询向量 xq 和键向量 xk
    • 通过 complex 表示实数向量的旋转,并使用复数乘法完成旋转操作。
    • 使用 torch.roll() 函数将 xq 分成实部和虚部,使用complex类型可以更快的完成旋转计算,避免了循环遍历,提高计算速度。
    • 使用复数乘法完成旋转,通过 .real 属性取出旋转后的实部,并将类型转换回原始类型

2.3 RoPE 的优势

与传统的正弦余弦位置编码相比,RoPE 具有以下优势:

  1. 相对位置编码: RoPE 专注于编码词语之间的相对位置信息,而不仅仅是绝对位置。通过向量旋转,使得向量之间的相对位置信息更直观。
  2. 高效计算: 通过使用复数乘法,RoPE 可以在GPU上进行高效的并行计算。
  3. 良好的外推能力: RoPE 可以比较容易地推广到训练期间未见过的序列长度,并且性能保持稳定。
  4. 可解释性: RoPE 的旋转操作使其相对位置信息具有更强的可解释性,有助于理解模型的行为。

3. 总结

本文详细介绍了 LLaMA 模型中使用的 Rotary Embedding 位置编码方法。通过源码分析和对比传统的位置编码,我们了解了 RoPE 的核心原理和优势。RoPE 通过旋转操作高效地编码相对位置信息,为 LLaMA 模型的强大性能提供了重要的基础。希望本文能帮助你更深入地理解 Transformer 模型中的位置编码机制。

4. 参考资料

  • RoFormer: Enhanced Transformer with Rotary Position Embedding
  • Attention is All You Need

http://www.kler.cn/a/524675.html

相关文章:

  • 用HTML、CSS和JavaScript实现庆祝2025蛇年大吉(附源码)
  • 分享| RL-GPT 框架通过慢agent和快agent结合提高AI解决复杂任务的能力-Arxiv
  • C++ unordered_map和unordered_set的使用,哈希表的实现
  • SQL教程-基础语法
  • Time Constant | RC、RL 和 RLC 电路中的时间常数
  • postgres基准测试工具pgbench如何使用自定义的表结构和自定义sql
  • Origami Agents:AI驱动的销售研究工具,助力B2B销售团队高效增长
  • 火出圈的DeepSeeK R1详解
  • AI大模型开发原理篇-2:语言模型雏形之词袋模型
  • Baklib在知识管理创新中的价值体现与其他产品的优势比较分析
  • 0小明的数组游戏
  • Java基础面试题总结(题目来源JavaGuide)
  • 曲线救国——uniapp封装toast消息提示组件(js)
  • 什么是长短期记忆网络?
  • JVM_类的加载、链接、初始化、卸载、主动使用、被动使用
  • STM32标准库移植RT-Thread nano
  • OceanBase 读写分离探讨
  • WPS数据分析000008
  • Linux---架构概览
  • 27.useFetch
  • unity学习22:Application类其他功能
  • rust操作pgsql、mysql和sqlite
  • ResNeSt-2020笔记
  • 【愚公系列】《循序渐进Vue.js 3.x前端开发实践》033-响应式编程的原理及在Vue中的应用
  • P10638 BZOJ4355 Play with sequence Solution
  • 前端实战:小程序搭建商品购物全流程