Attention系列笔记
- Attention:
- self Attention:
- Multi-head Attention
- Multi-query Attention
1.Attention:
首先计算查询向量Q 和键向量 K 的点积。通过点积,可以衡量查询和键之间的相似性。然后对点积结果进行缩放,即除以根号dk,对点积结果进行 softmax 归一化,得到每个查询对所有键的权重;最后,用归一化后的权重对值 V 进行加权求和,得到最终的注意力输出.
除以根号d_k这个缩放步骤是为了避免在高维空间中,点积值过大,导致 softmax 的梯度过小,造成模型训练不稳定。
import torch
import torch.nn as nn
import torch.nn.functional as F
class ScaledDotProductAttention(nn.Module):
def __init__(self):
super(ScaledDotProductAttention, self).__init__()
def forward(self, Q, K, V, mask=None):
"""
:param Q: 查询矩阵,形状为 [batch_size, seq_len_q, d_k]
:param K: 键矩阵,形状为 [batch_size, seq_len_k, d_k]
:param V: 值矩阵,形状为 [batch_size, seq_len_v, d_v]
:param mask: 掩码矩阵,可选,用于遮挡某些位置 [batch_size, seq_len_q, seq_len_k]
:return: 注意力输出和权重矩阵
"""
# 计算 Q 和 K 的点积
d_k = Q.size(-1) # 获取 d_k 的维度
scores = torch.matmul(Q, K.transpose(-2, -1)) / torch.sqrt(torch.tensor(d_k, dtype=torch.float32))
# scores: [batch_size, seq_len_q, seq_len_k]
# 如果有掩码,应用掩码,将被掩码的位置设置为一个大负数
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
# 计算注意力权重
attention_weights = F.softmax(scores, dim=-1) # [batch_size, seq_len_q, seq_len_k]
# 根据权重计算输出
output = torch.matmul(attention_weights, V) # [batch_size, seq_len_q, d_v]
return output, attention_weights
# 初始化
attention = ScaledDotProductAttention()
# 示例输入
batch_size = 2
seq_len_q = 3
seq_len_k = 4
d_k = 5
d_v = 6
Q = torch.randn(batch_size, seq_len_q, d_k) # 查询矩阵
K = torch.randn(batch_size, seq_len_k, d_k) # 键矩阵
V = torch.randn(batch_size, seq_len_k, d_v) # 值矩阵
mask = torch.tensor([[1, 1, 1, 0], [1, 1, 0, 0]]) # 掩码(示例)
# 扩展掩码形状到 [batch_size, seq_len_q, seq_len_k]
mask = mask.unsqueeze(1).expand(batch_size, seq_len_q, seq_len_k)
# 前向传播
output, attention_weights = attention(Q, K, V, mask)
print("输出:", output)
print("注意力权重:", attention_weights)
2 self-attention
自注意力的目标是根据输入序列中的每个位置的特征,计算该位置与序列中其他位置的相关性。这种机制能够动态地关注输入序列中的重要部分,而不是局限于固定大小的上下文窗口。
import torch
import torch.nn as nn
import torch.nn.functional as F
class SingleHeadSelfAttention(nn.Module):
def __init__(self, embed_dim):
"""
:param embed_dim: 输入特征维度
"""
super(SingleHeadSelfAttention, self).__init__()
# 定义线性变换矩阵
self.W_Q = nn.Linear(embed_dim, embed_dim)
self.W_K = nn.Linear(embed_dim, embed_dim)
self.W_V = nn.Linear(embed_dim, embed_dim)
self.W_O = nn.Linear(embed_dim, embed_dim) # 输出投影
def forward(self, X, mask=None):
"""
:param X: 输入序列,形状为 [batch_size, seq_len, embed_dim]
:param mask: 掩码矩阵,形状为 [batch_size, seq_len, seq_len],可选
:return: 自注意力输出和注意力权重
"""
batch_size, seq_len, embed_dim = X.size()
# 线性变换生成 Q, K, V
Q = self.W_Q(X) # [batch_size, seq_len, embed_dim]
K = self.W_K(X) # [batch_size, seq_len, embed_dim]
V = self.W_V(X) # [batch_size, seq_len, embed_dim]
# 计算缩放点积注意力
scores = torch.matmul(Q, K.transpose(-2, -1)) / torch.sqrt(torch.tensor(embed_dim, dtype=torch.float32))
# scores: [batch_size, seq_len, seq_len]
# 如果有掩码,应用掩码
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
# 注意力权重
attention_weights = F.softmax(scores, dim=-1) # [batch_size, seq_len, seq_len]
# 加权求和生成输出
attention_output = torch.matmul(attention_weights, V) # [batch_size, seq_len, embed_dim]
# 输出线性变换
output = self.W_O(attention_output) # [batch_size, seq_len, embed_dim]
return output, attention_weights
# 示例输入
batch_size = 2
seq_len = 5
embed_dim = 16
X = torch.randn(batch_size, seq_len, embed_dim) # 输入序列
# 初始化自注意力模块
self_attention = SingleHeadSelfAttention(embed_dim=embed_dim)
# 前向传播
output, attention_weights = self_attention(X)
print("自注意力输出:", output.shape) # [batch_size, seq_len, embed_dim]
print("注意力权重:", attention_weights.shape) # [batch_size, seq_len, seq_len]
3 多头注意力
多头注意力(Multi-Head Attention)是 Transformer 的核心组件之一,它通过并行多个独立的注意力头来提升模型的表达能力和捕获不同特征的能力,输入与参数定义:
import torch
import torch.nn as nn
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, num_heads):
super(MultiHeadAttention, self).__init__()
assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_model // num_heads
# Learnable weights
self.W_Q = nn.Linear(d_model, d_model)
self.W_K = nn.Linear(d_model, d_model)
self.W_V = nn.Linear(d_model, d_model)
self.W_O = nn.Linear(d_model, d_model)
def forward(self, X):
batch_size = X.size(0)
# Linear transformations
Q = self.W_Q(X) # [batch_size, seq_len, d_model]
K = self.W_K(X)
V = self.W_V(X)
# Split into multiple heads
Q = Q.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2) # [batch_size, num_heads, seq_len, d_k]
K = K.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
V = V.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
# Scaled dot-product attention
scores = torch.matmul(Q, K.transpose(-2, -1)) / torch.sqrt(torch.tensor(self.d_k, dtype=torch.float32))
attention = torch.softmax(scores, dim=-1) # [batch_size, num_heads, seq_len, seq_len]
head_output = torch.matmul(attention, V) # [batch_size, num_heads, seq_len, d_k]
# Concatenate heads and project
concat_heads = head_output.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)
output = self.W_O(concat_heads) # [batch_size, seq_len, d_model]
return output
import torch
import torch.nn as nn
# 定义参数
embed_dim = 16 # 嵌入维度
num_heads = 4 # 注意力头数
seq_len = 5 # 序列长度
batch_size = 2 # 批大小
# 创建多头注意力模块
multihead_attn = MultiheadAttention(embed_dim=embed_dim, num_heads=num_heads, batch_first=True)
# 示例输入
X = torch.randn(batch_size, seq_len, embed_dim) # 输入序列 [batch_size, seq_len, embed_dim]
# 注意力 mask,可选
# 用于屏蔽某些位置,防止注意力关注不需要的位置
attn_mask = torch.ones(seq_len, seq_len) # [seq_len, seq_len]
attn_mask = torch.triu(attn_mask, diagonal=1) # 上三角矩阵,掩盖后续时间步
attn_mask = attn_mask.masked_fill(attn_mask == 1, float('-inf')) # 将掩盖位置设置为 -inf
# 前向传播
output, attention_weights = multihead_attn(X, attn_mask=attn_mask)
# 打印结果
print("多头注意力输出:", output.shape) # [batch_size, seq_len, embed_dim]
print("注意力权重:", attention_weights.shape) # [batch_size, num_heads, seq_len, seq_len]
- multi-query Attention
多查询注意力的核心思想是,在多个注意力头(attention head)中,共享一组注意力键(keys)和值(values),而每个头仍然保留各自独立的查询(queries)。这种方法减少了内存使用和计算成本,同时性能损失较小。
优势
内存效率:减少了注意力机制的内存开销。
计算速度:加速了注意力计算,特别是在解码器为主的模型中。
可扩展性:适用于扩展Transformer模型到更大的数据集,或在资源有限的环境中部署。
应用场景
广泛应用于机器翻译、大型语言模型以及其他需要高效序列处理的任务。
在诸如OpenAI的GPT系列和Google的Transformer模型等现代大规模模型中非常重要。
import torch
import torch.nn as nn
import torch.nn.functional as F
class MultiQueryAttention(nn.Module):
def __init__(self, embed_dim, num_heads):
"""
:param embed_dim: 输入的嵌入维度
:param num_heads: 注意力头的数量
"""
super(MultiQueryAttention, self).__init__()
assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
self.embed_dim = embed_dim
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
# 独立的查询线性变换
self.W_Q = nn.Linear(embed_dim, embed_dim)
# 共享的键和值线性变换
self.W_K = nn.Linear(embed_dim, self.head_dim)
self.W_V = nn.Linear(embed_dim, self.head_dim)
# 输出线性投影
self.W_O = nn.Linear(embed_dim, embed_dim)
def forward(self, X):
"""
:param X: 输入序列,形状为 [batch_size, seq_len, embed_dim]
:return: 输出和注意力权重
"""
batch_size, seq_len, embed_dim = X.size()
# 独立的 Query
Q = self.W_Q(X) # [batch_size, seq_len, embed_dim]
Q = Q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
# 共享的 Key 和 Value
K = self.W_K(X) # [batch_size, seq_len, head_dim]
V = self.W_V(X) # [batch_size, seq_len, head_dim]
K = K.unsqueeze(1) # [batch_size, 1, seq_len, head_dim]
V = V.unsqueeze(1) # [batch_size, 1, seq_len, head_dim]
# 缩放点积注意力
scores = torch.matmul(Q, K.transpose(-2, -1)) / torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32))
attention_weights = F.softmax(scores, dim=-1) # [batch_size, num_heads, seq_len, seq_len]
attention_output = torch.matmul(attention_weights, V) # [batch_size, num_heads, seq_len, head_dim]
# 合并多头输出
attention_output = attention_output.transpose(1, 2).contiguous().view(batch_size, seq_len, embed_dim)
# 输出投影
output = self.W_O(attention_output) # [batch_size, seq_len, embed_dim]
return output, attention_weights
补充:
Mask 的引入
在某些场景下,Attention 的权重计算需要引入 Mask,用于限制模型对某些位置的访问。两种主要的 Mask 是:
a. Padding Mask
用于避免对填充(padding)位置的注意力计算。
常用于处理不定长序列,防止对无意义的填充值进行建模。
b. Causal Mask (或 Look-Ahead Mask)
用于自回归模型,确保当前时间步 t 仅能关注时间步 ≤t 的输入。
防止未来信息泄露到当前步。
公式中的具体实现是在计算注意力分数之前,将不允许的位置设置为非常大的负值(如 −∞),这样 softmax 计算时,权重会接近零。
import torch
import torch.nn.functional as F
def masked_attention(Q, K, V, mask):
d_k = Q.size(-1) # Key 的维度
scores = torch.matmul(Q, K.transpose(-2, -1)) / torch.sqrt(torch.tensor(d_k, dtype=torch.float32))
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf')) # 应用 Mask
weights = F.softmax(scores, dim=-1) # 注意力分布
return torch.matmul(weights, V) # 加权求和得到输出
具体步骤:
mask == 0:生成一个布尔数组,表示哪些位置应该被“屏蔽”。具体来说,mask == 0 生成一个与 mask 形状相同的布尔矩阵,在需要屏蔽的地方(即 Mask 为 0 的位置)值为 True,其余为 False。
scores.masked_fill(mask == 0, float(‘-inf’)):
scores 是计算出的注意力分数,表示查询(Query)与键(Key)之间的相似度。
masked_fill 是 PyTorch 的函数,能够根据条件填充数组中的某些位置。在这里,我们将 mask == 0 对应的 scores 的位置替换为负无穷(-inf)。
为什么使用 float(‘-inf’):
使用负无穷是因为在应用 softmax 时,负无穷值会导致该位置的 softmax 权重变为零。这保证了在计算注意力时,模型不会关注该位置,也就是说,模型不会使用被遮蔽的词的表示。
举个例子,假设有一个位置被 Mask 遮蔽了,如果不进行这种处理,softmax 计算时可能会为该位置分配一些注意力权重(例如,如果它没有被屏蔽)。但通过将该位置的分数设置为负无穷,softmax 计算时该位置的权重将被压缩到几乎为零,确保该位置不会影响最终结果。
应用场景
自回归语言模型:如 GPT 和 GPT-2 中,使用 Causal Mask 确保模型生成第 t+1 个词时不访问
t+2 及以后的词。
序列分类或机器翻译:结合 Padding Mask 处理可变长度输入。