当前位置: 首页 > article >正文

LLaMA详解

LLaMA 进化史

LLaMA

大规模语言模型(Large Language Model, LLM)的快速发展正在以前所未有的速度推动人工智能(AI)技术的进步。

作为这一领域的先行者, Meta在其LLaMA(Large Language Model Meta AI)系列模型上取得了一系列重大突破。

近日, Meta官方正式宣布推出LLaMA-3, 作为继LLaMA-1、LLaMA-2和Code-LLaMA之后的第三代旗舰模型。

LLaMA-3在多个权威基准测试中均取得了全面领先的成绩, 超越了业界同类最先进模型, 再次刷新了LLM的性能上限。

LLaMA-1 系列

LLaMA-1是Meta公司在2023年2月发布的首款大规模语言模型,它提供了7B、13B、30B和65B四种不同的参数规模版本。

这些版本均在超过1T(万亿)个tokens的广泛语料库上进行了预训练,有效利用了海量数据的优势。

在这些版本中,参数量最大的65B版本在2048张A100 80G GPU上训练了21天,表现出色,在多数基准测试中超越了具有175B参数的GPT-3。

除了卓越的性能,LLaMA-1还采取了开源策略,极大地方便了研究者和开发者的获取与使用,迅速成为开源社区中极受欢迎的大规模语言模型之一。

围绕LLaMA-1,形成了一个活跃的生态系统,涌现了许多基于该模型的下游应用、微调模型和衍生变体。

LLaMA-2 系列

LLaMA-2 系列,Meta在2023年7月推出的第二代大规模语言模型,是对LLaMA-1系列功能的扩展和改进。

此系列包括7B、13B和70B三个不同规模的版本,并特别推出了经过指令微调的对话模型LLaMA-2-Chat,旨在增强模型在对话和任务执行方面的能力。

主要特点包括:

  • 增强的模型性能: 相较于LLaMA-1,LLaMA-2系列在更广泛的语料上进行了预训练,这大幅增强了其文本生成和理解能力,确保模型在各类NLP任务中的领先性能。
  • 指令微调技术: LLaMA-2-Chat通过对大量的高质量指令-回答对进行微调,优化了模型对复杂指令的响应能力,使其在对话和具体任务执行中表现更为精确和自然。
  • 开源政策: 继承LLaMA-1的开放精神,LLaMA-2以开源形式发布,不仅丰富了LLaMA的生态系统,也激励了基于这一平台的创新性研究和应用开发。

通过这些改进,LLaMA-2系列不仅在技术层面上提供了显著的升级,也在开放源代码和社区参与方面持续发挥着积极的推动作用。此举确保了LLaMA系列在全球语言模型领域的持续领先和影响力扩展。

Llama2

LLaMA-3 系列

2024年4月, Meta 发布了最新一代的大规模语言模型系列 LLaMA-3。

作为 LLaMA 系列的第三代产品, LLaMA-3 在 LLaMA-1、LLaMA-2 和 Code-LLaMA 的基础上实现了全面升级和优化。

在多个权威基准测试中, LLaMA-3 的表现全面领先于业界同类最先进模型,再次刷新了大模型的性能上限。

LLaMA-3 在长序列建模能力上取得了重大突破。

相比 LLaMA-2 只能处理最长 2048 个 token 的文本, LLaMA-3 将这一上限提高到了 8192 个 token, 使其可以轻松应对超长文档、多轮对话等复杂场景。

同时, LLaMA-3 还采用了全新的 tokenizer, 将分词器更换为tiktoken,与GPT4保持一致。

其词表大小达到了 128K, 远超 LLaMA-2 的 32K。更大的词表意味着更精细的文本表示和更强的语言泛化能力。

注:

  • Llama-1模型采用了BPE算法进行分词,使用sentencepiece实现。其词表大小为32k。
  • Llama-2维持了与Llama-1相同的架构和分词器,但将上下文长度扩展到了4k。

数据规模是大模型取得突破性进展的另一个关键因素。

LLaMA-3 的预训练语料规模超过了 15 T (150 万亿) token, 是 LLaMA-2 的 7 倍多。

海量高质量的预训练数据, 为 LLaMA-3 提供了更全面、更深入的世界知识和语言理解能力。

同时, Meta 还对预训练数据进行了更细粒度的筛选和清洗, 进一步提高了数据质量。

凭借优秀的模型架构和海量的预训练数据, LLaMA-3 在下游任务上实现了全面领先。

在自然语言推理、机器阅读理解、常识问答等多个基准测试中, LLaMA-3 的表现都超越了业界最先进的同规模模型。

此外, LLaMA-3 在推理能力、代码生成、指令跟随、多语言支持等方面也有长足进步, 使其成为更加全能和可控的通用人工智能助手。

随着 LLaMA-3 的发布, Meta 再次引领了大模型技术的发展潮流。

作为当前最先进的开源大模型, LLaMA-3 必将掀起新一轮的研究和应用热潮, 为人工智能的进步注入强大动力。

Llama3

Llam3 400B+

LLaMA-3 的基础版模型和指令微调模型

Llama3 Model

Meta发布了LLaMA 3模型的几个变体:

  1. meta-llama/Meta-LLama-3-8B: 这是基础版的LLaMA 3模型, 有80亿个参数。

  2. meta-llama/Meta-LLama-3-8B-Instruct: 这是一个经过指令微调的8B参数LLaMA 3模型变体。

  3. meta-llama/Meta-LLama-3-70B-Instruct: 这是一个更大的经过指令微调的LLaMA 3模型, 有700亿个参数。

  4. meta-llama/Meta-LLama-3-70B: 这是没有经过指令微调的基础版70B参数LLaMA 3模型。

总的来说, Meta发布了8B和70B两种参数规模的LLaMA 3模型, 每种规模都有提供经过指令微调的变体。

指令微调(Instruction Tuning)的变体和基础版模型之间有以下几个主要区别:

  1. 训练目标不同: 基础版模型通常使用语言建模(Language Modeling)作为训练目标, 即根据上文预测下一个单词。而指令微调的变体则引入了额外的指令数据, 通过监督学习让模型学会理解和执行自然语言指令。
  2. 训练数据不同: 基础版模型主要在大规模无标签文本语料上进行预训练。而指令微调的变体需要构建专门的指令数据集, 其中包含大量的自然语言指令及其对应的执行结果。
  3. 应用场景不同: 基础版模型可以用于各种NLP任务, 但在执行具体指令时可能表现欠佳。指令微调的变体则专门针对指令理解和执行进行了优化, 在问答、对话、任务完成等场景中表现更加出色。
  4. 交互方式不同: 使用基础版模型时, 用户需要根据具体任务设计prompts。而指令微调的变体允许用户使用自然语言直接下达指令, 交互更加直观和方便。
  5. 可控性不同: 基础版模型生成的内容可能不够可控, 容易出现幻觉或不合适的言论。指令微调引入了人类反馈, 可以更好地引导模型生成安全、可靠的内容。
  6. 推理效率不同: 指令微调通常会引入一些控制机制, 如提示工程、示例学习等, 可能降低推理速度。但一些高效的微调方法如LoRA、Prefix Tuning等可以在保证性能的同时加快推理。

总的来说, 指令微调赋予了语言模型更强的指令理解和执行能力, 使其在实际应用中更加智能、高效和可控。

LLaMA衍生模型和生态

总览

LLaMA系列模型的开源性和可访问性, 使其成为了LLM领域的重要研究和应用平台。

围绕LLaMA, 涌现出了一个繁荣的开源生态, 催生了如 Alpaca、Vicuna、LLaVA 等一系列优秀的衍生模型。

这些模型在LLaMA的基础上,结合领域数据和下游任务进行针对性优化, 在对话、多模态、开放域问答等方面取得了瞩目成绩。

LLaMA的开源, 极大地降低了LLM研究的门槛, 推动了AI技术的普惠进程。

LLaMA Eco

图中展示的是LLaMA模型及其衍生模型的概览,包括了不同的训练方法、数据类型、以及一些特定的应用场景。

以下是对图中内容的描述和解释:

  1. **继续预训练(Continue pre-training):**指的是在现有模型基础上,使用更多数据继续训练以提升模型性能的过程。

  2. **LLaMA:**是Meta公司开发的一个大型语言模型,图中显示了基于LLaMA模型的多种扩展和优化路径。

  3. **参数高效微调(Parameter-efficient fine-tuning):**这是一种微调技术,通过调整模型中较少的参数来适应新的任务,以提高效率。

  4. **模型继承(Model inheritance):**指新模型继承或基于旧模型的架构和参数进行开发。

  5. **指令微调(Instruction tuning):**通过指令来指导模型微调,使其更好地遵循给定的任务指令。

  6. **全参数微调(Full parameter fine-tuning):**指在微调过程中调整模型的所有参数。

  7. **数据继承(Data inheritance):**新模型使用旧模型训练过的数据集作为训练数据的一部分。

  8. **中文数据(Chinese data):**指使用了中文语言数据进行训练的模型。

  9. **Open-Chinese-LLaMA:**可能是一个针对中文优化的LLaMA模型。

  10. **合成数据(synthetic data):**指通过技术手段生成的非真实世界数据,用于训练。

  11. **Vicuna, Panda, Alpaca, Goat:**基于LLaMA模型的不同变种或特定用途的模型名称。

  12. **RLHF:**指的是通过人类反馈进行强化学习的微调方法(Reinforcement Learning from Human Feedback)。

  13. **Yulan-Chat:**专门为中文对话优化的模型。

  14. **PKU-Beaver, BiLLa:**指北京大学开发的模型或者是双语语言模型。

  15. **Cornucopia:**指包含多种数据类型的综合数据集。

  16. **Lawyer, LLaVA, [BELLE]:**针对特定领域(如法律)或具有特定功能(如视觉语言模型)的模型。

  17. **MiniGPT-4, Ziya, QiZhenGPT, Baize:**不同大小或针对特定任务优化的模型。

  18. **Guanaco, Chatbridge, Koala:**特定的多模态模型或其他应用领域的模型。

  19. **VisionLLM, TaoLi, InstructBLIP:**结合了视觉和语言任务的模型。

  20. **ChatMed, Adapter, PandaGPT, LAWGPT, BenTsao:**针对医疗、适配器技术、法律等特定领域的模型。

  21. **多模态模型(Multimodal models):**能够处理并整合来自多种感官模式(如视觉、听觉、文本)的模型。

  22. **数学(Math)、金融(Finance)、医学(Medicine)、法律(Law)、双语(Bilingualism)、教育(Education):**模型应用的领域。

这张图展示了以 LLaMA 系列模型为核心的大语言模型生态系统。

  1. 继续预训练(Continue pre-training): 一些模型如 Chinese LLaMA、Chinese Alpaca 等在 LLaMA 的基础上加入了中文数据继续预训练,以提高中文任务的表现。
  2. 指令微调(Instruction tuning): 通过在指令数据集上微调 LLaMA, 衍生出了 Alpaca、Vicuna 等模型, 使其能够执行指令跟随和问答对话等任务。
  3. 参数高效微调(Parameter-efficient fine-tuning): 使用 LoRA、Prefix Tuning 等参数高效微调方法在下游任务数据上微调 LLaMA,得到 Alpaca Lora 等模型。
  4. 全参数微调(Full parameter fine-tuning): 在特定垂直领域数据上对 LLaMA 进行全参数微调, 如在医疗对话数据上微调得到的 BianTsao 模型。
  5. 多模态模型(Multimodal models): 将 LLaMA 扩展到多模态, 如支持图像输入的 LLaVA、MinGPT-4, 语音交互的 InstructBLIP 等。

总的来说, 该图全面地展示了以 LLaMA 为基础衍生出的丰富多样的大模型生态,涵盖了主要的优化训练范式、任务类型和具体模型。

值得注意的是, 中文模型在该生态中占据了重要地位。

Alpaca

Alpaca 是由斯坦福大学计算机科学系博士生 Eric Wang 等人开发的一个基于 LLaMA-7B 模型的衍生大模型。

其主要特点包括:

  1. 指令精调: Alpaca 在 LLaMA-7B 的基础上, 使用了一个包含 5.2 万条指令数据的数据集进行了监督微调(Supervised Fine-tuning)。这使得 Alpaca 能够很好地理解和执行自然语言指令, 具备类似ChatGPT 的对话交互能力。
  2. 开源共享: Alpaca 项目的代码和训练数据都已经在 GitHub 上完全开源, 允许研究者和开发者基于此进行二次开发。这极大地降低了构建指令跟随型对话系统的门槛。
  3. 性能优异: 在标准的指令跟随任务基准如 MMLU 上, Alpaca 的表现已经接近 ChatGPT 等封闭模型, 而参数量和计算开销却小很多。这说明了在 LLaMA 基础上进行指令精调的有效性。
  4. 多语言支持: 得益于 LLaMA 模型本身强大的多语言能力, Alpaca 也具备了一定的多语言处理能力, 尽管主要还是针对英文进行了优化。
  5. 可控性强: 由于 Alpaca 的训练数据是人工标注的高质量指令数据, 因此其生成的内容更加可控, 在事实性、安全性方面表现出色。
  6. 开源生态: Alpaca 的开源进一步推动了 LLaMA 周边生态的繁荣, 催生了一系列基于 Alpaca 的衍生模型和应用。

总的来说, Alpaca 是 LLaMA 家族中一个代表性的指令精调模型, 其开源性、可访问性和优异的性能, 使其成为了开源界 ChatGPT 的有力竞争者。

Alpaca 的成功证明了在一个强大的基础模型上, 利用高质量的指令数据进行针对性微调, 可以显著提升模型在对话交互任务上的表现, 同时还能保持较强的可控性。

Vicuna

Vicuna 是由加州大学伯克利分校、卡内基梅隆大学、斯坦福大学等机构的研究者联合开发的一个基于LLaMA 的开源对话语言模型。

其主要特点包括:

  1. 大规模指令精调: Vicuna 使用了一个包含 7 万多条对话数据的指令数据集对 LLaMA-13B 进行了微调。这些数据主要来自于 ShareGPT 收集的真实人类对话,质量相当高。相比 Alpaca 的 5 万条指令数据,Vicuna 的训练语料更加丰富和多样化。
  2. 多轮对话能力: 得益于大规模高质量对话数据的训练, Vicuna 具备了出色的多轮对话能力。它能够很好地理解对话的上下文, 根据之前的对话内容生成连贯且相关的回复。在这一点上, Vicuna 比Alpaca 表现得更为出色。
  3. 训练计算效率: Vicuna 在训练过程中采用了一系列优化手段, 如混合精度训练、梯度累积、DeepSpeed 等, 使得在有限的计算资源下也能高效地完成大模型的训练。这为开源社区提供了一个很好的模型训练范例。
  4. 开源共享: 与 Alpaca 类似, Vicuna 的代码和训练数据也已经完全开源, 供社区使用和研究。同时, 研究团队还发布了经过训练的检查点(checkpoint), 可以直接进行推理和微调。
  5. 人类对齐度高: 通过在大规模人类对话数据上的训练, Vicuna 学会了更加自然、人性化的交互方式。其生成的回复在流畅度、连贯性、同理心等方面都有更好的表现,给人一种更加亲切、自然的感受。
  6. 商业化应用: 进一步扩大了 Vicuna 的影响力和应用范围。

总的来说, Vicuna 代表了开源对话大模型的最新进展。其充分利用了大规模高质量的对话数据,在 LLaMA 的基础上实现了全面的性能提升,尤其是在多轮对话和人类对齐度方面表现突出。

Vicuna 的开源和商业应用,为构建更加智能、自然的对话系统提供了重要参考。

LLaVA

LLaVA,全称为 Large Language and Vision Assistant(大型语言和视觉助手),是一种新型的大型多模态模型。

它的目标是开发一种通用视觉助手,能够遵循语言和图像指令来完成各种现实世界任务。

LLaVA 结合了自然语言处理(NLP)和计算机视觉(CV)的能力,通过理解视觉内容并根据语言指令进行操作,从而实现对图像和文本的深入理解与交互。

LLaVA 模型的核心在于其多模态架构,它将视觉编码器(如基于 Transformer 的视觉模型)与语言模型(如 LLaMA)结合起来,形成一个能够处理图文信息的集成系统。这种设计使得 LLaVA 能够执行包括图像标注、视觉问答、文本到图像生成等在内的多模态任务。

其主要特点包括:

  1. 多模态融合: LLaVA 在 LLaMA 的基础上引入了视觉特征,实现了语言和视觉信息的融合。具体来说, 它使用 BLIP-2 模型提取图像特征,然后将其与文本表示进行交互, 生成与图像相关的自然语言描述或回答。
  2. 图像理解能力: 得益于多模态融合, LLaVA 具备了强大的图像理解和分析能力。它可以根据输入的图像生成详细的描述,回答与图像相关的问题,甚至进行开放式的图像内容分析。
  3. 通用语言能力: 作为 LLaMA 的衍生物,LLaVA 继承了原模型优秀的自然语言处理能力。因此它不仅可以处理图像相关的任务,在通用的语言理解、生成等方面也有不俗的表现。
  4. 零样本学习: LLaVA 支持零样本学习(zero-shot learning), 即无需在下游任务上进行微调, 直接利用预训练的知识完成推理。这使得 LLaVA 可以灵活地应对各种形式的多模态任务, 无需重新训练模型。
  5. 大规模预训练: LLaVA 在大规模图文对数据上进行了预训练, 学习了丰富的视觉-语言对齐知识。这为其在下游任务上的优异表现奠定了基础。同时, 预训练也提高了模型的泛化能力和鲁棒性。
  6. 开源开放: 与其他 LLaMA 衍生物类似,LLaVA 的代码和预训练模型权重都已经开源。这为进一步研究和应用多模态大模型提供了宝贵的资源和参考。

总的来说, LLaVA 代表了 LLaMA 在多模态领域的重要拓展。

小结

通过引入视觉特征与语言表示的交互, LLaVA 实现了对图像内容的深度理解和分析。

同时, 它在通用语言任务上的出色表现, 证明了多模态学习对于提升语言模型性能的积极作用。

可以预见, LLaVA 将为多模态大模型的研究和应用开辟新的方向, 推动人工智能向更加全面、贴近人类智能的方向发展。

Llama3原理解读

Llama 3 架构图

0193f875-8bf0-7806-ab37-494c1af37d61_0_289_502_1877_1252_0.jpg

Llama 3 模型架构

Llama 3模型基于标准的Transformer架构进行了多项改进,包括更高的效率和更好的性能。

接下来,我们详细探讨Llama 3架构的主要特点。

模型概述

  • 模型类型:基于解码器的Transformer

  • 参数量:8B和70B

  • 上下文长度:8192 (LLaMA-1和LLaMA-2的上下文长度分别为2048,4096)

  • 注意力机制:分组查询注意力 (Grouped Query Attention, GQA)

结构组成

Llama 3的模型架构主要包括以下组件:

分层堆叠

  • 嵌入层:将输入的token转换为固定维度的嵌入表示。

  • 自注意力层:包含多头自注意力机制和归一化。

  • 前馈网络 (FFN) : 包含激活函数和两层全连接网络。

  • 位置编码:采用RoPE(Rotary Position Embedding)位置编码。

Llama 3与传统Transformer架构的比较

相同之处

  • 基本结构:两者都基于标准的Transformer解码器架构。

  • 多头自注意力:都使用多头自注意力机制。

  • 前馈网络 (FFN): 包含两层全连接网络。

不同之处

  1. 注意力机制:
  • 传统Transformer: 使用标准多头自注意力机制。
  • Llama 3:引入了分组查询注意力(GQA)以提高效率。
  1. 激活函数:
  • 传统Transformer: 使用ReLU或GELU激活函数。
  • LIama 3:使用更高效的SwiGLU激活函数。
  1. 位置编码:
  • 传统Transformer: 使用正弦-余弦位置编码或Learned Positional Embedding。
  • LIama 3:使用RoPE(Rotary Position Embedding)位置编码。
  1. 上下文长度:
  • 传统Transformer: 通常为512或1024的上下文长度。
  • Llama 3: 上下文长度增加到8192。

LLama3 模型组件详解

RMSNorm

归一化技术是一种在神经网络中用于归一化激活值的技术。

它可以提高训练的稳定性,加快收敛速度并改善模型的泛化性能。

下面是Transformer架构中常见的两种归一化技术:

层归一化 (Layer Normalization)

0193f875-8bf0-7806-ab37-494c1af37d61_1_584_2435_1366_725_0.jpg

层归一化是Transformer架构中最早引入的归一化技术之一。

它将一个输入序列的所有特征归一化,使得每个样本的激活值在特征维度上具有均值为0、方差为1的分布。其公式如下:

x ^ i = x i − μ σ + ϵ {\widehat{x}}_{i} = \frac{{x}_{i} - \mu }{\sigma + \epsilon} x i=σ+ϵxiμ
其中, μ \mu μ 是均值, σ \sigma σ是标准差, ϵ \epsilon ϵ是一个很小的数以避免除以零。

RMSNorm

RMSNorm是一种简化版的层归一化,它通过均方根(Root Mean Square)计算得到归一化尺度,不需要计算均值和标准差。

其公式如下:

x ^ i = x i 1 n ∑ j = 1 n x j 2 + ϵ {\widehat{x}}_{i} = \frac{{x}_{i}}{\sqrt{\frac{1}{n}\mathop{\sum }\limits_{{j = 1}}^{n}{x}_{j}^{2} + \epsilon }} x i=n1j=1nxj2+ϵ xi

与层归一化相比, RMSNorm减少了计算复杂度,同时保持了良好的归一化效果。

RMSNorm在LLaMA中的应用

在LLaMA模型中, RMSNorm主要用于对每个时间步的隐藏状态向量进行归一化

每个时间步的输入向量被独立归一化,适用于处理变长序列和保持时间步间的独立性。

具体来说, RMSNorm在Transformer架构中的应用通常在以下几个位置:

  1. 多头自注意力机制的输出:
  • 对于每个时间步的隐藏状态,在经过多头自注意力机制后的输出进行归一化。
  1. 前馈神经网络的输出:
  • 对于每个时间步的隐藏状态,在经过前馈神经网络 (Feed-Forward Neural Network, FFN) 后的输出进行归一化。
为什么选择RMSNorm?

在LLaMA中,选择RMSNorm作为主要的归一化策略,原因如下:

  1. **提高训练效率:**RMSNorm在计算上比层归一化更高效。在大型语言模型中,效率的提升尤其显著。

  2. **提高训练稳定性和泛化性能:**研究发现, RMSNorm可以提供更稳定的训练过程,减少梯度爆炸或消失的风险,同时保持较高的泛化能力。

  3. **预归一化的优势:**LLaMA采用了RMSNorm的预归一化变体,将归一化操作放在Transformer模块的主要层之前,而不是之后。

这一策略进一步提升了模型的训练稳定性。

0193f875-8bf0-7806-ab37-494c1af37d61_3_599_94_1376_2047_0.jpg

RMSNorm vs LayerNorm的对比

LLaMA中的实验结果显示, RMSNorm在效率和性能方面都有明显的提升:

  • 效率提升:与层归一化相比, RMSNorm在效率上提升了10-50%。

  • 性能对比:在实际应用中, RMSNorm在稳定性和泛化性能上与层归一化不相上下,有时甚至更优。

SwiGLU激活函数

激活函数在大型语言模型(LLM)的发展过程中扮演着关键角色,它影响着模型性能的稳定性与泛化能力。

LLaMA3在其前馈层中采用了一种特殊的激活函数——SwiGLU。

激活函数是深度神经网络中的关键组件,负责引入非线性,使模型能够学习复杂的数据模式。

常见的激活函数包括:

ReLU (Rectified Linear Unit)

0193f875-8bf0-7806-ab37-494c1af37d61_4_319_232_1822_725_0.jpg

ReLU是目前使用最广泛的激活函数之一。其计算公式如下:

ReLU ⁡ ( x ) = max ⁡ ( 0 , x ) \operatorname{ReLU}\left( x\right) = \max \left( {0, x}\right) ReLU(x)=max(0,x)
特点:

  • 计算简单,梯度为0或1。

  • 可能导致“神经元死亡”问题。

GELU (Gaussian Error Linear Unit)

GELU是一种基于高斯分布的激活函数,在GPT-3等模型中使用。

其公式如下:

GELU ⁡ ( x ) = x ⋅ 0.5 ⋅ ( 1 + erf ⁡ ( x 2 ) ) \operatorname{GELU}\left( x\right) = x \cdot {0.5} \cdot \left( {1 + \operatorname{erf}\left( \frac{x}{\sqrt{2}}\right) }\right) GELU(x)=x0.5(1+erf(2 x))
其中, erf为误差函数。

特点:

  • 与ReLU相比, GELU具有平滑的激活曲线,能提供更高的性能。

0193f875-8bf0-7806-ab37-494c1af37d61_4_340_2233_1204_893_0.jpg

Swish

Swish是一种平滑且连续的激活函数,在Transformer等模型中应用广泛。其公式如下:

Swish ⁡ ( x ) = x ⋅ σ ( x ) \operatorname{Swish}\left( x\right) = x \cdot \sigma \left( x\right) Swish(x)=xσ(x)

其中, σ ( x ) = 1 1 + e − x \sigma \left( x\right) = \frac{1}{1 + {e}^{-x}} σ(x)=1+ex1 为标准Sigmoid函数。

特点:

  • 平滑的曲线有助于稳定梯度流。
  • 通常比ReLU表现更好。

0193f875-8bf0-7806-ab37-494c1af37d61_5_668_805_1129_821_0.jpg

GLU (Gated Linear Unit)

GLU(Gated Linear Unit)是一种用于神经网络中的激活函数,最初由 Yann Dauphin 等人在论文《Language Modeling with Gated Convolutional Networks》里提出。

GLU 的设计灵感来自门控机制,它通过引入门控操作来控制信息的流动。

该操作可以看作是引入了一种动态的选择机制,以在模型中选择性地传递信息。

GLU 的计算公式如下:

GLU ⁡ ( X ) = ( X ⋅ W 1 + b 1 ) ⊙ σ ( X ⋅ W 2 + b 2 ) \operatorname{GLU}\left( X\right) = \left( {X \cdot {W}_{1} + {b}_{1}}\right) \odot \sigma \left( {X \cdot {W}_{2} + {b}_{2}}\right) GLU(X)=(XW1+b1)σ(XW2+b2)
其中:

  • X X X 是输入特征。

  • W 1 , W 2 {W}_{1},{W}_{2} W1,W2 是两个不同的权重矩阵。

  • b 1 , b 2 {b}_{1},{b}_{2} b1,b2 是偏置。

  • σ \sigma σ 是 Sigmoid 函数,作为门控信号的激活函数。

  • ⊙ \odot 表示元素级别的乘法(Hadamard 积)。

在 GLU 中,输入特征 ( X ) 被分成两个部分:

  1. 线性部分: X ⋅ W 1 + b 1 X \cdot {W}_{1} + {b}_{1} XW1+b1

  2. 门控部分: σ ( X ⋅ W 2 + b 2 ) \sigma \left( {X \cdot {W}_{2} + {b}_{2}}\right) σ(XW2+b2)

门控部分使用 Sigmoid 函数输出一个介于 0 和 1 之间的权重矩阵,从而选择性地让部分信息通过。

因此, GLU 实际上是在学习一种信息过滤机制,以根据输入数据的特征动态调整信息的流动。

0193f875-8bf0-7806-ab37-494c1af37d61_6_640_98_1192_1229_0.jpg

优势

  • **动态信息选择:**门控机制允许模型根据输入特征选择性地传递信息,提高模型的灵活性。

  • **性能提升:**GLU 在自然语言处理任务中证明可以比传统激活函数(如 ReLU)获得更好的性能。

SwiGLU激活函数

SwiGLU(Swish-Gated Linear Unit)是一种结合Swish激活函数与GLU(Gated Linear Unit)机制的激活函数。其公式如下:

SwiGLU ⁡ ( x ) = ( Swish ⁡ ( W 1 x ) ) ⋅ ( W 2 x ) \operatorname{SwiGLU}\left( x\right) = \left( {\operatorname{Swish}\left( {{W}_{1}x}\right) }\right) \cdot \left( {{W}_{2}x}\right) SwiGLU(x)=(Swish(W1x))(W2x)
其中:

  • ${W}_{1} $ 和 W 2 {W}_{2} W2 分别是输入 X X X 的两个线性变换矩阵。
  • $Swish \left( z\right) = z \cdot \sigma \left( z\right) $,其中 σ ( z ) \sigma \left( z\right) σ(z) 是标准的Sigmoid函数。

该公式可以理解为:

  1. 对输入 $ x $ 进行两次线性变换。

  2. 将其中一个结果通过Swish激活函数。

  3. 将两个结果逐元素相乘,形成最终输出。

对于 β = 1 \beta=1 β=1 的 Swish函数,称之为 SiLU

image-20241224195224786

SwiGLU的计算代价与性能优势
class SwiGLU(nn.Module):

    def __init__(self, w1, w2, w3) -> None:
        super().__init__()
        self.w1 = w1
        self.w2 = w2
        self.w3 = w3

    def forward(self, x):
        x1 = F.linear(x, self.w1.weight)
        x2 = F.linear(x, self.w2.weight)
        hidden = F.silu(x1) * x2
        return F.linear(hidden, self.w3.weight)

SwiGLU激活函数需要进行三次矩阵乘法运算,相比于ReLU等传统激活函数计算复杂度更高。

然而,研究发现,尽管计算量增加, SwiGLU带来的性能提升显著:

  • **更好的泛化性能:**SwiGLU在处理复杂的文本数据时表现出更好的泛化能力。

  • **稳定的梯度流:**Swish激活函数平滑的曲线特性有助于稳定梯度,减少梯度消失和爆炸的可能性。

RoPE旋转位置编码

位置编码是Transformer中确保模型能够理解序列顺序信息的重要部分。

传统的绝对和相对位置编码方案各有优缺点,然而RoPE(Rotary Position Embedding)作为一种新型的位置编码方法,平衡了绝对和相对位置编码的优点。

位置编码的背景

位置编码在Transformer中至关重要,因为自注意力机制本质上是无序的。 常见的两种位置编码方法是:

  • 绝对位置编码 (Absolute Positional Encoding):为序列中的每个位置提供一个固定的嵌入。

  • 相对位置编码 (Relative Positional Encoding) : 表示序列中每两个token之间的相对位置信息。

绝对位置编码的特性

绝对位置编码是指在序列中的每个位置直接关联一个固定的嵌入。

例如, Transformer中常用的正弦-余弦绝对位置编码如下:
P E ( p o s , 2 i ) = sin ⁡ ( p o s 10000 2 i / d ) P E ( p o s , 2 i + 1 ) = cos ⁡ ( p o s 10000 2 i / d ) P{E}_{\left( pos,2i\right) } = \sin \left( \frac{pos}{{10000}^{{2i}/d}}\right) \\ P{E}_{\left( pos,2i + 1\right) } = \cos \left( \frac{pos}{{10000}^{{2i}/d}}\right) PE(pos,2i)=sin(100002i/dpos)PE(pos,2i+1)=cos(100002i/dpos)
其中, pos 表示序列中的位置, d 表示嵌入维度。

这种方法通过给定的位置索引为每个位置提供唯一的编码,确保模型能够理解token之间的顺序。

相对位置编码的特性

相对位置编码关注的是序列中两个token之间的相对距离,而非绝对位置。

相对位置编码可以帮助模型在不同长度的输入序列之间共享信息,具体实现方式多种多样。例如,在 Transformer-XL模型中:

A i j = Q i ⋅ K j + Q i ⋅ r i − j {A}_{ij} = {Q}_{i} \cdot {K}_{j} + {Q}_{i} \cdot {r}_{i - j} Aij=QiKj+Qirij
其中, r i − j {r}_{i - j} rij 表示相对距离的嵌入,确保模型对不同长度的输入都能够表现良好。

RoPE的原理

RoPE (Rotary Position Embedding) 结合了绝对和相对位置编码的优点。

它使用旋转矩阵对每个位置进行编码,并直接将相对位置信息引入自注意力操作中。

旋转矩阵(Rotation Matrix)

image-20241224200444195

RoPE的数学公式

假设输入向量为 ( x \in {\mathbb{R}}^{d} ) ,其第 ( i ) 个位置的编码向量可以表示为:

x i ′ = RoPE ⁡ ( x i , i ) {x}_{i}^{\prime } = \operatorname{RoPE}\left( {{x}_{i}, i}\right) xi=RoPE(xi,i)

其中:

RoPE ⁡ ( x i , i ) = R ( i ) x i \operatorname{RoPE}\left( {{x}_{i}, i}\right) = R\left( i\right) {x}_{i} RoPE(xi,i)=R(i)xi
这里的旋转矩阵 R ( i ) R\left( i\right) R(i) 用二维旋转矩阵的张量乘积进行定义。

将向量 x i x_i xi 拆分为一系列长度为2的子向量 ( x i , 2 k , x i , 2 k + 1 ) \left( {{x}_{i,{2k}},{x}_{i,{2k} + 1}}\right) (xi,2k,xi,2k+1),

其旋转形式为:

R ( i ) [ x i , 2 k x i , 2 k + 1 ] = [ cos ⁡ ( θ k i ) − sin ⁡ ( θ k i ) sin ⁡ ( θ k i ) cos ⁡ ( θ k i ) ] [ x i , 2 k x i , 2 k + 1 ] R\left( i\right) \left\lbrack \begin{matrix} {x}_{i,{2k}} \\ {x}_{i,{2k} + 1} \end{matrix}\right\rbrack = \left\lbrack \begin{matrix} \cos \left( {{\theta }_{k}i}\right) & - \sin \left( {{\theta }_{k}i}\right) \\ \sin \left( {{\theta }_{k}i}\right) & \cos \left( {{\theta }_{k}i}\right) \end{matrix}\right\rbrack \left\lbrack \begin{matrix} {x}_{i,{2k}} \\ {x}_{i,{2k} + 1} \end{matrix}\right\rbrack R(i)[xi,2kxi,2k+1]=[cos(θki)sin(θki)sin(θki)cos(θki)][xi,2kxi,2k+1]
其中:

θ k = 10000 − 2 k / d π {\theta }_{k} = \frac{{10000}^{-{2k}/d}}{\pi } θk=π100002k/d

经过RoPE编码后的输入向量与旋转矩阵结合,使得位置信息被直接嵌入到输入向量中。

image-20241224200628304

0193f875-8bf0-7806-ab37-494c1af37d61_9_319_1628_1835_1030_0.jpg

RoPE的相对位置编码特性

在自注意力计算中, RoPE能够引入相对位置信息。

令 ${q}_{i} $ 和 k j {k}_{j} kj分别表示位置 i i i j j j 处的查询和键向量,经过RoPE编码后,其点积为:

⟨ q i ′ , k j ′ ⟩ = ⟨ R ( i ) q i , R ( j ) k j ⟩ = ⟨ q i , R ( j − i ) k j ⟩ \left\langle {{q}_{i}^{\prime },{k}_{j}^{\prime }}\right\rangle = \left\langle {R\left( i\right) {q}_{i}, R\left( j\right) {k}_{j}}\right\rangle = \left\langle {{q}_{i}, R\left( {j - i}\right) {k}_{j}}\right\rangle qi,kj=R(i)qi,R(j)kj=qi,R(ji)kj
这意味着,通过RoPE编码,查询和键向量之间的相对位置信息直接被融入自注意力操作中。 RoPE (Rotary Position Embedding) 结合了绝对和相对位置编码的优点,因为它既能够编码绝对位置信息,又能够在自注意力操作中有效融入相对位置信息。

RoPE如何结合绝对和相对位置编码

RoPE通过旋转矩阵为每个位置编码,实现了绝对位置信息的嵌入,同时在自注意力操作中有效地引入了相对位置信息。

RoPE的编码过程

给定一个嵌入维度为 ( \mathrm{d} ) 的输入向量 ( x \in {\mathbb{R}}^{d} ) ,其位置编码可以表示为:

x i ′ = RoPE ⁡ ( x i , i ) = R ( i ) ⋅ x i {x}_{i}^{\prime } = \operatorname{RoPE}\left( {{x}_{i}, i}\right) = R\left( i\right) \cdot {x}_{i} xi=RoPE(xi,i)=R(i)xi
其中,旋转矩阵 $ R\left( i\right) $ 用二维旋转矩阵的张量乘积形式定义。

假设 $ {x}{i} $ 可以分解为一系列长度为2的子向量 $ \left( {{x}{i,{2k}},{x}_{i,{2k} + 1}}\right)$ ,其旋转矩阵为:

R ( i ) [ x i , 2 k x i , 2 k + 1 ] = [ cos ⁡ ( θ k ⋅ i ) − sin ⁡ ( θ k ⋅ i ) sin ⁡ ( θ k ⋅ i ) cos ⁡ ( θ k ⋅ i ) ] [ x i , 2 k x i , 2 k + 1 ] R\left( i\right) \left\lbrack \begin{matrix} {x}_{i,{2k}} \\ {x}_{i,{2k} + 1} \end{matrix}\right\rbrack = \left\lbrack \begin{matrix} \cos \left( {{\theta }_{k} \cdot i}\right) & - \sin \left( {{\theta }_{k} \cdot i}\right) \\ \sin \left( {{\theta }_{k} \cdot i}\right) & \cos \left( {{\theta }_{k} \cdot i}\right) \end{matrix}\right\rbrack \left\lbrack \begin{matrix} {x}_{i,{2k}} \\ {x}_{i,{2k} + 1} \end{matrix}\right\rbrack R(i)[xi,2kxi,2k+1]=[cos(θki)sin(θki)sin(θki)cos(θki)][xi,2kxi,2k+1]
其中:

θ k = 10000 − 2 k / d π {\theta }_{k} = \frac{{10000}^{-{2k}/d}}{\pi } θk=π100002k/d
这个矩阵表示每个位置 i i i 的绝对编码。

RoPE引入相对位置编码

经过RoPE编码的查询和键向量,在自注意力机制中的点积计算如下:

⟨ q i ′ , k j ′ ⟩ = ⟨ R ( i ) q i , R ( j ) k j ⟩ = ⟨ q i , R ( j − i ) k j ⟩ \left\langle {{q}_{i}^{\prime },{k}_{j}^{\prime }}\right\rangle = \left\langle {R\left( i\right) {q}_{i}, R\left( j\right) {k}_{j}}\right\rangle = \left\langle {{q}_{i}, R\left( {j - i}\right) {k}_{j}}\right\rangle qi,kj=R(i)qi,R(j)kj=qi,R(ji)kj
其中:

R ( j − i ) = [ cos ⁡ ( θ k ⋅ ( j − i ) ) − sin ⁡ ( θ k ⋅ ( j − i ) ) sin ⁡ ( θ k ⋅ ( j − i ) ) cos ⁡ ( θ k ⋅ ( j − i ) ) ] R\left( {j - i}\right) = \left\lbrack \begin{matrix} \cos \left( {{\theta }_{k} \cdot \left( {j - i}\right) }\right) & - \sin \left( {{\theta }_{k} \cdot \left( {j - i}\right) }\right) \\ \sin \left( {{\theta }_{k} \cdot \left( {j - i}\right) }\right) & \cos \left( {{\theta }_{k} \cdot \left( {j - i}\right) }\right) \end{matrix}\right\rbrack R(ji)=[cos(θk(ji))sin(θk(ji))sin(θk(ji))cos(θk(ji))]

这意味着RoPE在计算查询和键之间的点积时,将相对位置信息直接融入了自注意力操作中。

RoPE的优点
  1. **引入相对位置信息:**RoPE能够在自注意力操作中直接编码相对位置信息,使模型具有较好的相对位置感知能力。

  2. **保持绝对位置信息:**RoPE的旋转矩阵编码每个位置的绝对信息,保持绝对位置感知能力。

  3. **高效处理长序列:**在长序列任务中, RoPE相较于其他位置编码方案表现出更高的效率和性能。

RoPE通过旋转矩阵将绝对位置与相对位置信息相结合,既提供了绝对位置信息的精确性,又具备相对位置感知的灵活性,成为LLaMA等大型语言模型中重要的位置编码方法。

RoPE的优化与实现

image-20241224200941087

优化策略
  1. **分块并行处理:**RoPE可以通过分块并行化矩阵运算来提高效率。

  2. **预计算旋转矩阵:**对于固定序列长度的任务,可以预计算并缓存旋转矩阵,减少实时计算量。

image-20241224200956950

image-20241224201012592

image-20241224201025758

分组查询注意力(GQA)

分组查询注意力(Grouped Query Attention,简称GQA)作为一种改进的多头自注意力机制,最近在 Llama 3模型中得到了应用。

GQA旨在平衡多头自注意力和多查询注意力的优缺点,以提高推理效率并保持高性能。

image-20241224201055789

多头自注意力机制 (Multi-Head Self-Attention)

多头自注意力机制是Transformer模型的核心组件,旨在通过并行化的方式捕获序列中的复杂关系。其计算公式如下:

  1. 查询(Query)、键(Key)和值(Value)向量:

Q = X W Q ,    K = X W K ,    V = X W V Q = X{W}_{Q},\;K = X{W}_{K},\;V = X{W}_{V} Q=XWQ,K=XWK,V=XWV

其中, W Q {W}_{Q} WQ W K {W}_{K} WK W V {W}_{V} WV 分别是查询、键和值的投影矩阵。

  1. 自注意力计算:

A = softmax ⁡ ( Q K T d k ) V A = \operatorname{softmax}\left( \frac{Q{K}^{T}}{\sqrt{{d}_{k}}}\right) V A=softmax(dk QKT)V

其中, d k d_k dk 为查询和键向量的维度。

  1. 多头组合:

Multi-Head ⁡ ( Q , K , V ) = Concat ⁡ ( head ⁡ 1 , … , head ⁡ h ) W O \operatorname{Multi-Head}\left( {Q, K, V}\right) = \operatorname{Concat}\left( {{\operatorname{head}}_{1},\ldots ,{\operatorname{head}}_{h}}\right) {W}_{O} Multi-Head(Q,K,V)=Concat(head1,,headh)WO

其中, $ {\text{head}}{i} = Attention \left( {Q{W}{{Q}{i}}, K{W}{{K}{i}}, V{W}{{V}_{i}}}\right) $ , W O W_O WO 是输出矩阵。

多查询注意力机制 (Multi-Query Attention)

多查询注意力机制简化了多头自注意力,将所有注意力头共享相同的键和值投影,从而减少计算和内存开销。其结构如下:

  1. 查询向量:

Q = X W Q Q = X{W}_{Q} Q=XWQ

  1. 共享的键和值向量:

K = X W K ,    V = X W V K = X{W}_{K},\;V = X{W}_{V} K=XWK,V=XWV

  1. 自注意力计算:

A i = softmax ⁡ ( Q W Q i K T d k ) V {A}_{i} = \operatorname{softmax}\left( \frac{Q{W}_{{Q}_{i}}{K}^{T}}{\sqrt{{d}_{k}}}\right) V Ai=softmax(dk QWQiKT)V

  1. 组合:

Multi-Query ⁡ ( Q , K , V ) = Concat ⁡ ( A 1 , … , A h ) W O \operatorname{Multi-Query}\left( {Q, K, V}\right) = \operatorname{Concat}\left( {{A}_{1},\ldots ,{A}_{h}}\right) {W}_{O} Multi-Query(Q,K,V)=Concat(A1,,Ah)WO

GQA (Grouped Query Attention) 的原理

GQA结合了多头自注意力和多查询注意力的优点,通过分组的方式共享键和值投影。

其结构如下:

  1. 查询向量:

Q = X W Q Q = X{W}_{Q} Q=XWQ

  1. 分组的键和值向量:

将总计 ( N ) 个注意力头划分为 ( G ) 组,每组共享相同的键和值投影:

K g = X W K g ,    V g = X W V g ,    g = 1 , … , G {K}_{g} = X{W}_{{K}_{g}},\;{V}_{g} = X{W}_{{V}_{g}},\;g = 1,\ldots , G Kg=XWKg,Vg=XWVg,g=1,,G

  1. 组内自注意力计算:

A g , i = softmax ⁡ ( Q W Q g , i K g T d k ) V g {A}_{g, i} = \operatorname{softmax}\left( \frac{Q{W}_{{Q}_{g, i}}{K}_{g}^{T}}{\sqrt{{d}_{k}}}\right) {V}_{g} Ag,i=softmax(dk QWQg,iKgT)Vg

  1. 组合:

Grouped-Query ( Q , K , V ) = Concat ⁡ ( A 1 , 1 , … , A G , h / G ) W O \text{Grouped-Query}\left( {Q, K, V}\right) = \operatorname{Concat}\left( {{A}_{1,1},\ldots ,{A}_{G, h/G}}\right) {W}_{O} Grouped-Query(Q,K,V)=Concat(A1,1,,AG,h/G)WO

其中:

  • $ {W}{{Q}{g, i}} $为每个查询头的投影矩阵

  • ${W}{{K}{g}} 和 {W}{{V}{g}} $ 为每组共享的键和值投影矩阵

GQA的优势
  • **计算效率:**减少了键和值的计算和内存开销。

  • **性能保持:**在性能上与多头自注意力相当,同时在效率上接近多查询注意力。

KVCache

在大型语言模型(LLM)的推理和训练过程中,缓存机制对于提高模型的效率和性能至关重要。

KVCache(Key-Value Cache)是一种用于缓存自注意力机制中键和值的优化策略,可以显著加速推理过程。

image-20241224201919132

image-20241224201922903

自注意力机制中的键和值

在Transformer的自注意力机制中,键和值的计算是核心步骤:

  1. 输入向量投影:
  • 查询向量 Q = X W Q Q = X{W}_{Q} Q=XWQ

  • 键向量 $K = X{W}_{K} $

  • 值向量 V = X W V V = X{W}_{V} V=XWV

  1. 自注意力计算:

A = softmax ⁡ ( Q K T d k ) V A = \operatorname{softmax}\left( \frac{Q{K}^{T}}{\sqrt{{d}_{k}}}\right) V A=softmax(dk QKT)V

其中, d k d_k dk 为键向量的维度。

多头自注意力机制

多头自注意力机制将上述计算扩展为并行计算多个注意力头:

Multi-Head ⁡ ( Q , K , V ) = Concat ⁡ ( head ⁡ 1 , … , head ⁡ h ) W O \operatorname{Multi-Head}\left( {Q, K, V}\right) = \operatorname{Concat}\left( {{\operatorname{head}}_{1},\ldots ,{\operatorname{head}}_{h}}\right) {W}_{O} Multi-Head(Q,K,V)=Concat(head1,,headh)WO
每个注意力头使用不同的投影矩阵,捕获序列中的不同关系。

推理过程中的重复计算

在推理过程中,模型可能会接收一系列的连续输入。

对于每个新的输入,模型会重新计算键和值,导致大量的重复计算。

这种情况在长文本生成或对话任务中尤其明显。

KVCache的原理

KVCache旨在减少自注意力机制中的重复计算。

它通过缓存先前计算的键和值,在推理过程中复用这些信息,从而显著提高推理效率。

KVCache的工作流程
  1. **初始化缓存:**初始化一个空的键和值缓存。

  2. 计算当前输入的键和值:

K t = X t W K ,    V t = X t W V {K}_{t} = {X}_{t}{W}_{K},\;{V}_{t} = {X}_{t}{W}_{V} Kt=XtWK,Vt=XtWV

  1. **更新缓存:**将当前输入的键和值追加到缓存中。

  2. **自注意力计算:**在新的键和值上进行自注意力计算:

A = softmax ⁡ ( Q K cache  T d k ) V cache  A = \operatorname{softmax}\left( \frac{Q{K}_{\text{cache }}^{T}}{\sqrt{{d}_{k}}}\right) {V}_{\text{cache }} A=softmax(dk QKcache T)Vcache 

  1. 清理缓存:

在某些情况下(如达到上下文长度上限),需要清理缓存或进行部分替换。

KVCache在LIama 3中的应用

Llama 3作为Meta发布的最新大型语言模型,采用了KVCache机制以提高推理效率。

Llama 3通过引入KVCache机制,将推理过程中的键和值缓存下来,在后续的注意力计算中复用这些缓存,以减少重复计算。

  • **键和值缓存:**在每一层的自注意力机制中,键和值的计算结果都会被缓存。

  • **缓存更新策略:**每次新的输入到来时,缓存会根据新的键和值进行更新,确保自注意力计算的准确性。

  • **缓存清理:**当达到上下文长度上限时,缓存会被部分或全部清理。

KVCache作为一种优化策略,通过缓存自注意力机制中的键和值,有效减少了推理过程中的重复计算。

LLM文本生成的两个阶段:预填充(prefill)和解码(decode)

在大型语言模型(LLM)生成文本的过程中,通常会涉及两个阶段:预填充(prefill)和解码(decode) 。

预填充(Prefill)阶段

预填充阶段主要用于准备初始上下文,为模型生成后续文本提供基础。

在这个阶段,模型会处理输入的初始文本(通常是用户提供的提示或上下文),并生成相应的内部状态 (如KV缓存)。

具体过程

  1. 初始上下文输入:
  • 用户提供一个初始文本作为提示,例如问题、句子或段落。
  • 该文本通常被编码为一系列token(tokens)。
  1. 初始上下文的注意力计算:
  • 模型对输入的所有token进行处理,生成键和值,并存储在KV缓存中。 3. 初始概率分布:

  • 计算最后一个token的概率分布,作为生成下一个token的起始点。

解码 (Decode) 阶段

解码阶段是模型根据预填充阶段准备好的上下文,生成后续文本的过程。

在这个阶段,模型会逐步生成每个新token,并在每次生成后更新KV缓存。

具体过程

  1. token生成:
  • 从预填充阶段的初始概率分布中采样一个新token,作为下一个输入。

  • 根据用户选择的策略(如贪心、采样、温度)进行采样。

  1. KV缓存更新:
  • 使用新token进行前向传递,更新KV缓存,生成新的键和值。

  • 生成新的概率分布,用于采样下一个token。

  1. 循环迭代:
  • 重复步骤1和步骤2,直到达到生成的最大长度或满足停止条件。

下面详细解释这两个阶段KVCache的大小和更新方法,以及它们的不同之处。

预填充阶段的KVCache

大小和布局

在预填充阶段, KVCache的大小根据输入的上下文长度和模型层数确定。

  1. 键和值的形状:
keys: (num_layers, batch_size, num_heads, seq_len, head_dim)
values: (num_layers, batch_size, num_heads, seq_len, head_dim)

其中:

  • num_layers:模型的总层数

  • batch_size:输入的批次大小

  • num_heads : 多头注意力机制的头数

  • seq_len: 输入序列的长度

  • head_dim:每个头的嵌入维度

  1. KVCache大小示例:

假设一个模型有 12 层,每层有 16 个头,每个头的维度为 64 ,批次大小为 1 ,上下文长度为512 ,则KVCache的大小如下:

  • 键缓存大小:(12,1,16,512,64)

  • 值缓存大小:(12,1,16,512,64)

更新方法

预填充阶段, KVCache会随着输入的每个token进行填充和更新。

初始填充: 对于给定的输入上下文, 模型计算每一层的键和值向量, 并将它们存储到缓存中

解码阶段的KVCache

大小和布局

在解码阶段, KVCache的大小主要取决于上下文的长度和逐步生成的长度。

与预填充阶段不同的是,解码阶段的KVCache随着生成的序列不断增长。

  1. 键和值的形状:
keys: (num_layers, batch_size, num_heads, max_seq_len, head_dim)
values: (num_layers, batch_size, num_heads, max_seq_len, head_dim)

其中:

  • max_seq_len: 预填充序列长度 + 生成序列长度
  1. KVCache大小示例:

假设初始输入长度为 512 ,逐步生成长度为 100 ,则KVCache的最大长度为 612。

  • 键缓存大小:(12,1,16,612,64)

  • 值缓存大小:(12,1,16,612,64)

总结
  • KVCache的大小:

    • 预填充阶段:缓存的大小取决于初始输入的长度。
    • 解码阶段:缓存的大小随着生成序列长度的增长而不断增加。
  • KVCache的更新方式:

    • 预填充阶段:一次性将整个初始上下文填充到缓存中。
    • 解码阶段:逐步生成并更新缓存,每次生成新token时追加新的键和值。

逐步生成长度(Incremental Generation Length)指的是语言模型在解码阶段, 根据输入的初始文本(如提示词prompt), 一个token一个token地逐步生成输出序列的过程。

在解码开始时, 语言模型只有输入的初始文本,如"用英文写一首关于春天的诗"。然后模型根据这个输入, 预测并生成下一个最可能的token,如"Spring"。之后"Spring"会被添加到之前的输入中, 形成新的输入序列,模型再根据这个更新的输入去预测下一个token,如"is"。如此重复,直到满足某个停止条件(如遇到截止符<\s>, 或达到最大生成长度), 整个解码过程就结束了。

这个过程中, 生成的token序列(如"Spring is coming")的长度, 就被称为逐步生成长度。它一般远小于模型的最大序列长度限制。比如最大序列长度为2048的模型, 在实际解码时, 往往只需要生成几十到几百个 token就可以得到满意的结果。

在解码阶段引入键值缓存(KV Cache), 主要就是为了提高处理这些逐步生成token的效率。随着解码的进行, 每生成一个新的token, 就动态地将其添加到KV Cache中, 更新键值对。这样模型就可以得到到不断增长的已生成序列,从而实现循环机制。

所以总结来说, 逐步生成长度就是语言模型解码时动态生成token序列的长度, 反映了解码的进度。

合理设置逐步生成长度, 优化解码算法和KV Cache机制, 对于提升语言模型推理速度和生成质量至关重要。

总结

使用KV Cache,在 token-by-token 推理中:

  • 初始token 的计算复杂度为 O ( d 2 ) O\left( {d}^{2}\right) O(d2)

  • 后续每个token 的计算复杂度为 O ( N d ) O\left( {Nd}\right) O(Nd) ,而不是重新计算整个序列,从而避免了二次计算复杂度。

其中 N N N 是序列长度, d d d 是每个序列元素的维度。

因此,通过使用KV Cache,可以将token-by-token推理的计算复杂度从 O ( N 2 d ) O\left({{N}^{2}d}\right) O(N2d) 降低到接近线性关系,

这种优化使得大模型能够处理更长的序列,并在推理时显著提高了效率。

需要注意的是,这里讨论的是推理阶段的计算复杂度。

在训练阶段,自注意力的计算复杂度仍然是 O ( N 2 ) O\left({{N}^{2}}\right) O(N2),但是可以对整个序列进行并行计算

在实际应用中, 推理阶段的效率往往更为重要, 因为它直接影响了模型的响应速度和吞吐量。


http://www.kler.cn/a/463572.html

相关文章:

  • Golang的容器编排实践
  • ubuntu切换到root用户
  • “AI智慧教学系统:开启个性化教育新时代
  • MT8788安卓核心板_MTK8788核心板参数_联发科模块定制开发
  • websocket-sharp:.NET平台上的WebSocket客户端与服务器开源库
  • 【GO基础学习】gin的使用
  • springboot520基于Spring Boot的民宿租赁系统的设计与实现(论文+源码)_kaic
  • 安卓入门四 Application Component
  • ubuntu2204 gpu 没接显示器,如何连接vnc
  • JnetPcap抓取数据包IP数据包
  • 3、redis的集群模式
  • selenium 安装Chrome驱动
  • 用点包图洞察医学数据:以血压分析为例
  • 服务器网卡绑定mode和交换机的对应关系
  • MySQL 索引优化实战 – 结合 Explain 深度解析慢查询
  • REST与RPC的对比:从性能到扩展性的全面分析
  • 非关系型数据库和关系型数据库的区别
  • 免登录游客卡密发放系统PHP网站源码
  • Excel 面试 01 “Highlight in red the 10 lowest orders”
  • 如何查看下载到本地的大模型的具体大小?占了多少存储空间:Llama-3.1-8B下载到本地大概15GB
  • 系统架构风险、敏感点和权衡点的理解
  • MySQL数据库笔记——主从复制
  • Redis 实战篇 ——《黑马点评》(上)
  • 关于内网服务器依托可上网电脑实现访问互联网
  • MyBatis使用的设计模式
  • 【每日学点鸿蒙知识】输入法按压效果、web组件回弹、H5回退问题、Flex限制两行、密码输入自定义样式