图解多头注意力机制:维度变化一镜到底
目录
- 一、多头注意力机制概述
- 二、代码实现
- 1. pyTorch 实现
- 2. tensorFlow实现
- 三、维度变化全流程详解
- 1. 参数设定
- 2. 维度变化流程图
- 3. 关键步骤维度变化
- 四、关键实现细节解析
- 1. 多头拆分与合并
- 2. 注意力分数计算
- 3. 掩码处理技巧
- 五、完整运行示例
- 六、总结与常见问题
- 1. 核心优势
- 2. FAQ
一、多头注意力机制概述
多头注意力(Multi-Head Attention)是Transformer模型的核心组件,其核心思想是通过 并行处理多个子空间 来捕捉序列中不同位置间的复杂依赖关系。主要特点:
- 并行计算:将高维向量拆分为多个低维子空间
- 多视角学习:每个注意力头关注不同特征模式
- 高效性:矩阵运算高度可并行化
二、代码实现
1. pyTorch 实现
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
class MultiHeadAttention(nn.Module):
def __init__(self, embed_dim, num_heads):
"""
Args:
embed_dim: 词向量维度(如512)
num_heads: 注意力头数量(如8)
"""
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads # 每个头的维度(如512//8=64)
assert self.head_dim * num_heads == embed_dim, "维度不可整除"
# 定义线性变换层
self.query = nn.Linear(embed_dim, embed_dim) # Q矩阵
self.key = nn.Linear(embed_dim, embed_dim) # K矩阵
self.value = nn.Linear(embed_dim, embed_dim) # V矩阵
self.out = nn.Linear(embed_dim, embed_dim) # 输出层
def transpose_for_scores(self, x):
"""拆分多头并调整维度顺序
输入: [batch_size, seq_len, embed_dim]
输出: [batch_size, num_heads, seq_len, head_dim]
"""
new_shape = x.size()[:-1] + (self.num_heads, self.head_dim)
x = x.view(*new_shape) # 新增头维度
return x.permute(0, 2, 1, 3) # [batch, heads, seq_len, head_dim]
def forward(self, query, key, value, mask=None):
"""前向传播流程
输入形状: [batch_size, seq_len, embed_dim]
输出形状: [batch_size, seq_len, embed_dim]
"""
batch_size = query.size(0)
# 1. 线性变换
Q = self.query(query) # [N, seq, D]
K = self.key(key) # [N, seq, D]
V = self.value(value) # [N, seq, D]
# 2. 拆分多头
Q = self.transpose_for_scores(Q) # [N, h, seq, d]
K = self.transpose_for_scores(K) # [N, h, seq, d]
V = self.transpose_for_scores(V) # [N, h, seq, d]
# 3. 计算注意力分数
scores = torch.matmul(Q, K.transpose(-2, -1)) # [N, h, seq_q, seq_k]
scores /= math.sqrt(self.head_dim) # 缩放
# 4. 应用掩码(可选)
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
# 5. 计算注意力权重
attn_weights = F.softmax(scores, dim=-1) # [N, h, seq_q, seq_k]
# 6. 应用权重到Value
out = torch.matmul(attn_weights, V) # [N, h, seq_q, d]
# 7. 合并多头
out = out.permute(0, 2, 1, 3).contiguous() # [N, seq_q, h, d]
out = out.view(batch_size, -1, self.embed_dim) # [N, seq, D]
# 8. 输出层
return self.out(out), attn_weights
2. tensorFlow实现
# TensorFlow (兼容TF2.x)
import tensorflow as tf
from tensorflow.keras.layers import Layer, Dense
class MultiHeadAttention(Layer):
def __init__(self, embed_dim, num_heads):
"""
Args:
embed_dim: 词向量维度(如512)
num_heads: 注意力头数量(如8)
"""
super(MultiHeadAttention, self).__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
assert self.head_dim * num_heads == embed_dim, "维度不可整除"
# 定义线性变换层
self.query_dense = Dense(embed_dim)
self.key_dense = Dense(embed_dim)
self.value_dense = Dense(embed_dim)
self.output_dense = Dense(embed_dim)
def split_heads(self, x, batch_size):
"""拆分多头并调整维度顺序
输入: [batch_size, seq_len, embed_dim]
输出: [batch_size, num_heads, seq_len, head_dim]
"""
x = tf.reshape(x, (batch_size, -1, self.num_heads, self.head_dim))
return tf.transpose(x, perm=[0, 2, 1, 3])
def call(self, query, key, value, mask=None):
batch_size = tf.shape(query)
# 1. 线性变换
Q = self.query_dense(query) # [N, seq, D]
K = self.key_dense(key) # [N, seq, D]
V = self.value_dense(value) # [N, seq, D]
# 2. 拆分多头
Q = self.split_heads(Q, batch_size) # [N, h, seq, d]
K = self.split_heads(K, batch_size) # [N, h, seq, d]
V = self.split_heads(V, batch_size) # [N, h, seq, d]
# 3. 计算注意力分数
matmul_qk = tf.matmul(Q, K, transpose_b=True) # [N, h, seq_q, seq_k]
scaled_attention_logits = matmul_qk / tf.math.sqrt(tf.cast(self.head_dim, tf.float32))
# 4. 应用掩码(可选)
if mask is not None:
scaled_attention_logits += (mask * -1e9) # 添加极大负值
# 5. 计算注意力权重
attention_weights = tf.nn.softmax(scaled_attention_logits, axis=-1)
# 6. 应用权重到Value
output = tf.matmul(attention_weights, V) # [N, h, seq_q, d]
# 7. 合并多头
output = tf.transpose(output, perm=[0, 2, 1, 3]) # [N, seq_q, h, d]
concat_attention = tf.reshape(output, (batch_size, -1, self.embed_dim))
# 8. 输出层
return self.output_dense(concat_attention), attention_weights
三、维度变化全流程详解
1. 参数设定
- batch_size = 2
- seq_len = 5
- embed_dim = 512
- num_heads = 8
- head_dim = 512 // 8 = 64
2. 维度变化流程图
原始输入: [2, 5, 512]
│
├─线性变换───────保持形状→ [2, 5, 512]
│
├─拆分多头──────→ [2, 8, 5, 64]
│ (拆分512为8个64维头)
│
├─计算注意力分数──→ [2, 8, 5, 5]
│ (每个头计算5x5的注意力矩阵)
│
├─Softmax───────→ [2, 8, 5, 5]
│ (最后一维归一化)
│
├─应用权重到Value→ [2, 8, 5, 64]
│ (每个头输出新的序列表示)
│
├─合并多头───────→ [2, 5, 512]
│ (拼接8个64维头恢复512维)
│
└─输出层────────→ [2, 5, 512]
3. 关键步骤维度变化
四、关键实现细节解析
1. 多头拆分与合并
# 拆分多头(核心代码)
new_shape = x.size()[:-1] + (num_heads, head_dim)
x = x.view(*new_shape).permute(0, 2, 1, 3)
# 合并多头(逆过程)
x = x.permute(0, 2, 1, 3).contiguous().view(batch_size, -1, embed_dim)
- 为什么要permute:将num_heads维度提前,便于后续矩阵乘法并行处理多个头
2. 注意力分数计算
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)
- 转置维度:将K的seq_len和head_dim维度交换,使矩阵乘法满足[seq_q, d] x [d, seq_k] → [seq_q, seq_k]
- 缩放因子:防止点积结果过大导致softmax梯度消失
3. 掩码处理技巧
python
scores = scores.masked_fill(mask == 0, -1e9)
- 作用:将填充位置(如)的注意力权重趋近于0
- 为什么用-1e9:经过softmax后,exp(-1e9) ≈ 0
五、完整运行示例
# 测试用例
embed_dim = 512
num_heads = 8
model = MultiHeadAttention(embed_dim, num_heads)
# 生成测试数据
batch_size = 2
seq_len = 5
inputs = torch.randn(batch_size, seq_len, embed_dim)
# 前向传播
output, attn = model(inputs, inputs, inputs)
# 验证输出形状
print(output.shape) # torch.Size([2, 5, 512])
print(attn.shape) # torch.Size([2, 8, 5, 5])
六、总结与常见问题
1. 核心优势
- 并行计算效率:通过矩阵运算同时处理所有位置和注意力头
- 多视角学习:不同注意力头可关注语法、语义等不同特征
- 长距离依赖:直接计算任意两个位置间的关联
2. FAQ
-
Q1:为什么需要多个注意力头?
-
A:类比CNN中多个卷积核,不同头可以捕捉不同类型的特征依赖
-
Q2:head_dim为什么要设置为embed_dim/num_heads?
-
A:保持总参数量不变,确保拆分前后的维度乘积相等(num_heads * head_dim = embed_dim)
-
Q3:permute之后为什么要调用contiguous()?
-
A:确保张量在内存中连续存储,避免后续view操作报错