LLaMA详解
LLaMA 进化史
大规模语言模型(Large Language Model, LLM)的快速发展正在以前所未有的速度推动人工智能(AI)技术的进步。
作为这一领域的先行者, Meta在其LLaMA(Large Language Model Meta AI)系列模型上取得了一系列重大突破。
近日, Meta官方正式宣布推出LLaMA-3, 作为继LLaMA-1、LLaMA-2和Code-LLaMA之后的第三代旗舰模型。
LLaMA-3在多个权威基准测试中均取得了全面领先的成绩, 超越了业界同类最先进模型, 再次刷新了LLM的性能上限。
LLaMA-1 系列
LLaMA-1是Meta公司在2023年2月发布的首款大规模语言模型,它提供了7B、13B、30B和65B四种不同的参数规模版本。
这些版本均在超过1T(万亿)个tokens的广泛语料库上进行了预训练,有效利用了海量数据的优势。
在这些版本中,参数量最大的65B版本在2048张A100 80G GPU上训练了21天,表现出色,在多数基准测试中超越了具有175B参数的GPT-3。
除了卓越的性能,LLaMA-1还采取了开源策略,极大地方便了研究者和开发者的获取与使用,迅速成为开源社区中极受欢迎的大规模语言模型之一。
围绕LLaMA-1,形成了一个活跃的生态系统,涌现了许多基于该模型的下游应用、微调模型和衍生变体。
LLaMA-2 系列
LLaMA-2 系列,Meta在2023年7月推出的第二代大规模语言模型,是对LLaMA-1系列功能的扩展和改进。
此系列包括7B、13B和70B三个不同规模的版本,并特别推出了经过指令微调的对话模型LLaMA-2-Chat,旨在增强模型在对话和任务执行方面的能力。
主要特点包括:
- 增强的模型性能: 相较于LLaMA-1,LLaMA-2系列在更广泛的语料上进行了预训练,这大幅增强了其文本生成和理解能力,确保模型在各类NLP任务中的领先性能。
- 指令微调技术: LLaMA-2-Chat通过对大量的高质量指令-回答对进行微调,优化了模型对复杂指令的响应能力,使其在对话和具体任务执行中表现更为精确和自然。
- 开源政策: 继承LLaMA-1的开放精神,LLaMA-2以开源形式发布,不仅丰富了LLaMA的生态系统,也激励了基于这一平台的创新性研究和应用开发。
通过这些改进,LLaMA-2系列不仅在技术层面上提供了显著的升级,也在开放源代码和社区参与方面持续发挥着积极的推动作用。此举确保了LLaMA系列在全球语言模型领域的持续领先和影响力扩展。
LLaMA-3 系列
2024年4月, Meta 发布了最新一代的大规模语言模型系列 LLaMA-3。
作为 LLaMA 系列的第三代产品, LLaMA-3 在 LLaMA-1、LLaMA-2 和 Code-LLaMA 的基础上实现了全面升级和优化。
在多个权威基准测试中, LLaMA-3 的表现全面领先于业界同类最先进模型,再次刷新了大模型的性能上限。
LLaMA-3 在长序列建模能力上取得了重大突破。
相比 LLaMA-2 只能处理最长 2048 个 token 的文本, LLaMA-3 将这一上限提高到了 8192 个 token, 使其可以轻松应对超长文档、多轮对话等复杂场景。
同时, LLaMA-3 还采用了全新的 tokenizer, 将分词器更换为tiktoken,与GPT4保持一致。
其词表大小达到了 128K, 远超 LLaMA-2 的 32K。更大的词表意味着更精细的文本表示和更强的语言泛化能力。
注:
- Llama-1模型采用了BPE算法进行分词,使用sentencepiece实现。其词表大小为32k。
- Llama-2维持了与Llama-1相同的架构和分词器,但将上下文长度扩展到了4k。
数据规模是大模型取得突破性进展的另一个关键因素。
LLaMA-3 的预训练语料规模超过了 15 T (150 万亿) token, 是 LLaMA-2 的 7 倍多。
海量高质量的预训练数据, 为 LLaMA-3 提供了更全面、更深入的世界知识和语言理解能力。
同时, Meta 还对预训练数据进行了更细粒度的筛选和清洗, 进一步提高了数据质量。
凭借优秀的模型架构和海量的预训练数据, LLaMA-3 在下游任务上实现了全面领先。
在自然语言推理、机器阅读理解、常识问答等多个基准测试中, LLaMA-3 的表现都超越了业界最先进的同规模模型。
此外, LLaMA-3 在推理能力、代码生成、指令跟随、多语言支持等方面也有长足进步, 使其成为更加全能和可控的通用人工智能助手。
随着 LLaMA-3 的发布, Meta 再次引领了大模型技术的发展潮流。
作为当前最先进的开源大模型, LLaMA-3 必将掀起新一轮的研究和应用热潮, 为人工智能的进步注入强大动力。
LLaMA-3 的基础版模型和指令微调模型
Meta发布了LLaMA 3模型的几个变体:
-
meta-llama/Meta-LLama-3-8B: 这是基础版的LLaMA 3模型, 有80亿个参数。
-
meta-llama/Meta-LLama-3-8B-Instruct: 这是一个经过指令微调的8B参数LLaMA 3模型变体。
-
meta-llama/Meta-LLama-3-70B-Instruct: 这是一个更大的经过指令微调的LLaMA 3模型, 有700亿个参数。
-
meta-llama/Meta-LLama-3-70B: 这是没有经过指令微调的基础版70B参数LLaMA 3模型。
总的来说, Meta发布了8B和70B两种参数规模的LLaMA 3模型, 每种规模都有提供经过指令微调的变体。
指令微调(Instruction Tuning)的变体和基础版模型之间有以下几个主要区别:
- 训练目标不同: 基础版模型通常使用语言建模(Language Modeling)作为训练目标, 即根据上文预测下一个单词。而指令微调的变体则引入了额外的指令数据, 通过监督学习让模型学会理解和执行自然语言指令。
- 训练数据不同: 基础版模型主要在大规模无标签文本语料上进行预训练。而指令微调的变体需要构建专门的指令数据集, 其中包含大量的自然语言指令及其对应的执行结果。
- 应用场景不同: 基础版模型可以用于各种NLP任务, 但在执行具体指令时可能表现欠佳。指令微调的变体则专门针对指令理解和执行进行了优化, 在问答、对话、任务完成等场景中表现更加出色。
- 交互方式不同: 使用基础版模型时, 用户需要根据具体任务设计prompts。而指令微调的变体允许用户使用自然语言直接下达指令, 交互更加直观和方便。
- 可控性不同: 基础版模型生成的内容可能不够可控, 容易出现幻觉或不合适的言论。指令微调引入了人类反馈, 可以更好地引导模型生成安全、可靠的内容。
- 推理效率不同: 指令微调通常会引入一些控制机制, 如提示工程、示例学习等, 可能降低推理速度。但一些高效的微调方法如LoRA、Prefix Tuning等可以在保证性能的同时加快推理。
总的来说, 指令微调赋予了语言模型更强的指令理解和执行能力, 使其在实际应用中更加智能、高效和可控。
LLaMA衍生模型和生态
总览
LLaMA系列模型的开源性和可访问性, 使其成为了LLM领域的重要研究和应用平台。
围绕LLaMA, 涌现出了一个繁荣的开源生态, 催生了如 Alpaca、Vicuna、LLaVA 等一系列优秀的衍生模型。
这些模型在LLaMA的基础上,结合领域数据和下游任务进行针对性优化, 在对话、多模态、开放域问答等方面取得了瞩目成绩。
LLaMA的开源, 极大地降低了LLM研究的门槛, 推动了AI技术的普惠进程。
图中展示的是LLaMA模型及其衍生模型的概览,包括了不同的训练方法、数据类型、以及一些特定的应用场景。
以下是对图中内容的描述和解释:
-
**继续预训练(Continue pre-training):**指的是在现有模型基础上,使用更多数据继续训练以提升模型性能的过程。
-
**LLaMA:**是Meta公司开发的一个大型语言模型,图中显示了基于LLaMA模型的多种扩展和优化路径。
-
**参数高效微调(Parameter-efficient fine-tuning):**这是一种微调技术,通过调整模型中较少的参数来适应新的任务,以提高效率。
-
**模型继承(Model inheritance):**指新模型继承或基于旧模型的架构和参数进行开发。
-
**指令微调(Instruction tuning):**通过指令来指导模型微调,使其更好地遵循给定的任务指令。
-
**全参数微调(Full parameter fine-tuning):**指在微调过程中调整模型的所有参数。
-
**数据继承(Data inheritance):**新模型使用旧模型训练过的数据集作为训练数据的一部分。
-
**中文数据(Chinese data):**指使用了中文语言数据进行训练的模型。
-
**Open-Chinese-LLaMA:**可能是一个针对中文优化的LLaMA模型。
-
**合成数据(synthetic data):**指通过技术手段生成的非真实世界数据,用于训练。
-
**Vicuna, Panda, Alpaca, Goat:**基于LLaMA模型的不同变种或特定用途的模型名称。
-
**RLHF:**指的是通过人类反馈进行强化学习的微调方法(Reinforcement Learning from Human Feedback)。
-
**Yulan-Chat:**专门为中文对话优化的模型。
-
**PKU-Beaver, BiLLa:**指北京大学开发的模型或者是双语语言模型。
-
**Cornucopia:**指包含多种数据类型的综合数据集。
-
**Lawyer, LLaVA, [BELLE]:**针对特定领域(如法律)或具有特定功能(如视觉语言模型)的模型。
-
**MiniGPT-4, Ziya, QiZhenGPT, Baize:**不同大小或针对特定任务优化的模型。
-
**Guanaco, Chatbridge, Koala:**特定的多模态模型或其他应用领域的模型。
-
**VisionLLM, TaoLi, InstructBLIP:**结合了视觉和语言任务的模型。
-
**ChatMed, Adapter, PandaGPT, LAWGPT, BenTsao:**针对医疗、适配器技术、法律等特定领域的模型。
-
**多模态模型(Multimodal models):**能够处理并整合来自多种感官模式(如视觉、听觉、文本)的模型。
-
**数学(Math)、金融(Finance)、医学(Medicine)、法律(Law)、双语(Bilingualism)、教育(Education):**模型应用的领域。
这张图展示了以 LLaMA 系列模型为核心的大语言模型生态系统。
- 继续预训练(Continue pre-training): 一些模型如 Chinese LLaMA、Chinese Alpaca 等在 LLaMA 的基础上加入了中文数据继续预训练,以提高中文任务的表现。
- 指令微调(Instruction tuning): 通过在指令数据集上微调 LLaMA, 衍生出了 Alpaca、Vicuna 等模型, 使其能够执行指令跟随和问答对话等任务。
- 参数高效微调(Parameter-efficient fine-tuning): 使用 LoRA、Prefix Tuning 等参数高效微调方法在下游任务数据上微调 LLaMA,得到 Alpaca Lora 等模型。
- 全参数微调(Full parameter fine-tuning): 在特定垂直领域数据上对 LLaMA 进行全参数微调, 如在医疗对话数据上微调得到的 BianTsao 模型。
- 多模态模型(Multimodal models): 将 LLaMA 扩展到多模态, 如支持图像输入的 LLaVA、MinGPT-4, 语音交互的 InstructBLIP 等。
总的来说, 该图全面地展示了以 LLaMA 为基础衍生出的丰富多样的大模型生态,涵盖了主要的优化训练范式、任务类型和具体模型。
值得注意的是, 中文模型在该生态中占据了重要地位。
Alpaca
Alpaca 是由斯坦福大学计算机科学系博士生 Eric Wang 等人开发的一个基于 LLaMA-7B 模型的衍生大模型。
其主要特点包括:
- 指令精调: Alpaca 在 LLaMA-7B 的基础上, 使用了一个包含 5.2 万条指令数据的数据集进行了监督微调(Supervised Fine-tuning)。这使得 Alpaca 能够很好地理解和执行自然语言指令, 具备类似ChatGPT 的对话交互能力。
- 开源共享: Alpaca 项目的代码和训练数据都已经在 GitHub 上完全开源, 允许研究者和开发者基于此进行二次开发。这极大地降低了构建指令跟随型对话系统的门槛。
- 性能优异: 在标准的指令跟随任务基准如 MMLU 上, Alpaca 的表现已经接近 ChatGPT 等封闭模型, 而参数量和计算开销却小很多。这说明了在 LLaMA 基础上进行指令精调的有效性。
- 多语言支持: 得益于 LLaMA 模型本身强大的多语言能力, Alpaca 也具备了一定的多语言处理能力, 尽管主要还是针对英文进行了优化。
- 可控性强: 由于 Alpaca 的训练数据是人工标注的高质量指令数据, 因此其生成的内容更加可控, 在事实性、安全性方面表现出色。
- 开源生态: Alpaca 的开源进一步推动了 LLaMA 周边生态的繁荣, 催生了一系列基于 Alpaca 的衍生模型和应用。
总的来说, Alpaca 是 LLaMA 家族中一个代表性的指令精调模型, 其开源性、可访问性和优异的性能, 使其成为了开源界 ChatGPT 的有力竞争者。
Alpaca 的成功证明了在一个强大的基础模型上, 利用高质量的指令数据进行针对性微调, 可以显著提升模型在对话交互任务上的表现, 同时还能保持较强的可控性。
Vicuna
Vicuna 是由加州大学伯克利分校、卡内基梅隆大学、斯坦福大学等机构的研究者联合开发的一个基于LLaMA 的开源对话语言模型。
其主要特点包括:
- 大规模指令精调: Vicuna 使用了一个包含 7 万多条对话数据的指令数据集对 LLaMA-13B 进行了微调。这些数据主要来自于 ShareGPT 收集的真实人类对话,质量相当高。相比 Alpaca 的 5 万条指令数据,Vicuna 的训练语料更加丰富和多样化。
- 多轮对话能力: 得益于大规模高质量对话数据的训练, Vicuna 具备了出色的多轮对话能力。它能够很好地理解对话的上下文, 根据之前的对话内容生成连贯且相关的回复。在这一点上, Vicuna 比Alpaca 表现得更为出色。
- 训练计算效率: Vicuna 在训练过程中采用了一系列优化手段, 如混合精度训练、梯度累积、DeepSpeed 等, 使得在有限的计算资源下也能高效地完成大模型的训练。这为开源社区提供了一个很好的模型训练范例。
- 开源共享: 与 Alpaca 类似, Vicuna 的代码和训练数据也已经完全开源, 供社区使用和研究。同时, 研究团队还发布了经过训练的检查点(checkpoint), 可以直接进行推理和微调。
- 人类对齐度高: 通过在大规模人类对话数据上的训练, Vicuna 学会了更加自然、人性化的交互方式。其生成的回复在流畅度、连贯性、同理心等方面都有更好的表现,给人一种更加亲切、自然的感受。
- 商业化应用: 进一步扩大了 Vicuna 的影响力和应用范围。
总的来说, Vicuna 代表了开源对话大模型的最新进展。其充分利用了大规模高质量的对话数据,在 LLaMA 的基础上实现了全面的性能提升,尤其是在多轮对话和人类对齐度方面表现突出。
Vicuna 的开源和商业应用,为构建更加智能、自然的对话系统提供了重要参考。
LLaVA
LLaVA,全称为 Large Language and Vision Assistant(大型语言和视觉助手),是一种新型的大型多模态模型。
它的目标是开发一种通用视觉助手,能够遵循语言和图像指令来完成各种现实世界任务。
LLaVA 结合了自然语言处理(NLP)和计算机视觉(CV)的能力,通过理解视觉内容并根据语言指令进行操作,从而实现对图像和文本的深入理解与交互。
LLaVA 模型的核心在于其多模态架构,它将视觉编码器(如基于 Transformer 的视觉模型)与语言模型(如 LLaMA)结合起来,形成一个能够处理图文信息的集成系统。这种设计使得 LLaVA 能够执行包括图像标注、视觉问答、文本到图像生成等在内的多模态任务。
其主要特点包括:
- 多模态融合: LLaVA 在 LLaMA 的基础上引入了视觉特征,实现了语言和视觉信息的融合。具体来说, 它使用 BLIP-2 模型提取图像特征,然后将其与文本表示进行交互, 生成与图像相关的自然语言描述或回答。
- 图像理解能力: 得益于多模态融合, LLaVA 具备了强大的图像理解和分析能力。它可以根据输入的图像生成详细的描述,回答与图像相关的问题,甚至进行开放式的图像内容分析。
- 通用语言能力: 作为 LLaMA 的衍生物,LLaVA 继承了原模型优秀的自然语言处理能力。因此它不仅可以处理图像相关的任务,在通用的语言理解、生成等方面也有不俗的表现。
- 零样本学习: LLaVA 支持零样本学习(zero-shot learning), 即无需在下游任务上进行微调, 直接利用预训练的知识完成推理。这使得 LLaVA 可以灵活地应对各种形式的多模态任务, 无需重新训练模型。
- 大规模预训练: LLaVA 在大规模图文对数据上进行了预训练, 学习了丰富的视觉-语言对齐知识。这为其在下游任务上的优异表现奠定了基础。同时, 预训练也提高了模型的泛化能力和鲁棒性。
- 开源开放: 与其他 LLaMA 衍生物类似,LLaVA 的代码和预训练模型权重都已经开源。这为进一步研究和应用多模态大模型提供了宝贵的资源和参考。
总的来说, LLaVA 代表了 LLaMA 在多模态领域的重要拓展。
小结
通过引入视觉特征与语言表示的交互, LLaVA 实现了对图像内容的深度理解和分析。
同时, 它在通用语言任务上的出色表现, 证明了多模态学习对于提升语言模型性能的积极作用。
可以预见, LLaVA 将为多模态大模型的研究和应用开辟新的方向, 推动人工智能向更加全面、贴近人类智能的方向发展。
Llama3原理解读
Llama 3 架构图
Llama 3 模型架构
Llama 3模型基于标准的Transformer架构进行了多项改进,包括更高的效率和更好的性能。
接下来,我们详细探讨Llama 3架构的主要特点。
模型概述
-
模型类型:基于解码器的Transformer
-
参数量:8B和70B
-
上下文长度:8192 (LLaMA-1和LLaMA-2的上下文长度分别为2048,4096)
-
注意力机制:分组查询注意力 (Grouped Query Attention, GQA)
结构组成
Llama 3的模型架构主要包括以下组件:
分层堆叠
-
嵌入层:将输入的token转换为固定维度的嵌入表示。
-
自注意力层:包含多头自注意力机制和归一化。
-
前馈网络 (FFN) : 包含激活函数和两层全连接网络。
-
位置编码:采用RoPE(Rotary Position Embedding)位置编码。
Llama 3与传统Transformer架构的比较
相同之处
-
基本结构:两者都基于标准的Transformer解码器架构。
-
多头自注意力:都使用多头自注意力机制。
-
前馈网络 (FFN): 包含两层全连接网络。
不同之处
- 注意力机制:
- 传统Transformer: 使用标准多头自注意力机制。
- Llama 3:引入了分组查询注意力(GQA)以提高效率。
- 激活函数:
- 传统Transformer: 使用ReLU或GELU激活函数。
- LIama 3:使用更高效的SwiGLU激活函数。
- 位置编码:
- 传统Transformer: 使用正弦-余弦位置编码或Learned Positional Embedding。
- LIama 3:使用RoPE(Rotary Position Embedding)位置编码。
- 上下文长度:
- 传统Transformer: 通常为512或1024的上下文长度。
- Llama 3: 上下文长度增加到8192。
LLama3 模型组件详解
RMSNorm
归一化技术是一种在神经网络中用于归一化激活值的技术。
它可以提高训练的稳定性,加快收敛速度并改善模型的泛化性能。
下面是Transformer架构中常见的两种归一化技术:
层归一化 (Layer Normalization)
层归一化是Transformer架构中最早引入的归一化技术之一。
它将一个输入序列的所有特征归一化,使得每个样本的激活值在特征维度上具有均值为0、方差为1的分布。其公式如下:
x
^
i
=
x
i
−
μ
σ
+
ϵ
{\widehat{x}}_{i} = \frac{{x}_{i} - \mu }{\sigma + \epsilon}
x
i=σ+ϵxi−μ
其中,
μ
\mu
μ 是均值,
σ
\sigma
σ是标准差,
ϵ
\epsilon
ϵ是一个很小的数以避免除以零。
RMSNorm
RMSNorm是一种简化版的层归一化,它通过均方根(Root Mean Square)计算得到归一化尺度,不需要计算均值和标准差。
其公式如下:
x ^ i = x i 1 n ∑ j = 1 n x j 2 + ϵ {\widehat{x}}_{i} = \frac{{x}_{i}}{\sqrt{\frac{1}{n}\mathop{\sum }\limits_{{j = 1}}^{n}{x}_{j}^{2} + \epsilon }} x i=n1j=1∑nxj2+ϵxi
与层归一化相比, RMSNorm减少了计算复杂度,同时保持了良好的归一化效果。
RMSNorm在LLaMA中的应用
在LLaMA模型中, RMSNorm主要用于对每个时间步的隐藏状态向量进行归一化。
每个时间步的输入向量被独立归一化,适用于处理变长序列和保持时间步间的独立性。
具体来说, RMSNorm在Transformer架构中的应用通常在以下几个位置:
- 多头自注意力机制的输出:
- 对于每个时间步的隐藏状态,在经过多头自注意力机制后的输出进行归一化。
- 前馈神经网络的输出:
- 对于每个时间步的隐藏状态,在经过前馈神经网络 (Feed-Forward Neural Network, FFN) 后的输出进行归一化。
为什么选择RMSNorm?
在LLaMA中,选择RMSNorm作为主要的归一化策略,原因如下:
-
**提高训练效率:**RMSNorm在计算上比层归一化更高效。在大型语言模型中,效率的提升尤其显著。
-
**提高训练稳定性和泛化性能:**研究发现, RMSNorm可以提供更稳定的训练过程,减少梯度爆炸或消失的风险,同时保持较高的泛化能力。
-
**预归一化的优势:**LLaMA采用了RMSNorm的预归一化变体,将归一化操作放在Transformer模块的主要层之前,而不是之后。
这一策略进一步提升了模型的训练稳定性。
RMSNorm vs LayerNorm的对比
LLaMA中的实验结果显示, RMSNorm在效率和性能方面都有明显的提升:
-
效率提升:与层归一化相比, RMSNorm在效率上提升了10-50%。
-
性能对比:在实际应用中, RMSNorm在稳定性和泛化性能上与层归一化不相上下,有时甚至更优。
SwiGLU激活函数
激活函数在大型语言模型(LLM)的发展过程中扮演着关键角色,它影响着模型性能的稳定性与泛化能力。
LLaMA3在其前馈层中采用了一种特殊的激活函数——SwiGLU。
激活函数是深度神经网络中的关键组件,负责引入非线性,使模型能够学习复杂的数据模式。
常见的激活函数包括:
ReLU (Rectified Linear Unit)
ReLU是目前使用最广泛的激活函数之一。其计算公式如下:
ReLU
(
x
)
=
max
(
0
,
x
)
\operatorname{ReLU}\left( x\right) = \max \left( {0, x}\right)
ReLU(x)=max(0,x)
特点:
-
计算简单,梯度为0或1。
-
可能导致“神经元死亡”问题。
GELU (Gaussian Error Linear Unit)
GELU是一种基于高斯分布的激活函数,在GPT-3等模型中使用。
其公式如下:
GELU
(
x
)
=
x
⋅
0.5
⋅
(
1
+
erf
(
x
2
)
)
\operatorname{GELU}\left( x\right) = x \cdot {0.5} \cdot \left( {1 + \operatorname{erf}\left( \frac{x}{\sqrt{2}}\right) }\right)
GELU(x)=x⋅0.5⋅(1+erf(2x))
其中, erf为误差函数。
特点:
- 与ReLU相比, GELU具有平滑的激活曲线,能提供更高的性能。
Swish
Swish是一种平滑且连续的激活函数,在Transformer等模型中应用广泛。其公式如下:
Swish ( x ) = x ⋅ σ ( x ) \operatorname{Swish}\left( x\right) = x \cdot \sigma \left( x\right) Swish(x)=x⋅σ(x)
其中, σ ( x ) = 1 1 + e − x \sigma \left( x\right) = \frac{1}{1 + {e}^{-x}} σ(x)=1+e−x1 为标准Sigmoid函数。
特点:
- 平滑的曲线有助于稳定梯度流。
- 通常比ReLU表现更好。
GLU (Gated Linear Unit)
GLU(Gated Linear Unit)是一种用于神经网络中的激活函数,最初由 Yann Dauphin 等人在论文《Language Modeling with Gated Convolutional Networks》里提出。
GLU 的设计灵感来自门控机制,它通过引入门控操作来控制信息的流动。
该操作可以看作是引入了一种动态的选择机制,以在模型中选择性地传递信息。
GLU 的计算公式如下:
GLU
(
X
)
=
(
X
⋅
W
1
+
b
1
)
⊙
σ
(
X
⋅
W
2
+
b
2
)
\operatorname{GLU}\left( X\right) = \left( {X \cdot {W}_{1} + {b}_{1}}\right) \odot \sigma \left( {X \cdot {W}_{2} + {b}_{2}}\right)
GLU(X)=(X⋅W1+b1)⊙σ(X⋅W2+b2)
其中:
-
X X X 是输入特征。
-
W 1 , W 2 {W}_{1},{W}_{2} W1,W2 是两个不同的权重矩阵。
-
b 1 , b 2 {b}_{1},{b}_{2} b1,b2 是偏置。
-
σ \sigma σ 是 Sigmoid 函数,作为门控信号的激活函数。
-
⊙ \odot ⊙ 表示元素级别的乘法(Hadamard 积)。
在 GLU 中,输入特征 ( X ) 被分成两个部分:
-
线性部分: X ⋅ W 1 + b 1 X \cdot {W}_{1} + {b}_{1} X⋅W1+b1
-
门控部分: σ ( X ⋅ W 2 + b 2 ) \sigma \left( {X \cdot {W}_{2} + {b}_{2}}\right) σ(X⋅W2+b2)
门控部分使用 Sigmoid 函数输出一个介于 0 和 1 之间的权重矩阵,从而选择性地让部分信息通过。
因此, GLU 实际上是在学习一种信息过滤机制,以根据输入数据的特征动态调整信息的流动。
优势
-
**动态信息选择:**门控机制允许模型根据输入特征选择性地传递信息,提高模型的灵活性。
-
**性能提升:**GLU 在自然语言处理任务中证明可以比传统激活函数(如 ReLU)获得更好的性能。
SwiGLU激活函数
SwiGLU(Swish-Gated Linear Unit)是一种结合Swish激活函数与GLU(Gated Linear Unit)机制的激活函数。其公式如下:
SwiGLU
(
x
)
=
(
Swish
(
W
1
x
)
)
⋅
(
W
2
x
)
\operatorname{SwiGLU}\left( x\right) = \left( {\operatorname{Swish}\left( {{W}_{1}x}\right) }\right) \cdot \left( {{W}_{2}x}\right)
SwiGLU(x)=(Swish(W1x))⋅(W2x)
其中:
- ${W}_{1} $ 和 W 2 {W}_{2} W2 分别是输入 X X X 的两个线性变换矩阵。
- $Swish \left( z\right) = z \cdot \sigma \left( z\right) $,其中 σ ( z ) \sigma \left( z\right) σ(z) 是标准的Sigmoid函数。
该公式可以理解为:
-
对输入 $ x $ 进行两次线性变换。
-
将其中一个结果通过Swish激活函数。
-
将两个结果逐元素相乘,形成最终输出。
对于 β = 1 \beta=1 β=1 的 Swish函数,称之为 SiLU
SwiGLU的计算代价与性能优势
class SwiGLU(nn.Module):
def __init__(self, w1, w2, w3) -> None:
super().__init__()
self.w1 = w1
self.w2 = w2
self.w3 = w3
def forward(self, x):
x1 = F.linear(x, self.w1.weight)
x2 = F.linear(x, self.w2.weight)
hidden = F.silu(x1) * x2
return F.linear(hidden, self.w3.weight)
SwiGLU激活函数需要进行三次矩阵乘法运算,相比于ReLU等传统激活函数计算复杂度更高。
然而,研究发现,尽管计算量增加, SwiGLU带来的性能提升显著:
-
**更好的泛化性能:**SwiGLU在处理复杂的文本数据时表现出更好的泛化能力。
-
**稳定的梯度流:**Swish激活函数平滑的曲线特性有助于稳定梯度,减少梯度消失和爆炸的可能性。
RoPE旋转位置编码
位置编码是Transformer中确保模型能够理解序列顺序信息的重要部分。
传统的绝对和相对位置编码方案各有优缺点,然而RoPE(Rotary Position Embedding)作为一种新型的位置编码方法,平衡了绝对和相对位置编码的优点。
位置编码的背景
位置编码在Transformer中至关重要,因为自注意力机制本质上是无序的。 常见的两种位置编码方法是:
-
绝对位置编码 (Absolute Positional Encoding):为序列中的每个位置提供一个固定的嵌入。
-
相对位置编码 (Relative Positional Encoding) : 表示序列中每两个token之间的相对位置信息。
绝对位置编码的特性
绝对位置编码是指在序列中的每个位置直接关联一个固定的嵌入。
例如, Transformer中常用的正弦-余弦绝对位置编码如下:
P
E
(
p
o
s
,
2
i
)
=
sin
(
p
o
s
10000
2
i
/
d
)
P
E
(
p
o
s
,
2
i
+
1
)
=
cos
(
p
o
s
10000
2
i
/
d
)
P{E}_{\left( pos,2i\right) } = \sin \left( \frac{pos}{{10000}^{{2i}/d}}\right) \\ P{E}_{\left( pos,2i + 1\right) } = \cos \left( \frac{pos}{{10000}^{{2i}/d}}\right)
PE(pos,2i)=sin(100002i/dpos)PE(pos,2i+1)=cos(100002i/dpos)
其中, pos 表示序列中的位置, d 表示嵌入维度。
这种方法通过给定的位置索引为每个位置提供唯一的编码,确保模型能够理解token之间的顺序。
相对位置编码的特性
相对位置编码关注的是序列中两个token之间的相对距离,而非绝对位置。
相对位置编码可以帮助模型在不同长度的输入序列之间共享信息,具体实现方式多种多样。例如,在 Transformer-XL模型中:
A
i
j
=
Q
i
⋅
K
j
+
Q
i
⋅
r
i
−
j
{A}_{ij} = {Q}_{i} \cdot {K}_{j} + {Q}_{i} \cdot {r}_{i - j}
Aij=Qi⋅Kj+Qi⋅ri−j
其中,
r
i
−
j
{r}_{i - j}
ri−j 表示相对距离的嵌入,确保模型对不同长度的输入都能够表现良好。
RoPE的原理
RoPE (Rotary Position Embedding) 结合了绝对和相对位置编码的优点。
它使用旋转矩阵对每个位置进行编码,并直接将相对位置信息引入自注意力操作中。
旋转矩阵(Rotation Matrix)
RoPE的数学公式
假设输入向量为 ( x \in {\mathbb{R}}^{d} ) ,其第 ( i ) 个位置的编码向量可以表示为:
x i ′ = RoPE ( x i , i ) {x}_{i}^{\prime } = \operatorname{RoPE}\left( {{x}_{i}, i}\right) xi′=RoPE(xi,i)
其中:
RoPE
(
x
i
,
i
)
=
R
(
i
)
x
i
\operatorname{RoPE}\left( {{x}_{i}, i}\right) = R\left( i\right) {x}_{i}
RoPE(xi,i)=R(i)xi
这里的旋转矩阵
R
(
i
)
R\left( i\right)
R(i) 用二维旋转矩阵的张量乘积进行定义。
将向量 x i x_i xi 拆分为一系列长度为2的子向量 ( x i , 2 k , x i , 2 k + 1 ) \left( {{x}_{i,{2k}},{x}_{i,{2k} + 1}}\right) (xi,2k,xi,2k+1),
其旋转形式为:
R
(
i
)
[
x
i
,
2
k
x
i
,
2
k
+
1
]
=
[
cos
(
θ
k
i
)
−
sin
(
θ
k
i
)
sin
(
θ
k
i
)
cos
(
θ
k
i
)
]
[
x
i
,
2
k
x
i
,
2
k
+
1
]
R\left( i\right) \left\lbrack \begin{matrix} {x}_{i,{2k}} \\ {x}_{i,{2k} + 1} \end{matrix}\right\rbrack = \left\lbrack \begin{matrix} \cos \left( {{\theta }_{k}i}\right) & - \sin \left( {{\theta }_{k}i}\right) \\ \sin \left( {{\theta }_{k}i}\right) & \cos \left( {{\theta }_{k}i}\right) \end{matrix}\right\rbrack \left\lbrack \begin{matrix} {x}_{i,{2k}} \\ {x}_{i,{2k} + 1} \end{matrix}\right\rbrack
R(i)[xi,2kxi,2k+1]=[cos(θki)sin(θki)−sin(θki)cos(θki)][xi,2kxi,2k+1]
其中:
θ k = 10000 − 2 k / d π {\theta }_{k} = \frac{{10000}^{-{2k}/d}}{\pi } θk=π10000−2k/d
经过RoPE编码后的输入向量与旋转矩阵结合,使得位置信息被直接嵌入到输入向量中。
RoPE的相对位置编码特性
在自注意力计算中, RoPE能够引入相对位置信息。
令 ${q}_{i} $ 和 k j {k}_{j} kj分别表示位置 i i i 和 j j j 处的查询和键向量,经过RoPE编码后,其点积为:
⟨
q
i
′
,
k
j
′
⟩
=
⟨
R
(
i
)
q
i
,
R
(
j
)
k
j
⟩
=
⟨
q
i
,
R
(
j
−
i
)
k
j
⟩
\left\langle {{q}_{i}^{\prime },{k}_{j}^{\prime }}\right\rangle = \left\langle {R\left( i\right) {q}_{i}, R\left( j\right) {k}_{j}}\right\rangle = \left\langle {{q}_{i}, R\left( {j - i}\right) {k}_{j}}\right\rangle
⟨qi′,kj′⟩=⟨R(i)qi,R(j)kj⟩=⟨qi,R(j−i)kj⟩
这意味着,通过RoPE编码,查询和键向量之间的相对位置信息直接被融入自注意力操作中。 RoPE (Rotary Position Embedding) 结合了绝对和相对位置编码的优点,因为它既能够编码绝对位置信息,又能够在自注意力操作中有效融入相对位置信息。
RoPE如何结合绝对和相对位置编码
RoPE通过旋转矩阵为每个位置编码,实现了绝对位置信息的嵌入,同时在自注意力操作中有效地引入了相对位置信息。
RoPE的编码过程
给定一个嵌入维度为 ( \mathrm{d} ) 的输入向量 ( x \in {\mathbb{R}}^{d} ) ,其位置编码可以表示为:
x
i
′
=
RoPE
(
x
i
,
i
)
=
R
(
i
)
⋅
x
i
{x}_{i}^{\prime } = \operatorname{RoPE}\left( {{x}_{i}, i}\right) = R\left( i\right) \cdot {x}_{i}
xi′=RoPE(xi,i)=R(i)⋅xi
其中,旋转矩阵 $ R\left( i\right) $ 用二维旋转矩阵的张量乘积形式定义。
假设 $ {x}{i} $ 可以分解为一系列长度为2的子向量 $ \left( {{x}{i,{2k}},{x}_{i,{2k} + 1}}\right)$ ,其旋转矩阵为:
R
(
i
)
[
x
i
,
2
k
x
i
,
2
k
+
1
]
=
[
cos
(
θ
k
⋅
i
)
−
sin
(
θ
k
⋅
i
)
sin
(
θ
k
⋅
i
)
cos
(
θ
k
⋅
i
)
]
[
x
i
,
2
k
x
i
,
2
k
+
1
]
R\left( i\right) \left\lbrack \begin{matrix} {x}_{i,{2k}} \\ {x}_{i,{2k} + 1} \end{matrix}\right\rbrack = \left\lbrack \begin{matrix} \cos \left( {{\theta }_{k} \cdot i}\right) & - \sin \left( {{\theta }_{k} \cdot i}\right) \\ \sin \left( {{\theta }_{k} \cdot i}\right) & \cos \left( {{\theta }_{k} \cdot i}\right) \end{matrix}\right\rbrack \left\lbrack \begin{matrix} {x}_{i,{2k}} \\ {x}_{i,{2k} + 1} \end{matrix}\right\rbrack
R(i)[xi,2kxi,2k+1]=[cos(θk⋅i)sin(θk⋅i)−sin(θk⋅i)cos(θk⋅i)][xi,2kxi,2k+1]
其中:
θ
k
=
10000
−
2
k
/
d
π
{\theta }_{k} = \frac{{10000}^{-{2k}/d}}{\pi }
θk=π10000−2k/d
这个矩阵表示每个位置
i
i
i 的绝对编码。
RoPE引入相对位置编码
经过RoPE编码的查询和键向量,在自注意力机制中的点积计算如下:
⟨
q
i
′
,
k
j
′
⟩
=
⟨
R
(
i
)
q
i
,
R
(
j
)
k
j
⟩
=
⟨
q
i
,
R
(
j
−
i
)
k
j
⟩
\left\langle {{q}_{i}^{\prime },{k}_{j}^{\prime }}\right\rangle = \left\langle {R\left( i\right) {q}_{i}, R\left( j\right) {k}_{j}}\right\rangle = \left\langle {{q}_{i}, R\left( {j - i}\right) {k}_{j}}\right\rangle
⟨qi′,kj′⟩=⟨R(i)qi,R(j)kj⟩=⟨qi,R(j−i)kj⟩
其中:
R ( j − i ) = [ cos ( θ k ⋅ ( j − i ) ) − sin ( θ k ⋅ ( j − i ) ) sin ( θ k ⋅ ( j − i ) ) cos ( θ k ⋅ ( j − i ) ) ] R\left( {j - i}\right) = \left\lbrack \begin{matrix} \cos \left( {{\theta }_{k} \cdot \left( {j - i}\right) }\right) & - \sin \left( {{\theta }_{k} \cdot \left( {j - i}\right) }\right) \\ \sin \left( {{\theta }_{k} \cdot \left( {j - i}\right) }\right) & \cos \left( {{\theta }_{k} \cdot \left( {j - i}\right) }\right) \end{matrix}\right\rbrack R(j−i)=[cos(θk⋅(j−i))sin(θk⋅(j−i))−sin(θk⋅(j−i))cos(θk⋅(j−i))]
这意味着RoPE在计算查询和键之间的点积时,将相对位置信息直接融入了自注意力操作中。
RoPE的优点
-
**引入相对位置信息:**RoPE能够在自注意力操作中直接编码相对位置信息,使模型具有较好的相对位置感知能力。
-
**保持绝对位置信息:**RoPE的旋转矩阵编码每个位置的绝对信息,保持绝对位置感知能力。
-
**高效处理长序列:**在长序列任务中, RoPE相较于其他位置编码方案表现出更高的效率和性能。
RoPE通过旋转矩阵将绝对位置与相对位置信息相结合,既提供了绝对位置信息的精确性,又具备相对位置感知的灵活性,成为LLaMA等大型语言模型中重要的位置编码方法。
RoPE的优化与实现
优化策略
-
**分块并行处理:**RoPE可以通过分块并行化矩阵运算来提高效率。
-
**预计算旋转矩阵:**对于固定序列长度的任务,可以预计算并缓存旋转矩阵,减少实时计算量。
分组查询注意力(GQA)
分组查询注意力(Grouped Query Attention,简称GQA)作为一种改进的多头自注意力机制,最近在 Llama 3模型中得到了应用。
GQA旨在平衡多头自注意力和多查询注意力的优缺点,以提高推理效率并保持高性能。
多头自注意力机制 (Multi-Head Self-Attention)
多头自注意力机制是Transformer模型的核心组件,旨在通过并行化的方式捕获序列中的复杂关系。其计算公式如下:
- 查询(Query)、键(Key)和值(Value)向量:
Q = X W Q , K = X W K , V = X W V Q = X{W}_{Q},\;K = X{W}_{K},\;V = X{W}_{V} Q=XWQ,K=XWK,V=XWV
其中, W Q {W}_{Q} WQ、 W K {W}_{K} WK 和 W V {W}_{V} WV 分别是查询、键和值的投影矩阵。
- 自注意力计算:
A = softmax ( Q K T d k ) V A = \operatorname{softmax}\left( \frac{Q{K}^{T}}{\sqrt{{d}_{k}}}\right) V A=softmax(dkQKT)V
其中, d k d_k dk 为查询和键向量的维度。
- 多头组合:
Multi-Head ( Q , K , V ) = Concat ( head 1 , … , head h ) W O \operatorname{Multi-Head}\left( {Q, K, V}\right) = \operatorname{Concat}\left( {{\operatorname{head}}_{1},\ldots ,{\operatorname{head}}_{h}}\right) {W}_{O} Multi-Head(Q,K,V)=Concat(head1,…,headh)WO
其中, $ {\text{head}}{i} = Attention \left( {Q{W}{{Q}{i}}, K{W}{{K}{i}}, V{W}{{V}_{i}}}\right) $ , W O W_O WO 是输出矩阵。
多查询注意力机制 (Multi-Query Attention)
多查询注意力机制简化了多头自注意力,将所有注意力头共享相同的键和值投影,从而减少计算和内存开销。其结构如下:
- 查询向量:
Q = X W Q Q = X{W}_{Q} Q=XWQ
- 共享的键和值向量:
K = X W K , V = X W V K = X{W}_{K},\;V = X{W}_{V} K=XWK,V=XWV
- 自注意力计算:
A i = softmax ( Q W Q i K T d k ) V {A}_{i} = \operatorname{softmax}\left( \frac{Q{W}_{{Q}_{i}}{K}^{T}}{\sqrt{{d}_{k}}}\right) V Ai=softmax(dkQWQiKT)V
- 组合:
Multi-Query ( Q , K , V ) = Concat ( A 1 , … , A h ) W O \operatorname{Multi-Query}\left( {Q, K, V}\right) = \operatorname{Concat}\left( {{A}_{1},\ldots ,{A}_{h}}\right) {W}_{O} Multi-Query(Q,K,V)=Concat(A1,…,Ah)WO
GQA (Grouped Query Attention) 的原理
GQA结合了多头自注意力和多查询注意力的优点,通过分组的方式共享键和值投影。
其结构如下:
- 查询向量:
Q = X W Q Q = X{W}_{Q} Q=XWQ
- 分组的键和值向量:
将总计 ( N ) 个注意力头划分为 ( G ) 组,每组共享相同的键和值投影:
K g = X W K g , V g = X W V g , g = 1 , … , G {K}_{g} = X{W}_{{K}_{g}},\;{V}_{g} = X{W}_{{V}_{g}},\;g = 1,\ldots , G Kg=XWKg,Vg=XWVg,g=1,…,G
- 组内自注意力计算:
A g , i = softmax ( Q W Q g , i K g T d k ) V g {A}_{g, i} = \operatorname{softmax}\left( \frac{Q{W}_{{Q}_{g, i}}{K}_{g}^{T}}{\sqrt{{d}_{k}}}\right) {V}_{g} Ag,i=softmax(dkQWQg,iKgT)Vg
- 组合:
Grouped-Query ( Q , K , V ) = Concat ( A 1 , 1 , … , A G , h / G ) W O \text{Grouped-Query}\left( {Q, K, V}\right) = \operatorname{Concat}\left( {{A}_{1,1},\ldots ,{A}_{G, h/G}}\right) {W}_{O} Grouped-Query(Q,K,V)=Concat(A1,1,…,AG,h/G)WO
其中:
-
$ {W}{{Q}{g, i}} $为每个查询头的投影矩阵
-
${W}{{K}{g}} 和 {W}{{V}{g}} $ 为每组共享的键和值投影矩阵
GQA的优势
-
**计算效率:**减少了键和值的计算和内存开销。
-
**性能保持:**在性能上与多头自注意力相当,同时在效率上接近多查询注意力。
KVCache
在大型语言模型(LLM)的推理和训练过程中,缓存机制对于提高模型的效率和性能至关重要。
KVCache(Key-Value Cache)是一种用于缓存自注意力机制中键和值的优化策略,可以显著加速推理过程。
自注意力机制中的键和值
在Transformer的自注意力机制中,键和值的计算是核心步骤:
- 输入向量投影:
-
查询向量 Q = X W Q Q = X{W}_{Q} Q=XWQ
-
键向量 $K = X{W}_{K} $
-
值向量 V = X W V V = X{W}_{V} V=XWV
- 自注意力计算:
A = softmax ( Q K T d k ) V A = \operatorname{softmax}\left( \frac{Q{K}^{T}}{\sqrt{{d}_{k}}}\right) V A=softmax(dkQKT)V
其中, d k d_k dk 为键向量的维度。
多头自注意力机制
多头自注意力机制将上述计算扩展为并行计算多个注意力头:
Multi-Head
(
Q
,
K
,
V
)
=
Concat
(
head
1
,
…
,
head
h
)
W
O
\operatorname{Multi-Head}\left( {Q, K, V}\right) = \operatorname{Concat}\left( {{\operatorname{head}}_{1},\ldots ,{\operatorname{head}}_{h}}\right) {W}_{O}
Multi-Head(Q,K,V)=Concat(head1,…,headh)WO
每个注意力头使用不同的投影矩阵,捕获序列中的不同关系。
推理过程中的重复计算
在推理过程中,模型可能会接收一系列的连续输入。
对于每个新的输入,模型会重新计算键和值,导致大量的重复计算。
这种情况在长文本生成或对话任务中尤其明显。
KVCache的原理
KVCache旨在减少自注意力机制中的重复计算。
它通过缓存先前计算的键和值,在推理过程中复用这些信息,从而显著提高推理效率。
KVCache的工作流程
-
**初始化缓存:**初始化一个空的键和值缓存。
-
计算当前输入的键和值:
K t = X t W K , V t = X t W V {K}_{t} = {X}_{t}{W}_{K},\;{V}_{t} = {X}_{t}{W}_{V} Kt=XtWK,Vt=XtWV
-
**更新缓存:**将当前输入的键和值追加到缓存中。
-
**自注意力计算:**在新的键和值上进行自注意力计算:
A = softmax ( Q K cache T d k ) V cache A = \operatorname{softmax}\left( \frac{Q{K}_{\text{cache }}^{T}}{\sqrt{{d}_{k}}}\right) {V}_{\text{cache }} A=softmax(dkQKcache T)Vcache
- 清理缓存:
在某些情况下(如达到上下文长度上限),需要清理缓存或进行部分替换。
KVCache在LIama 3中的应用
Llama 3作为Meta发布的最新大型语言模型,采用了KVCache机制以提高推理效率。
Llama 3通过引入KVCache机制,将推理过程中的键和值缓存下来,在后续的注意力计算中复用这些缓存,以减少重复计算。
-
**键和值缓存:**在每一层的自注意力机制中,键和值的计算结果都会被缓存。
-
**缓存更新策略:**每次新的输入到来时,缓存会根据新的键和值进行更新,确保自注意力计算的准确性。
-
**缓存清理:**当达到上下文长度上限时,缓存会被部分或全部清理。
KVCache作为一种优化策略,通过缓存自注意力机制中的键和值,有效减少了推理过程中的重复计算。
LLM文本生成的两个阶段:预填充(prefill)和解码(decode)
在大型语言模型(LLM)生成文本的过程中,通常会涉及两个阶段:预填充(prefill)和解码(decode) 。
预填充(Prefill)阶段
预填充阶段主要用于准备初始上下文,为模型生成后续文本提供基础。
在这个阶段,模型会处理输入的初始文本(通常是用户提供的提示或上下文),并生成相应的内部状态 (如KV缓存)。
具体过程
- 初始上下文输入:
- 用户提供一个初始文本作为提示,例如问题、句子或段落。
- 该文本通常被编码为一系列token(tokens)。
- 初始上下文的注意力计算:
-
模型对输入的所有token进行处理,生成键和值,并存储在KV缓存中。 3. 初始概率分布:
-
计算最后一个token的概率分布,作为生成下一个token的起始点。
解码 (Decode) 阶段
解码阶段是模型根据预填充阶段准备好的上下文,生成后续文本的过程。
在这个阶段,模型会逐步生成每个新token,并在每次生成后更新KV缓存。
具体过程
- token生成:
-
从预填充阶段的初始概率分布中采样一个新token,作为下一个输入。
-
根据用户选择的策略(如贪心、采样、温度)进行采样。
- KV缓存更新:
-
使用新token进行前向传递,更新KV缓存,生成新的键和值。
-
生成新的概率分布,用于采样下一个token。
- 循环迭代:
- 重复步骤1和步骤2,直到达到生成的最大长度或满足停止条件。
下面详细解释这两个阶段KVCache的大小和更新方法,以及它们的不同之处。
预填充阶段的KVCache
大小和布局
在预填充阶段, KVCache的大小根据输入的上下文长度和模型层数确定。
- 键和值的形状:
values: (num_layers, batch_size, num_heads, seq_len, head_dim)
其中:
-
num_layers
:模型的总层数 -
batch_size
:输入的批次大小 -
num_heads
: 多头注意力机制的头数 -
seq_len
: 输入序列的长度 -
head_dim
:每个头的嵌入维度
- KVCache大小示例:
假设一个模型有 12 层,每层有 16 个头,每个头的维度为 64 ,批次大小为 1 ,上下文长度为512 ,则KVCache的大小如下:
-
键缓存大小:(12,1,16,512,64)
-
值缓存大小:(12,1,16,512,64)
更新方法
预填充阶段, KVCache会随着输入的每个token进行填充和更新。
初始填充: 对于给定的输入上下文, 模型计算每一层的键和值向量, 并将它们存储到缓存中
解码阶段的KVCache
大小和布局
在解码阶段, KVCache的大小主要取决于上下文的长度和逐步生成的长度。
与预填充阶段不同的是,解码阶段的KVCache随着生成的序列不断增长。
- 键和值的形状:
values: (num_layers, batch_size, num_heads, max_seq_len, head_dim)
其中:
max_seq_len
: 预填充序列长度 + 生成序列长度
- KVCache大小示例:
假设初始输入长度为 512 ,逐步生成长度为 100 ,则KVCache的最大长度为 612。
-
键缓存大小:(12,1,16,612,64)
-
值缓存大小:(12,1,16,612,64)
总结
-
KVCache的大小:
- 预填充阶段:缓存的大小取决于初始输入的长度。
- 解码阶段:缓存的大小随着生成序列长度的增长而不断增加。
-
KVCache的更新方式:
- 预填充阶段:一次性将整个初始上下文填充到缓存中。
- 解码阶段:逐步生成并更新缓存,每次生成新token时追加新的键和值。
逐步生成长度(Incremental Generation Length)指的是语言模型在解码阶段, 根据输入的初始文本(如提示词prompt), 一个token一个token地逐步生成输出序列的过程。
在解码开始时, 语言模型只有输入的初始文本,如"用英文写一首关于春天的诗"。然后模型根据这个输入, 预测并生成下一个最可能的token,如"Spring"。之后"Spring"会被添加到之前的输入中, 形成新的输入序列,模型再根据这个更新的输入去预测下一个token,如"is"。如此重复,直到满足某个停止条件(如遇到截止符<\s>, 或达到最大生成长度), 整个解码过程就结束了。
这个过程中, 生成的token序列(如"Spring is coming")的长度, 就被称为逐步生成长度。它一般远小于模型的最大序列长度限制。比如最大序列长度为2048的模型, 在实际解码时, 往往只需要生成几十到几百个 token就可以得到满意的结果。
在解码阶段引入键值缓存(KV Cache), 主要就是为了提高处理这些逐步生成token的效率。随着解码的进行, 每生成一个新的token, 就动态地将其添加到KV Cache中, 更新键值对。这样模型就可以得到到不断增长的已生成序列,从而实现循环机制。
所以总结来说, 逐步生成长度就是语言模型解码时动态生成token序列的长度, 反映了解码的进度。
合理设置逐步生成长度, 优化解码算法和KV Cache机制, 对于提升语言模型推理速度和生成质量至关重要。
总结
使用KV Cache,在 token-by-token 推理中:
-
初始token 的计算复杂度为 O ( d 2 ) O\left( {d}^{2}\right) O(d2)。
-
后续每个token 的计算复杂度为 O ( N d ) O\left( {Nd}\right) O(Nd) ,而不是重新计算整个序列,从而避免了二次计算复杂度。
其中 N N N 是序列长度, d d d 是每个序列元素的维度。
因此,通过使用KV Cache,可以将token-by-token推理的计算复杂度从 O ( N 2 d ) O\left({{N}^{2}d}\right) O(N2d) 降低到接近线性关系,
这种优化使得大模型能够处理更长的序列,并在推理时显著提高了效率。
需要注意的是,这里讨论的是推理阶段的计算复杂度。
在训练阶段,自注意力的计算复杂度仍然是 O ( N 2 ) O\left({{N}^{2}}\right) O(N2),但是可以对整个序列进行并行计算。
在实际应用中, 推理阶段的效率往往更为重要, 因为它直接影响了模型的响应速度和吞吐量。