当前位置: 首页 > article >正文

LLM中的Attention实现及优化

作者:phynlp
原文:https://zhuanlan.zhihu.com/p/15348185464

Multi-Head Attention

图片

MHA 原理示意图

Attention的计算复杂度与文本长度的二次方成正比,相关的计算过程如下。

1.Embedding lookup table: 输入文本长度为n(n个token),经过embedding table后,每一个token返回一个大小为(1, d)的向量,d对应embedding的维度大小。对于长度为n的文本,embedding table给出embedding matrix矩阵 大小为(n, d)

2.MHA (Multi head Attention):

  • • 作用:MHA (Multi Head Attention layer)的目标在于重构文本中的token的embedding表示,重构的出发点在于(1)考虑context中token的语义相关性(2)token的位置相关性(通过position embedding 体现)。

  • • 映射计算:embedding matrix X 进入MHA后,会并行处理h个head的attention layer。处理过程如下,X通过线性映射为h个head,得到Attention heads的Q、K、V。对应的维度信息为Q(h, n, k)、K(h, n, k) 和V(h, n, v),其中hk=hv=O(d)。实现过程相当于X(n, d)与h个维度为(d, k)、(d, k)和(d, v)的矩阵相乘得到,对应的时间复杂为 。矩阵相乘(a, b)(b, c)的时间复杂度为  ,h对应矩阵个数也就是head个数。

  • • Attention计算:简单来讲,n个token每个都要与其他token进行Attention的计算,时间复杂度为 O(n^2) 。详细的计算过程为: QK转置相乘(经过 Softmax 归一)然后与V相乘,获取最终的embedding。Q(h, n, k)、K(h, n, k)计算复杂度为 。attention score维度为(h,n,n),与V(h, n, v)相乘的时间复杂度为

通过上述分析可以看出MHA的整体复杂度为  。

MHA的整体复杂度与context 长度 n的二次方成正比,与模型的规模d(embedding size)的二次方成正比。

增大context的长度,会带来计算复杂度的二次方增大。

Attention实现机制优化

Multi-Query Attention (MQA)

对于multi-head attention,每个head对应的k矩阵和v矩阵不同,所以对于每个token都有h(head数目)个k矩阵和v矩阵。

在模型推理的过程中,为了防止重新计算,会缓存之前token对应的Keys和Values。因此GPU显存占用会随着预测的token数目而增加。

Multi-Query Attention 通过在不同head中共享K和V,即不同的head具有相同的key和value,降低了存储的k矩阵和v矩阵的数目,对于每个token存储的matrix数目由2h个,降低为两个matrix。同事也降低了计算复杂度。

Multi-Query Attention极大的提高了推理速度。

Group Query Attention

图片

 

Group Query Attention是对所有head的Query分组为不同的group,对一个group内的query,共享key和value。GQA的效果与MHA的效果相当,训练速度与MQA相当,提高了训练速度的同时,效果相比MQA有提高。

GQA 的实现

图片

 

# init时k和v用self.num_key_value_heads * self.head_dim初始化,当self.num_key_value_heads小于self.num_heads时,参数量变少
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)

# forward时,通过repeat_kv方法,将hidden states 从(batch, num_key_value_heads, seqlen, head_dim) 变成 (batch, num_attention_heads, seqlen, head_dim),相当于是复制了self.num_key_value_groups份
self.num_key_value_groups = self.num_heads // self.num_key_value_heads

key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)

内存开销计算

使用 MHA 结构的自回归模型,在推理过程中,会维护一个巨大的 k/v cache。它的内存开销公式为:

batch * max_seq_len * n_heads * head_dim * sizeof(half) * 2

而对于 GQA 来说,k/v cache 的内存开销公式变成:

batch * max_seq_len * n_kv_heads * head_dim * sizeof(half) * 2

n_heads / n_kv_heads 就是 group 的大小。可见,使用 GQA 可以把 k/v cache 降低到 MHA的1/group 的水平。非常利好 Attention 这种访存密集型的计算。

SWA (Sliding Window Attention)

通过优化attention的实现,降低attention与context length的长度依赖关系。这种对attention结构的优化,会同时提升训练和推理的性能。

sliding window attention 将计算复杂度由变为。

注意力的时间复杂度是序列长度的二次方,空间复杂度是序列长度的一次方。在推理时,由于缓存的可用性降低,会造成更高的延迟和更小的吞吐量。为了减少这样的问题,提出了窗口注意力机制,在每一个注意力层每个token最多能注意前W个token。

图片

 

图片

 

注意力的传递通过层数的增加而向后传递。每一层注意力层,信息可以传递W tokens。经过两层注意力层,信息可以传递2W tokens。比如对于16k 序列长度和4k的滑动窗口,通过4层,信息可以实现整个序列长度的传递。因此序列越长,在滑动窗口长度固定的情况下,为了实现整个序列长度的传递,需要的注意力层数越多。

Attention底层实现优化

FlashAttention

FlashAttention 解决attention计算过程中,频繁访问HBM的问题,将attention计算block化,直接在SRAM中进行。

在GPU中底层对算子的优化,会同时提升模型的训练和推理性能。考虑到Attention在GPU的计算过程以及GPU的结构,优化Attention在GPU中的实现。

GPU中的两个核心部分,SRAM运算速度快但是存储量小,HBM运算速度慢但是存储量大。GPU中的operation运算过程,是从HBM拷贝数据进行运算,完成运算后再将数据存储到HBM.

图片

 

FlashAttention通过以下两个操作实现了attention的加速实现。

  • • 1.利用了GPU中存储的差异性。将数据从HBM拷贝到SRAM中,计算时从SRAM中读取数据,SRAM相比HBM读取和写入速度更快。

  • • 2.SRAM相比HBM速度快,但是存储量小,因此采用分块block的形式计算QK的矩阵乘法。即实现了并行block的softmax计算。为了保证分块block计算的softmax值与原有的softmax值不变,采用了block的 softmax计算。

FlashAttention将Q、K和V切分为block,进行block的计算,提高operation的处理速度。

图片

 

PagedAttention

PagedAttention解决attention计算过程中的内存分配问题,防止内存的浪费,更好的分配内存,可以实现更大的batch size和吞吐量。

传统KV Cache存在的问题主要包括:

1.显存占用大:对于大型模型如LLaMA-13B中的单个序列,KV Cache可能占用高达1.7GB的内存。

2.动态变化:KV Cache的大小取决于序列长度,而序列长度具有高度可变和不可预测的特点,这对有效管理KV Cache构成挑战。

3.内存碎片化和过度预留:由于显存碎片和过度预留,现有系统浪费了60%-80%的显存。

4.内部碎片化:在静态批处理策略下,一个请求结束后,其剩余的空间就被浪费掉了

5.外部碎片化:由于KV Cache是一个巨大的矩阵,且必须占用连续内存,操作系统如果只分配大的连续内存,势必有很多小的内存空间被浪费掉

PagedAttention 的优势

相当于小空间内存的动态分配,可以实现非连续的内存存储,解决了传统KV Cache连续动态内存分配造成的内存空间浪费。

https://cloud.tencent.com/developer/article/2316226https://arxiv.org/pdf/2309.06180.pdfhttps://arxiv.org/pdf/2305.13245.pdfhttps://zhuanlan.zhihu.com/p/672698614https://zhuanlan.zhihu.com/p/626079753

赞赏二维码


http://www.kler.cn/a/469983.html

相关文章:

  • 行为树详解(6)——黑板模式
  • 小白学Pytorch
  • 面试经典150题——链表(二)
  • 【Java数据结构】二叉树
  • Gitee图形界面上传(详细步骤)
  • 软件工程大复习之(四)——面向对象与UML
  • 【 算法设计与分析-回顾算法知识点】福建师范大学数学与计算机科学学院 2006 — 2007学年第二学期考试 A 卷
  • Spark和Mapreduce对比
  • SpringBoot开发——内置的 ObjectUtils 工具类详解
  • 【C++】类和对象(下):友元、static成员、内部类、explicit 和 匿名对象
  • VUE3配置后端地址,实现前后端分离及开发、正式环境分离
  • 【C++】const关键字_运算符重载_继承
  • Spring Boot教程之四十七:Spring Boot ——JDBC
  • BMS应用软件开发 — 2 单体电池的基本结构和工作原理
  • linux redis/: Permission denied,当前用户对该目录没有访问权限
  • Jdbc笔记01
  • 探索报表软件的世界:山海鲸、Tableau与Power BI比较
  • 头文件iostream的一些函数使用
  • 半导体数据分析: 玩转WM-811K Wafermap 数据集(二) AI 机器学习
  • ElasticSearch基础-文章目录
  • mapreduce 工作流程
  • 头歌python实验:网络安全应用实践-恶意流量检测
  • 【FTP 协议】FTP主动模式
  • Rabbitmq消息补偿机制
  • 【机器学习】从监督学习的懵懂起步至迁移学习的前沿瞭望
  • iOS - 自定义引用计数(MRC)