【AI学习】Transformer深入学习(二):从MHA、MQA、GQA到MLA
前面文章:
《Transformer深入学习(一):Sinusoidal位置编码的精妙》
一、MHA、MQA、GQA
为了降低KV cache,MQA、GQA作为MHA的变体,很容易理解。
多头注意力(MHA):
多头注意力是一种在Transformer架构中广泛使用的注意力机制,通过将查询、键和值分别投影到多个不同的空间上,然后并行计算这些空间上的注意力得分,从而获得更加丰富和细致的特征表示。
多查询注意力(MQA):
多查询注意力是MHA的一种变种,它通过共享单个key和value头来提升性能,但可能会导致质量下降和训练不稳定。MQA在保持速度的同时提高了模型的推理效率,但在某些情况下可能无法达到与MHA相同的效果。
分组查询注意力(GQA):
分组查询注意力是MQA和MHA之间的过渡方法,旨在同时保持MQA的速度和MHA的质量。GQA通过使用中间数量的键值头(大于一个,小于查询头的数量),实现了性能和速度的平衡。具体来说,GQA通过分组的方式减少了需要处理的头数,从而降低了内存需求和计算复杂度。
分组查询注意力(Grouped-Query Attention,简称GQA)是一种用于提高大模型推理可扩展性的机制。其具体实现机制如下:
1、基本概念:GQA是多头注意力(Multi-Head Attention,MHA)的变种,通过将查询头(query heads)分成多个组来减少内存带宽的需求。每个组共享一个键头(key head)和一个值头(value head),从而降低了每个处理步骤中加载解码器权重和注意力键/值的内存消耗。
2、实现方式:在实际应用中,GQA将查询头分成G个组,每组共享一个键头和一个值头。例如,GQA-G表示有G个组,而GQA-1则表示只有一个组,这相当于传统的MQA(Multi-Group Query Attention)。当GQA的组数等于查询头的数量时,它等同于标准的MHA。
3、性能与效率平衡:GQA通过这种方式有效地平衡了性能和内存需求。它允许模型在不显著降低性能的情况下,处理更多的请求并提高推理效率。此外,使用GQA可以避免由于加载大量解码器权重和注意力键/值而导致的内存瓶颈问题
二、MLA
2.1 基础原理
这张图,对从MHA、MQA、GQA到MLA,看的很清楚。
GQA就是用了多组KV Cahe,MQA只用了一组KV Cache。
那MLA呢?MLA看起来和MHA是一样的,只不过存的压缩后的隐KV,在计算的时候再通过投影倒多个KV参与注意力计算。
为什么会节省KV Cache?苏神的文章解释的很清楚。
看下面的公式,MLA公式如下:
其中的c就是压缩后的隐KV。
但是这样好像无法节省KV Cache,因为计算和MHA一样了,关键在于下面的转换公式:
这个公式把注意力的计算做了转换,k的投影矩阵这样就可以合并倒q的投影矩阵中。
另外,因为注意力之后的o还有一个投影矩阵,也可以合并到后面的投影矩阵中。
而c作为压缩后的隐KV,是所有头共享的,这样就实现了内存的节省。
2.2 增加RoPE
但是,如上面,矩阵合并之后,就和RoPE不兼容了,具体看苏神的分析文章。
MLA采取了一种混合的方法——每个 Attention Head的 Q、K 新增 dr个维度用来添加 RoPE,其中 K 新增的维度每个 Head 共享:
因为dr远小于dk,所以增加的内存空间不大。
2.3 最后的版本
MLA 的最终版本,还将 Q 的输入也改为了低秩投影形式,可以减少训练期间参数量和相应的梯度的显存。
MLA这种方法,在训练阶段还是照常进行,此时优化空间不大;在推理阶段,应该可以大幅减少显存。
见苏神的分析:“ MLA 在推理阶段做的这个转换,虽然能有效减少 KV Cache,但其推理的计算量是增加的。
那为什么还能提高推理效率呢?这又回到“瓶颈”一节所讨论的问题了,我们可以将 LLM 的推理分两部分:第一个 Token 的生成(Prefill)和后续每个 Token 的生成(Generation)。
Prefill 阶段涉及到对输入所有 Token 的并行计算,然后把对应的 KV Cache 存下来,这部分对于计算、带宽和显存都是瓶颈,MLA 虽然增大了计算量,但 KV Cache 的减少也降低了显存和带宽的压力,大家半斤八两;但是 Generation 阶段由于每步只计算一个 Token,实际上它更多的是带宽瓶颈和显存瓶颈,因此 MLA 的引入理论上能明显提高 Generation 的速度。”
三、参考文章
苏神:《缓存与效果的极限拉扯:从MHA、MQA、GQA到MLA》
https://mp.weixin.qq.com/s/yCczYU0po0PvPTa-eh2pfg
《大模型KV Cache节省神器MLA学习笔记》
https://mp.weixin.qq.com/s/cBMrRUdM1IM0T1ji_ODxng
《注意力机制的变体之MLA》
https://mp.weixin.qq.com/s/dWZk8TBY89re207ZL3GjfA