当前位置: 首页 > article >正文

【NLP251】Transformer中的Attention机制

我们现在来思考这样一个问题?

问题一:Seq2Seq有什么缺点吗?

问题二:模型具有更强的上下文权重信息会怎样?

问题一答:Seq2Seq模型进行的是跨序列的样本相关性计算,这是说,经典注意力机制考虑的是序列A的样本之于序列B的重要程度。

来看一张经典老图,这张图在上一篇文章中多次提到,其中编码器部分传入的是序列A解码器部分输出的是序列B,而黄色圈C所传递的就是序列A的样本之于序列B的特征信息。

那么问题来了,序列A各token之间的信息对于序列A本身的重要程度就不需要考虑了吗?答案当然是NO!这也是传统注意力机制的缺点,而本节我们要学习的Attention自注意力机制,就是要让A序列对自己保持注意力发现自己序列中各token的重要程度,而不仅仅是序列A对于序列B的重要程度,

问题二答:来啦来啦别催,我们来回答第二个问题, 模型具有更强的上下文权重信息会怎样?当然是有更长的上下文权重信息,模型的理解能力理解程度会更优秀,效率也会更高。这就是自注意力机制的第二个优势,对更早输入的信息保持能力更强,而传统注意力机制随着信息一步步传递权重一步步叠加对于早期信息的持有能力会更若。

OK两个问题回答完了我们步入正题,为什么要引入自注意力机制Attention这个概念呢?因为本节要讲的Transformer就是引入了Attention模块,其实再简单清晰一点就是对于上节我们提到的对于Seq2Seq模块的RNN和LSTM部分进行了替换。就是对包浆老图中的标红部分进行了替换。

Attention模块又具有相对独立性。Transformer只是将其发扬光大,之后有很多网络结构都用到了子注意力机制。如:

CNN + Attention == Transformer
Attention + LSTM == Transformer
CNN + Transformer == ViT

1. 自注意力机制

1.1 Transformer中的自注意力机制运算流程

首先,transformer当中计算的相关性被称之为是注意力分数,该注意力分数是在原始的注意力机制上修改后而获得的全新计算方式,其具体计算公式如下:
 

在这个公式中,首先我们需要将原始特征矩阵转换为Q和K。然后,令Q乘以K的转置,从而获得最基础的相关性分数。在计算出权重之后,我们还需将权重乘在样本上,以构成“上下文的复合表示”。因此,在原始特征矩阵的基础上,我们还需转化出矩阵V,用以表示原始特征所携带的信息值。假设现在我们有4个单词,每个单词被编码成了6列的词向量,那么计算Q、K、V的过程如下所示: 

现在我们已经获得了softmax之后的分数矩阵r,同时我们还有代表原始特征矩阵值的V矩阵二者相乘的结果如下:

1.2 注意力机制的分类

1.2.1 单头子注意力机制

是我们在1.1小节中所提到的子注意力机制。

1.2.2 多头自注意力机制 

Multi-Head Attention机制是在标准的self-attention机制基础上进行的一种扩展。对于输入的embedding矩阵,传统的self-attention仅使用一组权重矩阵WQ、WK和WV来转换得到Query、Keys和Values。而Multi-Head Attention则采用多组WQ、WK和WV矩阵,从而生成多组Query、Keys和Values。这一过程允许模型从不同的表示子空间中捕捉信息。对于每组Query、Keys和Values,模型分别计算得到一个Z矩阵,最终将所有Z矩阵拼接起来,形成最终的输出。值得注意的是,在Transformer的原论文中,作者使用了8组不同的WQ、WK和WV矩阵来实现这一机制。

我们以双头自注意力机制为例进行演示。在这种机制中,每个头独立地计算自注意力,然后将两个头的输出(即每个头的注意力结果)拼接起来。拼接后的向量通过一个线性层进行处理,以整合不同头捕获的信息。这种设计可以灵活地调整输出的维度,以适应不同的任务需求。Transformer模型通过这种方式处理序列数据,这些数据通常以三维张量的形式表示,包括序列长度、特征维度和批次大小。

 1.3 自注意力机制全流程

如上图详细解释单头注意力机制(Scaled Dot-Product Attention)和多头注意力机制(Multi-Head Attention)的执行流程。

1.3.1 Scaled Dot-Product Attention

  1. 输入

    • 三个主要输入:查询(Query,Q)、键(Key,K)和值(Value,V)。

  2. 矩阵乘法

    • 首先,计算查询(Q)和键(K)的矩阵乘法,得到注意力分数(Attention Scores)。

    其中 Q 和 K 是矩阵,T 表示转置。

  3. 缩放

    • 将注意力分数除以 dk​​,其中 dk​ 是键(Key)的维度。这一步是为了控制梯度的尺度,防止梯度消失或梯度爆炸。

  4. 可选的掩码

    • 应用可选的掩码(Mask),以忽略某些位置的注意力分数。这在处理序列数据时特别有用,例如在机器翻译中忽略未来位置的信息。

  5. SoftMax

    • 通过SoftMax函数将缩放后的分数转换为概率分布。

  6. 加权求和

    • 使用注意力权重对值(V)进行加权求和,得到最终的注意力输出。

1.3.2 Multi-Head Attention

多头注意力机制是单头注意力机制的扩展,它通过并行执行多个注意力机制来捕获不同的注意力模式。

  1. 线性变换

    • 对查询(Q)、键(K)和值(V)分别进行线性变换,生成多个头(h个头)的查询、键和值。

     

    其中 Wqh​、Wkh​ 和 Wvh​ 是可学习的权重矩阵。

  2. 并行执行单头注意力

    • 每个头(h个头)并行执行单头注意力机制。

    其中 i 表示第 i 个头。

  3. 拼接

    • 将所有头的注意力输出拼接在一起。

  4. 线性变换

    • 对拼接后的输出进行线性变换,得到最终的多头注意力输出。

    其中 Wo​ 是可学习的权重矩阵。

2.Attention与Transformer的关系

 Transformer的总体架构主要由两大部分构成:编码器(Encoder)和解码器(Decoder)。在Transformer中,编码是解读数据的结构,在NLP的流程中,编码器负责解构自然语言、将自然语言转化为计算机能够理解的信息,并让计算机能够学习数据、理解数据;而解码器是将被解读的信息“还原”回原始数据、或者转化为其他类型数据的结构,它可以让算法处理过的数据还原回“自然语言”,也可以将算法处理过的数据直接输出成某种结果。因此在transformer中,编码器负责接收输入数据、负责提取特征,而解码器负责输出最终的标签。当这个标签是自然语言的时候,解码器负责的是“将被处理后的信息还原回自然语言”,当这个标签是特定的类别或标签的时候,解码器负责的就是“整合信息输出统一结果”。 

下图中:

:表示上述结构(多头注意力和前馈网络,以及它们前后的加法和归一化)会被重复N次。在原始的Transformer模型中,编码器和解码器通常各有6层,所以N=6。

 

 

编码器(Encoder)的结构包含两个子层:首先是一个多头的自注意力(Self-Attention)层,其次是前馈(Feed-Forward)神经网络层。输入数据首先经过自注意力层,该层通过为输入数据中的不同信息分配重要性权重,使模型能够识别哪些信息是关键的。接下来,信息会进入前馈神经网络层,这是一个简单的全连接神经网络,功能是整合由多头注意力机制生成的信息。这两个子层均采用了残差连接(Residual Connection),即每个子层的输出都会与其输入相加,再经过层标准化(Layer Normalization)处理,从而得到最终的输出。在神经网络中,这种多头注意力机制加前馈网络的结构可以堆叠多层,在Transformer的经典架构中,编码器结构重复了6次。

解码器(Decoder)同样由多个子层构成:首先是一个多头的自注意力层(由于解码器的特性,此处的多头注意力层带有掩码),接着是一个普通的多头注意力机制层,最后是前馈神经网络层。自注意力层和前馈神经网络层的结构与编码器中的相应层相似。而注意力层则用于关注编码器的输出。同样,每个子层后都接有残差连接和层标准化。在经典的Transformer结构中,解码器也包含6层。


http://www.kler.cn/a/529155.html

相关文章:

  • 分布式事务组件Seata简介与使用,搭配Nacos统一管理服务端和客户端配置
  • 3.Spring-事务
  • Spring JDBC:简化数据库操作的利器
  • 开源的瓷砖式图像板系统Pinry
  • 需求分析应该从哪些方面来着手做?
  • 探索AI(chatgpt、文心一言、kimi等)提示词的奥秘
  • 【Proteus】NE555纯硬件实现LED呼吸灯效果,附源文件,效果展示
  • 设计心得——平衡和冗余
  • C语言:输入正整数链表并选择删除任意结点
  • ComfyUI安装调用DeepSeek——DeepSeek多模态之图形模型安装问题解决(ComfyUI-Janus-Pro)
  • 一文学会HTML编程之视频+图文详解详析
  • Selenium 使用指南:从入门到精通
  • 17.2 图形绘制8
  • ASP.NET Core与配置系统的集成
  • redex快速体验
  • 力扣动态规划-16【算法学习day.110】
  • 《苍穹外卖》项目学习记录-Day5在Java中操作Redis_Spring Data Redis
  • torch numpy seed使用方法
  • Easy系列PLC尺寸测量功能块(激光微距应用)
  • 2007-2019年各省科学技术支出数据
  • A4988一款常用的步进电机驱动芯片
  • 项目架构调整,切换版本并发布到中央仓库
  • Java篇之继承
  • Flink报错Caused by: java.io.FileNotFoundException: /home/wc.txt
  • Ubuntu16.04编译安装Cartographer 1.0版本
  • NoteGen:记录、写作与AI融合的跨端笔记应用