当前位置: 首页 > article >正文

图解多头注意力机制:维度变化一镜到底



一、多头注意力机制概述

多头注意力(Multi-Head Attention)是Transformer模型的核心组件,其核心思想是通过 ‌并行处理多个子空间‌ 来捕捉序列中不同位置间的复杂依赖关系。主要特点:

  • 并行计算:将高维向量拆分为多个低维子空间
  • 多视角学习:每个注意力头关注不同特征模式
  • 高效性:矩阵运算高度可并行化

在这里插入图片描述

二、代码实现

1. pyTorch 实现
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

class MultiHeadAttention(nn.Module):
    def __init__(self, embed_dim, num_heads):
        """
        Args:
            embed_dim: 词向量维度(如512)
            num_heads: 注意力头数量(如8)
        """
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads  # 每个头的维度(如512//8=64)
        
        assert self.head_dim * num_heads == embed_dim, "维度不可整除"
        
        # 定义线性变换层
        self.query = nn.Linear(embed_dim, embed_dim)  # Q矩阵
        self.key = nn.Linear(embed_dim, embed_dim)    # K矩阵
        self.value = nn.Linear(embed_dim, embed_dim)  # V矩阵
        self.out = nn.Linear(embed_dim, embed_dim)    # 输出层

    def transpose_for_scores(self, x):
        """拆分多头并调整维度顺序
        输入: [batch_size, seq_len, embed_dim]
        输出: [batch_size, num_heads, seq_len, head_dim]
        """
        new_shape = x.size()[:-1] + (self.num_heads, self.head_dim)
        x = x.view(*new_shape)  # 新增头维度
        return x.permute(0, 2, 1, 3)  # [batch, heads, seq_len, head_dim]

    def forward(self, query, key, value, mask=None):
        """前向传播流程
        输入形状: [batch_size, seq_len, embed_dim]
        输出形状: [batch_size, seq_len, embed_dim]
        """
        batch_size = query.size(0)
        
        # 1. 线性变换
        Q = self.query(query)  # [N, seq, D]
        K = self.key(key)      # [N, seq, D]
        V = self.value(value)  # [N, seq, D]

        # 2. 拆分多头
        Q = self.transpose_for_scores(Q)  # [N, h, seq, d]
        K = self.transpose_for_scores(K)  # [N, h, seq, d] 
        V = self.transpose_for_scores(V)  # [N, h, seq, d]

        # 3. 计算注意力分数
        scores = torch.matmul(Q, K.transpose(-2, -1))  # [N, h, seq_q, seq_k]
        scores /= math.sqrt(self.head_dim)  # 缩放
        
        # 4. 应用掩码(可选)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)
            
        # 5. 计算注意力权重
        attn_weights = F.softmax(scores, dim=-1)  # [N, h, seq_q, seq_k]
        
        # 6. 应用权重到Value
        out = torch.matmul(attn_weights, V)  # [N, h, seq_q, d]
        
        # 7. 合并多头
        out = out.permute(0, 2, 1, 3).contiguous()  # [N, seq_q, h, d]
        out = out.view(batch_size, -1, self.embed_dim)  # [N, seq, D]
        
        # 8. 输出层
        return self.out(out), attn_weights
2. tensorFlow实现
# TensorFlow (兼容TF2.x)

import tensorflow as tf
from tensorflow.keras.layers import Layer, Dense

class MultiHeadAttention(Layer):
    def __init__(self, embed_dim, num_heads):
        """
        Args:
            embed_dim: 词向量维度(如512)
            num_heads: 注意力头数量(如8)
        """
        super(MultiHeadAttention, self).__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        
        assert self.head_dim * num_heads == embed_dim, "维度不可整除"
        
        # 定义线性变换层
        self.query_dense = Dense(embed_dim)
        self.key_dense = Dense(embed_dim)
        self.value_dense = Dense(embed_dim)
        self.output_dense = Dense(embed_dim)
        
    def split_heads(self, x, batch_size):
        """拆分多头并调整维度顺序
        输入: [batch_size, seq_len, embed_dim]
        输出: [batch_size, num_heads, seq_len, head_dim]
        """
        x = tf.reshape(x, (batch_size, -1, self.num_heads, self.head_dim))
        return tf.transpose(x, perm=[0, 2, 1, 3])
    
    def call(self, query, key, value, mask=None):
        batch_size = tf.shape(query)
        
        # 1. 线性变换
        Q = self.query_dense(query)  # [N, seq, D]
        K = self.key_dense(key)      # [N, seq, D]
        V = self.value_dense(value)  # [N, seq, D]
        
        # 2. 拆分多头
        Q = self.split_heads(Q, batch_size)  # [N, h, seq, d]
        K = self.split_heads(K, batch_size)  # [N, h, seq, d]
        V = self.split_heads(V, batch_size)  # [N, h, seq, d]
        
        # 3. 计算注意力分数
        matmul_qk = tf.matmul(Q, K, transpose_b=True)  # [N, h, seq_q, seq_k]
        scaled_attention_logits = matmul_qk / tf.math.sqrt(tf.cast(self.head_dim, tf.float32))
        
        # 4. 应用掩码(可选)
        if mask is not None:
            scaled_attention_logits += (mask * -1e9)  # 添加极大负值
        
        # 5. 计算注意力权重
        attention_weights = tf.nn.softmax(scaled_attention_logits, axis=-1)
        
        # 6. 应用权重到Value
        output = tf.matmul(attention_weights, V)  # [N, h, seq_q, d]
        
        # 7. 合并多头
        output = tf.transpose(output, perm=[0, 2, 1, 3])  # [N, seq_q, h, d]
        concat_attention = tf.reshape(output, (batch_size, -1, self.embed_dim))
        
        # 8. 输出层
        return self.output_dense(concat_attention), attention_weights

三、维度变化全流程详解

1. 参数设定
  • batch_size = 2
  • seq_len = 5
  • embed_dim = 512
  • num_heads = 8
  • head_dim = 512 // 8 = 64
2. 维度变化流程图
原始输入: [2, 5, 512]
    │
    ├─线性变换───────保持形状→ [2, 5, 512]
    │
    ├─拆分多头──────→ [2, 8, 5, 64]
    │                (拆分512为8个64维头)
    │
    ├─计算注意力分数──→ [2, 8, 5, 5]
    │                (每个头计算5x5的注意力矩阵)
    │
    ├─Softmax───────→ [2, 8, 5, 5]
    │                (最后一维归一化)
    │
    ├─应用权重到Value→ [2, 8, 5, 64]
    │                (每个头输出新的序列表示)
    │
    ├─合并多头───────→ [2, 5, 512]
    │                (拼接8个64维头恢复512维)
    │
    └─输出层────────→ [2, 5, 512]
3. 关键步骤维度变化

在这里插入图片描述

四、关键实现细节解析

1. 多头拆分与合并
# 拆分多头(核心代码)
new_shape = x.size()[:-1] + (num_heads, head_dim)
x = x.view(*new_shape).permute(0, 2, 1, 3)

# 合并多头(逆过程)
x = x.permute(0, 2, 1, 3).contiguous().view(batch_size, -1, embed_dim)
  • 为什么要permute:将num_heads维度提前,便于后续矩阵乘法并行处理多个头
2. 注意力分数计算
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)
  • 转置维度‌:将K的seq_len和head_dim维度交换,使矩阵乘法满足[seq_q, d] x [d, seq_k] → [seq_q, seq_k]
  • 缩放因子‌:防止点积结果过大导致softmax梯度消失
3. 掩码处理技巧

python

scores = scores.masked_fill(mask == 0, -1e9)
  • 作用‌:将填充位置(如)的注意力权重趋近于0
  • 为什么用-1e9‌:经过softmax后,exp(-1e9) ≈ 0

五、完整运行示例

# 测试用例
embed_dim = 512
num_heads = 8
model = MultiHeadAttention(embed_dim, num_heads)

# 生成测试数据
batch_size = 2
seq_len = 5
inputs = torch.randn(batch_size, seq_len, embed_dim)

# 前向传播
output, attn = model(inputs, inputs, inputs)

# 验证输出形状
print(output.shape)  # torch.Size([2, 5, 512])
print(attn.shape)    # torch.Size([2, 8, 5, 5])

六、总结与常见问题

1. 核心优势
  • 并行计算效率‌:通过矩阵运算同时处理所有位置和注意力头
  • 多视角学习‌:不同注意力头可关注语法、语义等不同特征
  • 长距离依赖‌:直接计算任意两个位置间的关联
2. FAQ
  • Q1:为什么需要多个注意力头?‌

  • A:类比CNN中多个卷积核,不同头可以捕捉不同类型的特征依赖

  • Q2:head_dim为什么要设置为embed_dim/num_heads?‌

  • A:保持总参数量不变,确保拆分前后的维度乘积相等(num_heads * head_dim = embed_dim)

  • Q3:permute之后为什么要调用contiguous()?‌

  • A:确保张量在内存中连续存储,避免后续view操作报错


http://www.kler.cn/a/588358.html

相关文章:

  • 完整的模型验证套路
  • NPU的工作原理:神经网络计算的流水线
  • 计算机二级Python资料
  • RabbitMQ(补档)
  • Hive SQL 精进系列:解锁 Hive SQL 中 KeyValue 函数的强大功能
  • 微信小程序刷题逻辑实现:技术揭秘与实践分享
  • sensor数据在整个rk平台的框架流程是怎么样,
  • 业务幂等性设计的六种方案
  • 蓝桥杯[阶段总结] 二分,前缀和
  • 华为云容器引擎应用场景
  • 游戏成瘾与学习动力激发研究——多巴胺脉冲式释放与奖赏预测误差机制的神经科学解析
  • ccf3501密码
  • 计算机操作系统进程(4)
  • 【网络】什么是反向代理Reverse Proxies?
  • matlab中如何集成使用python
  • Python中在类中创建对象
  • 基于Spring Boot的航司互售系统
  • Java中队列(Queue)和列表(List)的区别
  • 基于ssm+vue汽车租赁系统
  • 量化交易学习笔记02:双均线策略