当前位置: 首页 > article >正文

注意力机制:让AI拥有黄金七秒记忆的魔法--(注意力机制中的Q、K、V)

注意力机制:让AI拥有"黄金七秒记忆"的魔法–(注意力机制中的Q、K、V)

在注意⼒机制中,查询(Query)、键(Key)和值(Value)是三个关键部分。

■ 查询(Query):是指当前需要处理的信息。模型根据查询向量在输⼊序列中查找相关信息。

■ 键(Key):是指来⾃输⼊序列的⼀组表示。它们⽤于根据查询向量计算注意⼒权重。注意⼒权重反映了不同位置的输⼊数据与查询的相关性。

■ 值(Value):是指来⾃输⼊序列的⼀组表示。它们⽤于根据注意⼒权重计算加权和,得到最终的注意⼒输出向量,其包含了与查询最相关的输⼊信息。

用下面栗子打一个比方:

import torch # 导入 torch
import torch.nn.functional as F # 导入 nn.functional
# 1. 创建两个张量 x1 和 x2
x1 = torch.randn(2, 3, 4) # 形状 (batch_size, seq_len1, feature_dim)
x2 = torch.randn(2, 5, 4) # 形状 (batch_size, seq_len2, feature_dim)
# 2. 计算原始权重
raw_weights = torch.bmm(x1, x2.transpose(1, 2)) # 形状 (batch_size, seq_len1, seq_len2)
# 3. 用 softmax 函数对原始权重进行归一化
attn_weights = F.softmax(raw_weights, dim=2) # 形状 (batch_size, seq_len1, seq_len2)
# 4. 将注意力权重与 x2 相乘,计算加权和
attn_output = torch.bmm(attn_weights, x2)  # 形状 (batch_size, seq_len1, feature_dim)

我们可以将x1视为查询(Query,Q)向量,将x2视为键(Key,K)和值(Value,V)向量。这是因为我们直接使⽤x1和x2的点积作为相似度得分,并将权重应⽤于x2本身来计算加权信息。所以,在这个简化示例中,Q对应于x1,KV都对应于x2。

然⽽,在Transformer中,QKV通常是从相同的输⼊序列经过不同的线性变换得到的不同向量。

import torch
import torch.nn.functional as F
#1. 创建 Query、Key 和 Value 张量
q = torch.randn(2, 3, 4) # 形状 (batch_size, seq_len1, feature_dim)
k = torch.randn(2, 4, 4) # 形状 (batch_size, seq_len2, feature_dim)
v = torch.randn(2, 4, 4) # 形状 (batch_size, seq_len2, feature_dim)
# 2. 计算点积,得到原始权重,形状为 (batch_size, seq_len1, seq_len2)
raw_weights = torch.bmm(q, k.transpose(1, 2))
# 3. 将原始权重进行缩放(可选),形状仍为 (batch_size, seq_len1, seq_len2)
scaling_factor = q.size(-1) ** 0.5
scaled_weights = raw_weights / scaling_factor
# 4. 应用 softmax 函数,使结果的值在 0 和 1 之间,且每一行的和为 1
attn_weights = F.softmax(scaled_weights, dim=-1) # 形状仍为 (batch_size, seq_len1, seq_len2)
# 5. 与 Value 相乘,得到注意力分布的加权和 , 形状为 (batch_size, seq_len1, feature_dim)
attn_output = torch.bmm(attn_weights, v)

KV的维度是否需完全相同呢?

在缩放点积注意⼒中,KV向量的维度不⼀定需要完全相同。在这种注意⼒机制中,KV的序列⻓度维度(在这⾥是第2维)应该相同,因为它们描述了同⼀个序列的不同部分。然⽽,它们的特征(或隐藏层)维度(在这⾥是第3维)可以不同。V向量的第⼆个维度则决定了最终输出张量的特征维度,这个维度可以根据具体任务和模型设计进⾏调整。

K向量的序列⻓度维度(在这⾥是第2维)和Q向量的序列⻓度维度可以不同,因为它们可以来⾃不同的输⼊序列,但是,K向量的特征维度(在这⾥是第3维)需要与Q向量的特征维度相同,因为它们之间要计算点积。

在实践中,KV的各个维度通常是相同的,因为它们通常来⾃同⼀个输⼊序列并经过不同的线性变换。

在注意力机制中,k(Key)和 v(Value)的初始值并不是随机产生的,而是由输入数据经过各自的线性变换得到的。具体来说:

来源相同但变换不同:

  • 假设我们有一个输入序列的表示矩阵 X(例如编码器的输出或者词嵌入),

  • 我们通过三个不同的线性层(也就是不同的权重矩阵)分别计算 Query、Key 和 Value:

    • q = X W q q=XW_q q=XWq
    • k = X W k k=XW_k k=XWk
    • v = X W v v= XW_v v=XWv
  • 这里, W q W_q Wq W k W_k Wk W v W_v Wv 是模型在训练过程中学习到的参数矩阵。

  • 确定方式
    这些矩阵 W k W_k Wk W v W_v Wv 在模型设计时就被定义好,并在训练过程中通过反向传播进行更新。

  • 作用不同

    • Key (k):通过 W k W_k Wk 得到,用来与 Query 进行匹配,计算注意力分数,决定输入中哪些部分对当前 Query 最重要。
    • Value (v):通过 W v W_v Wv得到,它携带的是具体的信息内容,最终会根据注意力分数被加权求和,形成输出的上下文向量。
  • k 与 v 的初始值都源自相同的输入 X,但它们经过了各自独立的线性变换,参数 W k W_k Wk W v W_v Wv 决定了它们具体的数值和表示。

  • 这两个过程是在训练过程中自动学习并调整的,确保模型能够有效地捕捉和利用输入信息。

这样,通过学习到的权重矩阵,模型可以从输入中抽取出适合进行匹配(Key)和传递信息(Value)的表示。

现在,重写缩放点积注意⼒的计算过程,如下所述。

(1)计算Q向量和K向量的点积。

(2)将点积结果除以缩放因⼦(Q向量特征维度的平⽅根)。

(3)应⽤softmax函数得到注意⼒权重。

(4)使⽤注意⼒权重对V向量进⾏加权求和。

这个过程的图示如下⻚图所示:

image-20250315220205401

具体到编码器-解码器注意⼒来说,可以这样理解QKV向量。

Q向量代表了解码器在当前时间步的表示,⽤于和K向量进⾏匹配,以计算注意⼒权重Q向量通常是解码器隐藏状态的线性变换

K向量是编码器输出的⼀种表示,⽤于和Q向量进⾏匹配,以确定哪些编码器输出对于当前解码器时间步来说最相关K向量通常是编码器隐藏状态的线性变换

V向量是编码器输出的另⼀种表示,⽤于计算加权求和,⽣成注意⼒上下⽂向量。注意⼒权重会作⽤在V向量上,以便在解码过程中关注输⼊序列中的特定部分。V向量通常也是编码器隐藏状态的线性变换

在刚才的编码器-解码器注意⼒示例中,直接使⽤了编码器隐藏状态和解码器隐藏状态来计算注意⼒。这⾥的QKV向量并没有显式地表示出来(⽽且,此处KV是同⼀个向量),但它们的概念仍然隐含在实现中:

■ 编码器隐藏状态(encoder_hidden_states)充当了KV向量的⻆⾊。

■ 解码器隐藏状态(decoder_hidden_states)充当了Q向量的⻆⾊。

我们计算Q向量(解码器隐藏状态)与K向量(编码器隐藏状态)之间的点积来得到注意⼒权重,然后⽤这些权重对V向量(编码器隐藏状态)进⾏加权求和,得到上下⽂向量。

当然了,在⼀些更复杂的注意⼒机制(如Transformer中的多头⾃注意⼒机制)中,QKV向量通常会更明确地表示出来,因为我们需要通过使⽤不同的线性层将相同的输⼊序列显式地映射到不同的QKV向量空间。

V向量表示值,⽤于计算加权信息。通过将注意⼒权重应⽤于V向量,我们可以获取输⼊序列中与Q向量相关的信息。它们(QKV)其实都是输⼊序列,有时是编码器输⼊序列,有时是解码器输⼊序列,有时是神经⽹络中的隐藏状态(也来⾃输⼊序列)的线性表示,也都是序列的“嵌⼊向量”。


http://www.kler.cn/a/595865.html

相关文章:

  • 广度优先搜索(BFS)完全解析:从原理到 Java 实战
  • 分布式中间件:RabbitMQ确认消费机制
  • Ubuntu 22.04 上配置 ufw(Uncomplicated Firewall)防火墙的详细步骤
  • watch方法解析
  • win32汇编环境,网络编程入门之八
  • 20250319在荣品的PRO-RK3566开发板的buildroot系统下使用集成的QT应用调试串口UART3
  • 深度学习与计算机视觉方向
  • docker、docker-compose常用命令
  • 【C#高级编程】—表达式树详解
  • k8s自动弹性伸缩之HPA实践
  • 网络编程——套接字、创建服务器、创建客户端
  • 挑战用AI替代我的工作——从抢券困境到技术突破
  • C# System.Text.Encoding 使用详解
  • 支持向量机(Support Vector Machine)基础知识1
  • 普通鼠标的500连击的工具来了!!!
  • 【AIGC】Win10系统极速部署Docker+Ragflow+Dify
  • 最新!Ubuntu Docker 安装教程
  • C 语 言 --- 扫 雷 游 戏(初 阶 版)
  • 系统思考—链接组织效能提升与问题解决
  • OSPF 协议详解:从概念原理到配置实践的全网互通实现