当前位置: 首页 > article >正文

【AI学习】Transformer深入学习(二):从MHA、MQA、GQA到MLA

前面文章:
《Transformer深入学习(一):Sinusoidal位置编码的精妙》

一、MHA、MQA、GQA

为了降低KV cache,MQA、GQA作为MHA的变体,很容易理解。
多头注意力(MHA):
多头注意力是一种在Transformer架构中广泛使用的注意力机制,通过将查询、键和值分别投影到多个不同的空间上,然后并行计算这些空间上的注意力得分,从而获得更加丰富和细致的特征表示。

多查询注意力(MQA)
多查询注意力是MHA的一种变种,它通过共享单个key和value头来提升性能,但可能会导致质量下降和训练不稳定。MQA在保持速度的同时提高了模型的推理效率,但在某些情况下可能无法达到与MHA相同的效果。

分组查询注意力(GQA)
分组查询注意力是MQA和MHA之间的过渡方法,旨在同时保持MQA的速度和MHA的质量。GQA通过使用中间数量的键值头(大于一个,小于查询头的数量),实现了性能和速度的平衡。具体来说,GQA通过分组的方式减少了需要处理的头数,从而降低了内存需求和计算复杂度。

分组查询注意力(Grouped-Query Attention,简称GQA)是一种用于提高大模型推理可扩展性的机制。其具体实现机制如下:

1、基本概念:GQA是多头注意力(Multi-Head Attention,MHA)的变种,通过将查询头(query heads)分成多个组来减少内存带宽的需求。每个组共享一个键头(key head)和一个值头(value head),从而降低了每个处理步骤中加载解码器权重和注意力键/值的内存消耗。

2、实现方式:在实际应用中,GQA将查询头分成G个组,每组共享一个键头和一个值头。例如,GQA-G表示有G个组,而GQA-1则表示只有一个组,这相当于传统的MQA(Multi-Group Query Attention)。当GQA的组数等于查询头的数量时,它等同于标准的MHA。

3、性能与效率平衡:GQA通过这种方式有效地平衡了性能和内存需求。它允许模型在不显著降低性能的情况下,处理更多的请求并提高推理效率。此外,使用GQA可以避免由于加载大量解码器权重和注意力键/值而导致的内存瓶颈问题

二、MLA

2.1 基础原理

在这里插入图片描述
这张图,对从MHA、MQA、GQA到MLA,看的很清楚。
GQA就是用了多组KV Cahe,MQA只用了一组KV Cache。
那MLA呢?MLA看起来和MHA是一样的,只不过存的压缩后的隐KV,在计算的时候再通过投影倒多个KV参与注意力计算。
为什么会节省KV Cache?苏神的文章解释的很清楚。
看下面的公式,MLA公式如下:
在这里插入图片描述

其中的c就是压缩后的隐KV。
但是这样好像无法节省KV Cache,因为计算和MHA一样了,关键在于下面的转换公式:
在这里插入图片描述
这个公式把注意力的计算做了转换,k的投影矩阵这样就可以合并倒q的投影矩阵中。
另外,因为注意力之后的o还有一个投影矩阵,在这里插入图片描述也可以合并到后面的投影矩阵中。
c作为压缩后的隐KV,是所有头共享的,这样就实现了内存的节省

2.2 增加RoPE

但是,如上面,矩阵合并之后,就和RoPE不兼容了,具体看苏神的分析文章。
MLA采取了一种混合的方法——每个 Attention Head的 Q、K 新增 dr个维度用来添加 RoPE,其中 K 新增的维度每个 Head 共享:
在这里插入图片描述
因为dr远小于dk,所以增加的内存空间不大。

2.3 最后的版本

MLA 的最终版本,还将 Q 的输入也改为了低秩投影形式,可以减少训练期间参数量和相应的梯度的显存。
在这里插入图片描述
MLA这种方法,在训练阶段还是照常进行,此时优化空间不大;在推理阶段,应该可以大幅减少显存。
见苏神的分析:“ MLA 在推理阶段做的这个转换,虽然能有效减少 KV Cache,但其推理的计算量是增加的。
那为什么还能提高推理效率呢?这又回到“瓶颈”一节所讨论的问题了,我们可以将 LLM 的推理分两部分:第一个 Token 的生成(Prefill)和后续每个 Token 的生成(Generation)。
Prefill 阶段涉及到对输入所有 Token 的并行计算,然后把对应的 KV Cache 存下来,这部分对于计算、带宽和显存都是瓶颈,MLA 虽然增大了计算量,但 KV Cache 的减少也降低了显存和带宽的压力,大家半斤八两;但是 Generation 阶段由于每步只计算一个 Token,实际上它更多的是带宽瓶颈和显存瓶颈,因此 MLA 的引入理论上能明显提高 Generation 的速度。”

三、参考文章

苏神:《缓存与效果的极限拉扯:从MHA、MQA、GQA到MLA》
https://mp.weixin.qq.com/s/yCczYU0po0PvPTa-eh2pfg

《大模型KV Cache节省神器MLA学习笔记》
https://mp.weixin.qq.com/s/cBMrRUdM1IM0T1ji_ODxng

《注意力机制的变体之MLA》
https://mp.weixin.qq.com/s/dWZk8TBY89re207ZL3GjfA


http://www.kler.cn/a/467192.html

相关文章:

  • 算法:两个升序单链表的合并
  • 【计算机网络】什么是AC和AP?
  • Which CAM is Better for Extracting Geographic Objects? A Perspective From参考文献
  • ArcGIS Server 10.2授权文件过期处理
  • 百度Android最新150道面试题及参考答案 (中)
  • 机器人对物体重定向操作的发展简述
  • 阿里云-通义灵码:在 PyCharm 中的强大助力(下)
  • 急需升级,D-Link 路由器漏洞被僵尸网络广泛用于 DDoS 攻击
  • GPIO、RCC库函数
  • 104周六复盘 (188)UI
  • perl包安装的CPAN大坑
  • SQL-【DDL+DML】
  • 30分钟学会HTML
  • vscode下载vetur和vue-helper插件之后删除键(backspace)失效
  • Java十六
  • 【Web】极简快速入门Vue 3
  • 05-spring-理-bean的生命周期
  • RuoYi-Vue从http升级为https(Jar+Nginx)
  • 金毛可以穷养吗?
  • GESP真题 | 2024年12月1级-编程题4《美丽数字》及答案(Python版)
  • SpringBoot框架开发中常用的注解
  • 工具学习_社区检测算法
  • 基于gin一个还算比较优雅的controller实现
  • 深度学习-80-大语言模型LLM之基于streamlit与ollama的API开发本地聊天工具
  • 使用MySQL SLES存储库安装MYSQL
  • 计算机网络:网络层知识点及习题(二)