当前位置: 首页 > article >正文

Transformer中Self-Attention以及Multi-Head Attention模块详解(附pytorch实现)

写在前面

最近在项目中需要使用Transformer模型来处理图像任务,所以稍微补充一下这部分的知识,本篇主要了解一下Self-Attention以及Multi-Head Attention模块。

原论文链接:https://arxiv.org/pdf/1706.03762

原文代码:tensor2tensor/tensor2tensor/models/transformer.py

自注意力 Self-Attention

Self-Attention(自注意力机制)是一种动态建模输入数据内部依赖关系的方法,能够让模型关注输入数据中不同部分之间的相关性。

注意力机制的公式表示如下:

Attention(Q,K,V) =softmax(\frac{QK^{T}}{\sqrt{d_{k}}})\times V

这里我们介绍一下这个公式,假设输入有一个序列X=[x_{1},x_{2},...,x_{n}],其中x_{1}为是输入的第 i 个元素,维度为 d ,那么对于 Self-Attention,关键的公式如下:

  • 计算Query、Key和Value输入

通过线性变换分别得到Query(Q)、Key(K)和Value(V):

Q=XW_{Q}

K=XW_{K}

V=XW_{V}

其中W_{Q},W_{K},W_{V}分别是训练的权重矩阵。

  • 计算注意力分数(Attention Scores)

利用 Query Key 计算注意力分数。注意力分数是 Query Key 的点积,然后经过缩放处理(除以\sqrt{d_{k}},其中d_{k}是 Key 向量的维度)

Attention\; Scores=\frac{QK^{T}}{\sqrt{d_{k}}}

这个分数反映了序列中每一对元素之间的相似度。接下来,我们对这些分数进行归一化处理(通过 Softmax 函数):

Attention\; Weights=Softmax(\frac{QK^{T}}{\sqrt{d_{k}}})

这个步骤确保了所有注意力权重的和为 1,使得它们可以作为概率分布。

  • 计算加权和(Weighted Sum)

将注意力权重与 Value 进行加权平均,得到最终的输出:

Attention\; Output=Attention\; Weights\times V

这一步的输出是一个新的表示,它是所有输入位置的加权求和结果,权重由其与 Query 的相关性决定。

以上为公式的原理,现在我举一个实际的例子来帮助大家理解这一部分。假设我们有一个长度为 3 的输入序列,每个元素是一个 2 维的嵌入:

X = \begin{bmatrix} 1 &0 \\ 0&1 \\ 1 & 1 \end{bmatrix}

其中,X 是一个形状为 (3,2) 的矩阵,表示每个位置的 2 维特征。为了计算 Query(Q)、Key(K)和 Value(V),我们首先需要通过一组权重矩阵对输入进行线性变换。假设权重矩阵如下(为简单起见,设定为 2 x 2 矩阵):

W_{Q}=\begin{bmatrix} 1 & 1\\ 1& 0 \end{bmatrix} \qquad W_{K}=\begin{bmatrix} 0 &1 \\ 1& 1 \end{bmatrix} \qquad W_{V}=\begin{bmatrix} 1 & 0\\ 0& 1 \end{bmatrix}

这些矩阵分别用于计算 QueryKeyValue

根据上面的公式可得到:

Q=\begin{bmatrix} 1 &0 \\ 0&1 \\ 1 & 1 \end{bmatrix} \times \begin{bmatrix} 1 & 1\\ 1& 0 \end{bmatrix}=\begin{bmatrix} 1 & 1\\ 1 &0 \\ 2 & 1 \end{bmatrix}

K=\begin{bmatrix} 1 &0 \\ 0&1 \\ 1 & 1 \end{bmatrix} \times \begin{bmatrix} 0 & 1\\ 1& 1 \end{bmatrix}=\begin{bmatrix} 0 & 1\\ 1 &1 \\ 1 & 2 \end{bmatrix}

V=\begin{bmatrix} 1 &0 \\ 0&1 \\ 1 & 1 \end{bmatrix} \times \begin{bmatrix} 1 & 0\\ 0& 1 \end{bmatrix}=\begin{bmatrix} 1 & 0\\ 0 &1 \\ 1 & 1 \end{bmatrix}

得到 QueryKey 后,我们可以计算注意力分数。具体来说,对于每个 Query,我们与所有 Key 进行点积运算。计算如下:

Attention\; Scores=QK^{T}=\begin{bmatrix} 1 & 1\\ 1 &0 \\ 2 & 1 \end{bmatrix} \times \begin{bmatrix} 0 & 1 & 1\\ 1 & 1& 2 \end{bmatrix} = \begin{bmatrix} 1 & 2 &3 \\ 0&1 &1 \\ 1 & 3& 4 \end{bmatrix} 

注意,我们在这里计算的是 Query 和 Key 的点积,然后将其结果进行缩放。缩放因子通常是\sqrt{d_{k}},其中d_{k}是 Key 向量的维度。在这个例子中,Key 的维度是 2,所以缩放因子是 \sqrt{2}=1.414

Scaled\; Attention\; Scores = \frac{1}{\sqrt{2}}\times \begin{bmatrix} 1 & 2 &3 \\ 0&1 &1 \\ 1 & 3& 4 \end{bmatrix}

 即是,

Scaled\; Attention\; Scores \approx \begin{bmatrix} 0.707 & 1.414 & 2.121\\ 0 & 0.707 &0.707 \\ 0.707 &2.121 & 2.828 \end{bmatrix}

然后,我们对每行的 Scaled Attention Scores 应用 softmax 操作,得到注意力权重:

Attention\; Weights=\begin{bmatrix} 0.1401 & 0.2840 &0.5759 \\ 0.1978 & 0.4011&0.4011 \\ 0.0743& 0.3057 & 0.6200 \end{bmatrix}

最后,使用注意力权重对 Value(V)进行加权求和:

Attention\; Output=\begin{bmatrix} 0.1401 & 0.2840 &0.5759 \\ 0.1978 & 0.4011&0.4011 \\ 0.0743& 0.3057 & 0.6200 \end{bmatrix}\times \begin{bmatrix} 1 & 0\\ 0 &1 \\ 1 & 1 \end{bmatrix} \newline = \begin{bmatrix} 0.7160 &0.8599 \\ 0.5989 &0.8022 \\ 0.6943 & 0.9257 \end{bmatrix}

 最后两步的计算结果用我下面给的代码跑一下就知道了。 

import torch
import torch.nn as nn

class Softmax(nn.Module):
    def __init__(self, dim=-1):
        super().__init__()
        self.dim = dim

    def _softmax(self, x):
        exp_x = torch.exp(x)
        softmax = exp_x / torch.sum(exp_x, dim=self.dim, keepdim=True)
        return softmax

    def forward(self, x):
        return self._softmax(x)

matrix = torch.tensor([[0.707, 1.414, 2.121], [0, 0.707, 0.707], [0.707, 2.121, 2.828]])
Value = torch.tensor([[1, 0],[0, 1], [1, 1]], dtype=torch.float32)
print(matrix, "\n", Value)
softmax = Softmax()
score = softmax(matrix)
print(score)
result = torch.matmul(score, Value)
print(result)

多头注意力 Multi-Head Attention

在多头注意力机制中,模型会并行地计算多个不同的注意力头,每个头都有自己独立的 Query、Key 和 Value 权重,然后将每个头的输出连接起来,并通过一个线性变换得到最终的结果。

  • 分割 Query、Key 和 Value 成多个头

对于每个头i,分别计算独立的 QueryKey Value

Q_{i}=XW_{Q_{i}} \qquad K_{i}=XW_{K_{i}}\qquad V_{i}=XW_{V_{i}}

  • 计算每个头的注意力输出

Attention\; Output_{i} =Softmax(\frac{Q_{i}K_{i}^{T}}{\sqrt{d_{k}}})\times V_{i}

  • 拼接各头的输出

将所有头的输出拼接在一起,h 是头的数量

Multi-Head Output=Concat(Attention\;Output_{1}, Attention\;Output_{2},...,Attention\;Output_{h})

  • 最终输出

对拼接后的结果进行一次线性变换,得到最终的多头注意力输出。

这里我们不再举例了,下面你可以根据下面的代码进行测试。

注意力pytorch实现

import torch
import torch.nn as nn
import torch.nn.functional as F

class MultiHeadAttention(nn.Module):
    def __init__(self, embed_dim, num_heads):
        """
        初始化 MultiHeadAttention 模块
        :param embed_dim: 输入嵌入的特征维度
        :param num_heads: 注意力头的数量
        """
        super(MultiHeadAttention, self).__init__()
        assert embed_dim % num_heads == 0, "embed_dim 必须能被 num_heads 整除"

        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads  # 每个头的特征维度

        # 定义 Query, Key 和 Value 的线性变换
        self.q_linear = nn.Linear(embed_dim, embed_dim)
        self.k_linear = nn.Linear(embed_dim, embed_dim)
        self.v_linear = nn.Linear(embed_dim, embed_dim)

        # 输出的线性变换
        self.out_linear = nn.Linear(embed_dim, embed_dim)

    def forward(self, x):
        """
        :param x: 输入张量,形状为 (batch_size, seq_len, embed_dim)
        :return: 注意力后的输出,形状为 (batch_size, seq_len, embed_dim)
        """
        batch_size, seq_len, embed_dim = x.size()
        # 生成 Query, Key, Value (batch_size, seq_len, embed_dim)
        Q = self.q_linear(x)
        K = self.k_linear(x)
        V = self.v_linear(x)

        # 分成多头 (batch_size, num_heads, seq_len, head_dim)
        Q = Q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        K = K.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        V = V.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        # 计算注意力分数 (batch_size, num_heads, seq_len, seq_len)
        attention_scores = torch.matmul(Q, K.transpose(-2, -1)) / torch.sqrt(
            torch.tensor(self.head_dim, dtype=torch.float32))
        attention_weights = F.softmax(attention_scores, dim=-1)
        # 加权求和 (batch_size, num_heads, seq_len, head_dim)
        attention_output = torch.matmul(attention_weights, V)
        # 拼接多头输出 (batch_size, seq_len, embed_dim
        attention_output = attention_output.transpose(1, 2).contiguous().view(batch_size, seq_len, embed_dim)
        # 输出线性变换 (batch_size, seq_len, embed_dim)
        output = self.out_linear(attention_output)
        return output

class SelfAttention(MultiHeadAttention):
    def __init__(self, embed_dim):
        """
        初始化 SelfAttention 模块
        :param embed_dim: 输入嵌入的特征维度
        """
        super(SelfAttention, self).__init__(embed_dim, num_heads=1)

    def forward(self, x):
        """
        :param x: 输入张量,形状为 (batch_size, seq_len, embed_dim)
        :return: 注意力后的输出,形状为 (batch_size, seq_len, embed_dim)
        """
        return super(SelfAttention, self).forward(x)

if __name__ == "__main__":
    embed_dim = 64  # 输入特征维度
    num_heads = 8  # 注意力头的数量
    model = SelfAttention(embed_dim)
    multi_model = MultiHeadAttention(embed_dim, num_heads)
    batch_size = 2
    seq_len = 10
    x = torch.rand(batch_size, seq_len, embed_dim)
    output = model(x)
    output2 = multi_model(x)
    print("输出形状:", output.shape)  # 应为 (batch_size, seq_len, embed_dim)
    print("输出形状:", output2.shape)

以上实现,与torch官方内部内部的实现略有不同,官方提供了一个实现好的多头注意力模块 torch.nn.MultiheadAttention。这个实现做了很多优化,比如对于输入和输出的形状、注意力分数的计算以及参数的处理,都进行了更加简化和高效的实现。官方实现默认要求输入形状为 (seq_len, batch_size, embed_dim)(但是实际可以通过参数batch_first来修改),这是因为官方实现是为实现批量并行化优化的。官方实现直接接受 Query、Key 和 Value 作为三个输入张量。

官方写的很好,但可能不够直观,我这里写的就比较的简洁了,并且由于内部实现有些许不同,我这个无法与其进行比较是否相同。下面是我做测试时候写的草稿,大家将就着看吧。

import torch
import torch.nn as nn
from model import MultiHeadAttention

class OfficialMultiheadAttention(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super(OfficialMultiheadAttention, self).__init__()
        self.multihead_attn = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True)

    def forward(self, x):
        output, attention_weights = self.multihead_attn(x, x, x)
        # 检查每行的和是否为1
        attention_weights_sum = torch.sum(attention_weights, dim=-1)
        print("官方实现 - 每行的和:", attention_weights_sum)
        print("官方实现 - 每行的和是否为1:", torch.allclose(attention_weights_sum, torch.ones_like(attention_weights_sum)))

        return output


if __name__=="__main__":
    embed_dim = 8  # 输入特征维度
    num_heads = 2  # 注意力头的数量
    batch_size = 2
    seq_len = 5
    x = torch.rand(batch_size, seq_len, embed_dim)
    my_attention_model = MultiHeadAttention(embed_dim, num_heads)
    official_attention_model = OfficialMultiheadAttention(embed_dim, num_heads)
    my_attention_output = my_attention_model(x)
    official_attention_output = official_attention_model(x)
    print("SelfAttention 输出形状:", my_attention_output.shape)  # 应为 (batch_size, seq_len, embed_dim)
    print("Official MultiheadAttention 输出形状:", official_attention_output.shape)  # 应为 (batch_size, seq_len, embed_dim)
    is_same = torch.allclose(my_attention_output, official_attention_output, atol=1e-2)
    print("两个输出是否相同:", is_same)

总结

尽管官方的 MultiheadAttention 模块经过优化,具有更高的效率,但手动实现能够帮助大家更好地理解多头注意力机制的各个计算步骤。通过这些实验,我们不仅深入了解了注意力机制的原理,还能在实际应用中灵活使用这些机制,尤其是在图像任务中,Transformer 的强大能力得到了广泛的应用。

参考文章

详解Transformer中Self-Attention以及Multi-Head Attention_transformer multi head-CSDN博客

第四篇:一文搞懂Transformer架构的三种注意力机制_c3tr 注意力 详解-CSDN博客 

一文搞定自注意力机制(Self-Attention)-CSDN博客 

十分推荐的参考视频:Transformer中Self-Attention以及Multi-Head Attention详解_哔哩哔哩_bilibili


http://www.kler.cn/a/466851.html

相关文章:

  • Android12 App窗口创建流程
  • Vue2中使用Echarts
  • umd格式
  • 基于YOLO11的道路缺陷检测系统
  • 基于springboot的课程作业管理系统(源码+数据库+文档)
  • 【小程序开发】- 小程序版本迭代指南(版本发布教程)
  • web漏洞之文件包含漏洞
  • [网络安全]DVWA之SQL注入—low level解题详析
  • Spring Boot自动装配代码详解
  • python +tkinter绘制彩虹和云朵
  • 2025年股指期货每月什么时候交割?
  • 探索光耦:光耦在风力发电中的应用——保障绿色能源的高效与安全
  • ubuntu16 重启之后lvm信息丢失故障恢复
  • Eureka 介绍与原理详解
  • 记录:导出功能:接收文件流数据进行导出(vue3)
  • Jdk动态代理源码缓存优化比较(JDK17比JDK8)
  • 推荐一些关于C#中LINQ的学习资料
  • Qt窗口获取Tftpd32_svc服务下载信息
  • [redux] ts声明useSelector和useDispatch
  • 嵌入式 Linux LED 驱动开发实验
  • 运维工具汇总
  • 【数据分析实战】24年T4某二手车交易平台数据分析
  • 【机器学习】【朴素贝叶斯分类器】从理论到实践:朴素贝叶斯分类器在垃圾短信过滤中的应用
  • 【数据库】简答题汇总
  • 力扣28找出字符串中第一个匹配项的下标
  • PyTorch中的__init__.pyi文件:作用与C++实现关系解析