当前位置: 首页 > article >正文

Transformer | 一文了解:缩放、批量、多头、掩码、交叉注意力机制(Attention)

源自: AINLPer(每日干货分享!!)
编辑: ShuYini
校稿: ShuYini
时间: 2025-3-27

更多:>>>>专注大模型/AIGC、学术前沿的知识分享!

引言

之前的文章:2万字长文!一文了解Attention,从MHA到DeepSeek MLA,大量图解,非常详细!主要分享了Attention算法的演化。但是对于基础的Attention算法的细节整理的不够详细,今天这篇文章填补上这一点,并利用纯Python和Numpy实现注意力模块,并解释了整个过程中的所有向量维度的变化,对刚入门的新手非常友好。文章安排如下:

  • 基础缩放Attention
  • 批量Attention
  • 多头Attention
  • 掩码Attention
  • 交叉Attention
  • 跨头维度向量化

另外,感谢卢卡3541对于文章:2万字长文!一文了解Attention… GQA部分的错误指正,留言回复已经置顶,大家看的时候可以结合,后续作者回整理到一块。

基础缩放Attention

我们先从最基础的缩放点积自注意力开始,它只针对单个序列的Token进行操作,不涉及掩码。

输入是一个形状为(N,D)的二维数组,如下图所示。N是序列的长度(即它包含的Token数量),D是嵌入维度——每个Token的嵌入向量长度。D的值可以是512,或者更大,具体取决于模型。

img

自注意力模块通过三个权重矩阵WkWqWv进行参数化。某些变体还配有偏置向量,这里为方便解释,因此这里将其省略了。一般情况下,每个权重矩阵的形状为(D,HS),其中HS是D的某个分数。HS代表“头大小”,稍后我们会明白它的含义。这是一个自注意力模块的示意图(图中假设N=6,D是一个较大的数字,HS也是如此,一般情况下这里的D=HS)。在图中,@表示矩阵乘法(Python/Numpy语法):

img

以下是该算法的基本Numpy实现:

# x is the input (N, D), each token in a row.
# Each of W* is a weight matrix of shape (D, HS)
# The result is (N, HS)
def self_attention(x, Wk, Wq, Wv):
    # Each of these is (N, D) @ (D, HS) = (N, HS)
    q = x @ Wq
    k = x @ Wk
    v = x @ Wv

    # kq: (N, N) matrix of dot products between each pair of q and k vectors.
    # The division by sqrt(HS) is the scaling.
    kq = q @ k.T / np.sqrt(k.shape[1])

    # att: (N, N) attention matrix. The rows become the weights that sum
    # to 1 for each output vector.
    att = softmax_lastdim(kq)
    return att @ v  # (N, HS)

“缩放”部分只是将kq除以HS的平方根,这样做的目的是为了使点积的值保持在可控范围内(否则它们会随着收缩维度的大小而增长)【面试的时候,面试官可能会问这个问题:为什么计算完 Q K T QK^T QKT要除以 d \sqrt{d} d 】。

唯一的依赖是一个用于计算输入数组最后一维Softmax的函数:

def softmax_lastdim(x):
    """Compute softmax across last dimension of x.

    x is an arbitrary array with at least two dimensions. The returned array has
    the same shape as x, but its elements sum up to 1 across the last dimension.
    """
    # Subtract the max for numerical stability
    ex = np.exp(x - np.max(x, axis=-1, keepdims=True))
    # Divide by sums across last dimension
    return ex / np.sum(ex, axis=-1, keepdims=True)

当输入是二维时,“最后一维”指的是列。通俗地说,这个Softmax函数分别对x的每一行进行操作它将Softmax公式应用于每行的元素,最终得到一行介于[0,1]之间的数字,这些数字加起来等于1。

再提一下维度的问题:Wv矩阵的第二维可以与WqWk不同。如果你看看示意图,你会发现这也能行得通,因为Softmax产生的结果是 ( N , N ) (N,N) NN,而不管V的第二维是什么,输出的第二维就会是什么。我们可以将k、v的维度分别记作 d k d_k dk d v d_v dv,但我们可以发现,目前几乎所有的Attention计算方法中,这两个维度通常也是相同的,即 d k = d v d_k=d_v dk=dv。因此为了简化,本文中将它们都设为D;如果需要,对代码进行修改以实现不同的 d k d_k dk d v d_v dv也是相当简单的。

批量自注意力

再进一步!在现实世界中,输入数组不太可能是二维的,因为模型是在输入序列的批次上进行训练的。为了利用现代硬件的并行性,通常会在同一个操作中处理整个批次。

img

批量缩放自注意力与非批量版本非常相似,这要归功于Numpy矩阵乘法和广播的魔力。现在输入的形状是(B,N,D),其中B是批次维度。W*矩阵的形状仍然是(D,HS);将一个(B,N,D)数组乘以(D,HS)会在第一个数组的最后一维和第二个数组的第一维之间进行收缩,结果为(B,N,HS)。以下是带有维度标注的代码:

# self_attention with inputs that have a batch dimension.
# x has shape (B, N, D)
# Each of W* has shape (D, D)
def self_attention_batched(x, Wk, Wq, Wv):
    q = x @ Wq  # (B, N, HS)
    k = x @ Wk  # (B, N, HS)
    v = x @ Wv  # (B, N, HS)

    kq = q @ k.swapaxes(-2, -1) / np.sqrt(k.shape[-1])  # (B, N, N)

    att = softmax_lastdim(kq)  # (B, N, N)
    return att @ v  # (B, N, HS)

与非批量版本唯一的区别在于计算kq的那行代码:

  • 由于k不再是二维的,“转置”的概念变得模糊不清,因此我们明确要求交换最后一维和倒数第二维,同时保留第一维度(B)。
  • 在计算缩放因子时,我们使用k.shape[-1]来选择k的_最后一维_,而不是k.shape[1](后者仅适用于二维数组)。

实际上,这个函数也可以计算非批量版本!从现在开始,我们将假设所有输入都是批量的,所有操作都是隐式的批量操作。后面将不会再在函数中使用“批量”前缀或后缀了。

自注意力模块的基本思想是将序列中Token的多维表示进行调整,以更好地表示整个序列。这些Token相互“关注”。具体来说,Softmax操作产生的矩阵被称为_注意力矩阵_。它是一个(N,N)的矩阵;对于每个Token,它指定了在序列中应考虑来自其它Token的多少信息。例如,矩阵中第(R,C)个单元格的值越高,就意味着序列中索引为R的Token与索引为C的Token之间的关系越强。

下面这个例子展示了单词序列以及两个注意力头(紫色和棕色)为输入序列中的某个位置产生的权重:

img

这个例子展示了模型是如何学习解决句子中“its”所指代的内容的。以紫色的头为例。序列中Token“its”的索引是8,而“Law”的索引是1。在这个头的注意力矩阵中,索引为(8,1)的值将非常高(接近1),而同一行中的其他值则会低得多。

虽然这种直观的解释对于理解注意力的实现并不是至关重要的,但在我们稍后讨论_掩码_自注意力时,它会变得更加重要。

多头注意力

上面看到的注意力机制只有一组K、Q和V矩阵。这被称为一个“头”的注意力。在当今的模型中,通常有多个头。每个头分别执行其注意力任务,最终将所有这些结果连接起来并通过一个线性层

在下文中,NH代表头的数量,HS代表头的大小。通常,NH乘以HS等于D;例如,D=512维度情况下,可能有以下几种配置:NH=8且HS=64,NH=32且HS=16,等等。然而,即使情况并非如此,数学计算仍然可行,因为最终的线性(“投影”)层将输出映射回(N,D)

假设前面的示意图展示的是一个具有输入(N,D)和输出(N,HS)的单头自注意力模块,以下是多个头的组合方式:

img

每个头都有其自己的Q、K和V权重矩阵。每个注意力头输出一个(N,HS)矩阵;这些矩阵沿着最后一维连接起来,形成(N,NH*HS),然后通过最终的线性投影。

以下是一个实现(批量)多头注意力的函数;下面可以暂时忽略do_mask条件内的代码:

# x has shape (B, N, D)
# In what follows:
#   NH = number of heads
#   HS = head size
# Each W*s is a list of NH weight matrices of shape (D, HS).
# Wp is a weight matrix for the final linear projection, of shape (NH * HS, D)
# The result is (B, N, D)
# If do_mask is True, each attention head is masked from attending to future
# tokens.
def multihead_attention_list(x, Wqs, Wks, Wvs, Wp, do_mask=False):
    # Check shapes.
    NH = len(Wks)
    HS = Wks[0].shape[1]
    assert len(Wks) == len(Wqs) == len(Wvs)
    for W in Wqs + Wks + Wvs:
        assert W.shape[1] == HS
    assert Wp.shape[0] == NH * HS

    # List of head outputs
    head_outs = []

    if do_mask:
        # mask is a lower-triangular (N, N) matrix, with zeros above
        # the diagonal and ones on the diagonal and below.
        N = x.shape[1]
        mask = np.tril(np.ones((N, N)))

    for Wk, Wq, Wv in zip(Wks, Wqs, Wvs):
        # Calculate self attention for each head separately
        q = x @ Wq  # (B, N, HS)
        k = x @ Wk  # (B, N, HS)
        v = x @ Wv  # (B, N, HS)

        kq = q @ k.swapaxes(-2, -1) / np.sqrt(k.shape[-1])  # (B, N, N)

        if do_mask:
            # Set the masked positions to -inf, to ensure that a token isn't
            # affected by tokens that come after it in the softmax.
            kq = np.where(mask == 0, -np.inf, kq)

        att = softmax_lastdim(kq)  # (B, N, N)
        head_outs.append(att @ v)  # (B, N, HS)

    # Concatenate the head outputs and apply the final linear projection
    all_heads = np.concatenate(head_outs, axis=-1)  # (B, N, NH * HS)
    return all_heads @ Wp  # (B, N, D)

这里是通过numpy实现的多头注意力,我们在源码中看到的基本上都是pytorch写的,可以对比一下面这个代码:

def forward(self, query, key, value):
        """
        多头注意力的前向传播。
        :param query: 查询张量,形状为 [batch_size, seq_len_q, embed_dim]
        :param key: 键张量,形状为 [batch_size, seq_len_k, embed_dim]
        :param value: 值张量,形状为 [batch_size, seq_len_k, embed_dim]
        :return: 输出张量,形状为 [batch_size, seq_len_q, embed_dim]
        """
        batch_size = query.shape[0]

        # 将输入映射到 Query、Key 和 Value
        Q = self.query_linear(query)  # [batch_size, seq_len_q, embed_dim]
        K = self.key_linear(key)      # [batch_size, seq_len_k, embed_dim]
        V = self.value_linear(value)  # [batch_size, seq_len_k, embed_dim]

        # 分割成多个头
        Q = Q.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)  # [batch_size, num_heads, seq_len_q, head_dim]
        K = K.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)  # [batch_size, num_heads, seq_len_k, head_dim]
        V = V.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)  # [batch_size, num_heads, seq_len_k, head_dim]

        # 计算点积注意力分数
        attention_scores = torch.matmul(Q, K.transpose(-2, -1)) / self.scale  # [batch_size, num_heads, seq_len_q, seq_len_k]

        # 应用 Softmax 函数,得到注意力权重
        attention_weights = F.softmax(attention_scores, dim=-1)  # [batch_size, num_heads, seq_len_q, seq_len_k]

        # 加权求和,得到每个头的输出
        output = torch.matmul(attention_weights, V)  # [batch_size, num_heads, seq_len_q, head_dim]

        # 合并所有头的输出
        output = output.transpose(1, 2).contiguous().view(batch_size, -1, self.embed_dim)  # [batch_size, seq_len_q, embed_dim]

        # 通过输出的线性层
        output = self.out(output)  # [batch_size, seq_len_q, embed_dim]

        return output, attention_weights

掩码自注意力

注意力模块可用于编码器和解码器块。编码器块适用于语言理解或翻译等任务;对于这些任务,序列中的每个Token关注序列中的其他所有Token是有意义的。

然而,对于生成模型来说,这会带来一个问题:如果在训练过程中一个词关注了未来的词,模型就会“作弊”,而无法真正学会如何仅从过去的词生成下一个词。这在解码器块中完成,为此我们需要为注意力添加掩码。

从概念上讲,掩码非常简单。考虑以下句子:

People like watching funny cat videos

当我们的注意力代码生成att矩阵时,它是一个(N,N)的方阵,包含序列中每个Token到其他每个Token的注意力权重:
img
我们希望图中所有灰色单元格的值都为零,以确保一个Token不会关注未来的Token。蓝色单元格在经过Softmax操作后,每一行的值加起来等于1。

现在再看看前面的代码示例,看看当do_mask=True时会发生什么:

  1. 首先,准备一个(N,N)的下三角数组,其对角线上方的值为零,对角线及下方的值为一。
  2. 然后,在将缩放后的kq传递给Softmax之前,将掩码矩阵为0的位置的值设置为-np.inf。这确保了Softmax函数会将这些索引处的输出值设为零,同时仍然在行的其余部分产生正确的值。

掩码自注意力的另一个名称是因果自注意力

交叉注意力

到目前为止,我们一直在看自注意力模块,其中的“自”表明输入序列中的元素会关注同一输入序列中的其他元素

注意力的另一种变体是交叉注意力,其中一个序列的元素会关注另一个序列中的元素。这种变体一般会存在于具有解码器块的模型中。这是一个单头交叉注意力的示意图:

cross-attention with different Nq, Nv

这里我们有两个可能长度不同的序列:xqxvxq用于注意力的请求部分,而xv用于键和值部分。其余的维度保持不变。这种模块的输出形状为(Nq,HS),一般情况下D和HS是一样的,这里HS是为了方便解释。

以下是一个实现多头交叉注意力的函数;它没有包括掩码,因为在交叉注意力中通常不需要掩码,因为xq的元素可以关注xv的所有元素:

# Cross attention between two input sequences that can have different lengths.
# xq has shape (B, Nq, D)
# xv has shape (B, Nv, D)
# In what follows:
#   NH = number of heads
#   HS = head size
# Each W*s is a list of NH weight matrices of shape (D, HS).
# Wp is a weight matrix for the final linear projection, of shape (NH * HS, D)
# The result is (B, Nq, D)
def multihead_cross_attention_list(xq, xv, Wqs, Wks, Wvs, Wp):
    # Check shapes.
    NH = len(Wks)
    HS = Wks[0].shape[1]
    assert len(Wks) == len(Wqs) == len(Wvs)
    for W in Wqs + Wks + Wvs:
        assert W.shape[1] == HS
    assert Wp.shape[0] == NH * HS

    # List of head outputs
    head_outs = []

    for Wk, Wq, Wv in zip(Wks, Wqs, Wvs):
        q = xq @ Wq  # (B, Nq, HS)
        k = xv @ Wk  # (B, Nv, HS)
        v = xv @ Wv  # (B, Nv, HS)

        kq = q @ k.swapaxes(-2, -1) / np.sqrt(k.shape[-1])  # (B, Nq, Nv)

        att = softmax_lastdim(kq)  # (B, Nq, Nv)
        head_outs.append(att @ v)  # (B, Nq, HS)

    # Concatenate the head outputs and apply the final linear projection
    all_heads = np.concatenate(head_outs, axis=-1)  # (B, Nq, NH * HS)
    return all_heads @ Wp  # (B, Nq, D)

跨头维度向量化

前面展示的multihead_attention_list实现使用权重矩阵的列表作为输入。虽然这使代码更清晰,但对于优化实现(特别是在GPU和TPU等加速器上)来说,并不是一个特别友好的格式。我们可以通过为注意力头创建一个新维度来进一步向量化它。(目前几乎所有的Attention算法都是基于这种方式来实现的!

为了理解所使用的技巧,考虑一个基本的矩阵乘法,(8,6)乘以(6,2):
basic matrix multiplication
现在假设我们想将LHS乘以另一个(6,2)矩阵。我们可以通过沿着列将两个RHS矩阵连接起来,在同一个操作中完成它们的乘法:
concatenated basic matrix multiplication
如果两个图中的黄色RHS块是相同的,那么结果中的绿色块也将是相同的。紫色块只是LHS与RHS的红色块的矩阵乘法的结果。这源于矩阵乘法的语义,很容易能够理解。

现在回到我们的多头注意力。注意,我们将输入x乘以一整个权重矩阵——事实上,是乘以三个权重矩阵(一个用于Q,一个用于K,另一个用于V)。我们可以使用相同的向量化技巧,将所有这些权重矩阵连接成一个单独的矩阵。假设NH*HS=D,那么组合矩阵的形状为(D,3*D)。以下是向量化的实现:

# x has shape (B, N, D)
# In what follows:
#   NH = number of heads
#   HS = head size
#   NH * HS = D
# W is expected to have shape (D, 3 * D), with all the weight matrices for
# Qs, Ks, and Vs concatenated along the last dimension, in this order.
# Wp is a weight matrix for the final linear projection, of shape (D, D).
# The result is (B, N, D).
# If do_mask is True, each attention head is masked from attending to future
# tokens.
def multihead_attention_vec(x, W, NH, Wp, do_mask=False):
    B, N, D = x.shape
    assert W.shape == (D, 3 * D)
    qkv = x @ W  # (B, N, 3 * D)
    q, k, v = np.split(qkv, 3, axis=-1)  # (B, N, D) each

    if do_mask:
        # mask is a lower-triangular (N, N) matrix, with zeros above
        # the diagonal and ones on the diagonal and below.
        mask = np.tril(np.ones((N, N)))

    HS = D // NH
    q = q.reshape(B, N, NH, HS).transpose(0, 2, 1, 3)  # (B, NH, N, HS)
    k = k.reshape(B, N, NH, HS).transpose(0, 2, 1, 3)  # (B, NH, N, HS)
    v = v.reshape(B, N, NH, HS).transpose(0, 2, 1, 3)  # (B, NH, N, HS)

    kq = q @ k.swapaxes(-1, -2) / np.sqrt(k.shape[-1])  # (B, NH, N, N)

    if do_mask:
        # Set the masked positions to -inf, to ensure that a token isn't
        # affected by tokens that come after it in the softmax.
        kq = np.where(mask == 0, -np.inf, kq)

    att = softmax_lastdim(kq)  # (B, NH, N, N)
    out = att @ v  # (B, NH, N, HS)
    return out.transpose(0, 2, 1, 3).reshape(B, N, D) @ Wp  # (B, N, D)

这段代码通过单次矩阵乘法计算Q、K和V,然后将它们拆分成单独的数组。

Q、K和V最初是(B,N,D),因此通过首先将D拆分成(NH,HS),然后改变维度的顺序,将它们重塑成更方便的形状,得到(B,NH,N,HS)。以这种格式,kq的计算可以像以前一样进行,Numpy将自动在所有批次维度上执行矩阵乘法。

有时你会在论文中看到用于这些矩阵乘法的另一种符号:numpy.einsum 或者torch.einsum。例如,在我们最后一个代码示例中,kq的计算也可以写成:

kq = np.einsum("bhqd,bhkd->bhqk", q, k) / np.sqrt(k.shape[-1])

为了更好的理解,这里给一个torch.einsum使用样例:

import torch
A = torch.randn(3, 4)
B = torch.randn(4, 5)
C = torch.einsum("ij,jk->ik", A, B)

这里的 “ij,jk->ik” 表示:

  • 输入张量 A 的索引为 i,j。
  • 输入张量 B 的索引为 j,k。
  • 输出张量 C 的索引为 i,k。

更多:>>>>专注大模型/AIGC、学术前沿的知识分享!

推荐阅读

[1] 2025年的风口!| 万字长文让你了解大模型Agent
[2] 大模型Agent的 “USB”接口!| 一文详细了解MCP(模型上下文协议)
[3] 盘点一下!大模型Agent的花式玩法,涉及娱乐、金融、新闻、软件等各个行业
[4] 一文了解大模型Function Calling
[5] 万字长文!最全面的大模型Attention介绍,含DeepSeek MLA,含大量图示!
[6]一文带你详细了解:大模型MoE架构(含DeepSeek MoE详解)
[7] 颠覆大模型归一化!Meta | 提出动态Tanh:DyT,无归一化的 Transformer 性能更强


http://www.kler.cn/a/611830.html

相关文章:

  • DMA 之FIFO的作用
  • .NET开源的智能体相关项目推荐
  • c#的反射和特性
  • Docker实现MySQL主从复制配置【简易版】
  • 旅游纵览杂志旅游纵览杂志社旅游纵览编辑部2025年第2期目录
  • 微服务与分布式系统
  • Axure设计之中继器表格——拖动列调整位置教程(中继器)
  • python文件保存
  • Nextjs15 - 服务端组件(RSC)与客服端组件
  • SVTAV1热点函数-svt_ext_all_sad_calculation_8x8_16x16_avx2
  • python面试-基础
  • thinkphp8.0\swoole的websocket应用
  • vue配置.eslintrc、.prettierrc详解
  • Android 问真八字-v2.1.7[看八字APP]
  • Netty源码—8.编解码原理二
  • 2025年具有AI招聘管理系统选型及攻略分享
  • Rust从入门到精通之入门篇:8.基本数据结构
  • 快速入手-基于Django-rest-framework的mixins混合类(五)
  • 自然语言处理(NLP)技术的应用面有哪些
  • 如何卸载雷池 WAF