当前位置: 首页 > article >正文

Transformer 模型介绍(四)——编码器 Encoder 和解码器 Decoder

上篇中讲完了自注意力机制 Self-Attention 和多头注意力机制 Multi-Head Attention,这是 Transformer 核心组成部分之一,在此基础上,进一步展开讲一下编码器-解码器结构(Encoder-Decoder Architecture)

Transformer 模型由以下两个主要部分组成:

  • 编码器(Encoder):负责处理输入句子,将其转化为一个上下文丰富的表示
  • 解码器(Decoder):根据编码器生成的上下文向量来生成目标语言句子

目录

1 编码器 Encoder

1.1 整体结构

1.2 多头注意力机制 Multi-Head Attention

1.3 残差连接与归一化 Add & Norm

1.4 前馈神经网络 FFN

1.5 重复堆叠

2 解码器 Decoder

2.1 整体结构

2.2 掩码多头注意力 Masked Multi-Head Attention

2.3 编码器-解码器注意力 Encoder-Decoder Attention

2.4 前馈神经网络 FFN

2.5 堆叠与生成

3 小结


1 编码器 Encoder

1.1 整体结构

在Transformer模型中,编码器(Encoder)部分是至关重要的组件,它负责接收输入序列并对其进行逐层处理,生成高质量的序列表示。与传统的序列到序列模型(如RNN、LSTM)不同,Transformer的编码器不依赖于递归结构,而是通过自注意力机制实现对序列中各个单词的动态关联。具体而言:

  • 编码器由多个相同的层堆叠而成,每个层由两个主要部分组成:多头自注意力(Multi-Head Self-Attention)机制和前馈神经网络(Feed-Forward Neural Network, FFNN)
  • 每个子层后都会接一个归一化操作,旨在稳定训练过程,确保模型在深度训练时不会出现梯度爆炸或消失的情况
  • 多头注意力机制由多头注意力层和归一化处理相连接,接着是一个全连接的前馈网络,共同构成了编码器的核心结构
  • 多头注意力层对输入句子中的特定单词向量计算注意力分数,将所有单词向量的注意力分数编码为一个新的隐藏状态向量,发送到前馈神经网络,进行线性映射
  • 多头注意力层根据输入句子中不同的单词得出不同的注意力分数,因此其权重参数不同,而前馈神经网络对输入句子中不同的单词应用完全相同的权重参数

1.2 多头注意力机制 Multi-Head Attention

上一篇中已详细解释了多头注意力机制,此处不再赘述

1.3 残差连接与归一化 Add & Norm

为了避免深度网络中的梯度消失问题,残差连接被引入到每个子层中。具体来说,每个子层(包括多头自注意力和前馈神经网络)之后都有一个残差连接,将该子层的输入与输出相加,从而保留输入的信息。这种设计有助于缓解深度神经网络训练时梯度消失或爆炸的风险,使得模型能够稳定地训练

在残差连接之后,应用层归一化(Layer Normalization)对数据进行规范化,确保每一层的输出数据分布保持稳定。层归一化的作用是将每一层的输出数据重新调整,使得其均值为0,方差为1,从而加快训练过程,并提高训练的稳定性。归一化操作有助于减少不同训练阶段可能引起的数据分布偏差,提升模型的泛化能力

1.4 前馈神经网络 FFN

每个编码器层的第二个子层是一个全连接的前馈神经网络(FFN)。前馈神经网络的作用是引入非线性转换,以提升模型的表达能力,学习更加复杂的特征

在标准的 FFN 中,通常包含两个线性变换层,中间夹着一个非线性激活函数(如 ReLU)。FFN 的结构通过对输入信息进行两次线性变换和一次非线性映射,能够进一步丰富输入特征的表示

前馈神经网络的具体计算过程如下:

\text{FFN}(x) = \max(0, xW_1 + b_1)W_2 + b_2

其中,W1​ W2 是学习得到的权重矩阵,b1​ 和 b2​ 是偏置项。通过这种非线性映射,FFN 能够学习到更复杂的模式和特征,使得模型具有更强的拟合能力

1.5 重复堆叠

上述结构(多头自注意力、前馈神经网络、归一化和残差连接)会在整个Transformer编码器中重复多次(通常为6次),每次迭代都会对输入的序列表示进行更加深入的处理。通过这种逐层堆叠的方式,模型能够不断地提取更加高层次的特征表示,从而构建出对输入序列的深刻理解

每一层的输出会作为下一层的输入,逐层传递,通过多次迭代,模型逐步提升对输入序列的表示能力

2 解码器 Decoder

2.1 整体结构

Transformer 模型的解码器(Decoder)负责生成输出序列,其设计具有独特性,旨在确保输出的顺序性并有效利用编码器产生的上下文信息

解码器与编码器类似,也由多个相同的层堆叠而成,每一层包含三个关键模块

编码器与解码器的最大不同之处是,解码器使用了红色箭头所指的掩码多头注意力机制

2.2 掩码多头注意力 Masked Multi-Head Attention

解码器的第一部分是一个特殊的多头自注意力(Masked Multi-Head Attention)层,旨在引入“未来遮蔽”(Future Masking)机制。这意味着在计算当前单词(或Token)的注意力分数时,模型不会访问未来的词,从而保证了解码过程中的时序性

具体来说,掩码多头注意力机制的工作原理是:在解码时,模型只能使用当前和过去的单词信息来预测下一个单词,而不能提前看到未来的单词。通过将当前单词的查询向量(Query)与未来单词的键向量(Key)遮蔽(即将其置为负无穷大),模型只能基于已生成的单词来计算注意力分数

这种设计确保了生成过程是顺序的,避免了信息泄露,确保了每一步的预测只能基于先前生成的内容。这种“未来遮蔽”机制是Transformer解码器与编码器的一个关键区别

具体来说,使用掩码矩阵(Mask Matrix)用于阻止模型在计算注意力分数时访问未来的位置。掩码矩阵的维度与输入矩阵相同,在该矩阵中,未来位置的值被设置为负无穷(-inf)。通过这种方式,模型在进行 softmax 操作时,未来位置的注意力分数会趋近于零,确保解码器不会利用未来的信息。对于计算得到的注意力分数矩阵 Q \cdot K^T,我们将其与掩码矩阵按位相乘,得到掩码后的注意力分数矩阵 \text{Mask}(Q \cdot K^T)

\text{Mask}(Q \cdot K^T) = (Q \cdot K^T) \cdot M

其中,M 是掩码矩阵,在掩码矩阵中,未来位置的值为0,已生成位置的值为1。通过按位相乘,未来的位置的注意力分数会被置为负无穷

2.3 编码器-解码器注意力 Encoder-Decoder Attention

解码器的第二部分是编码器-解码器注意力(Encoder-Decoder Attention),这是 Transformer 解码器的重要组成部分。在这个层中,查询(Query)矩阵来自解码器的前一层输出,而键(Key)和值(Value)矩阵则直接来自编码器的最终输出矩阵 C

  • 这种设置使得解码器能够根据当前的解码状态,从编码器生成的全局上下文中提取相关信息
  • 解码器能够在生成每个新的单词时,有选择地从编码器的输出中提取上下文信息,以帮助生成更准具有全局一致性的输出序列

2.4 前馈神经网络 FFN

解码器的最后一个模块与编码器相同,是一个全连接的前馈网络

和编码器一样,解码器的每层之间也通过跨层方法,如残差连接(Residual Connections)和层归一化(Layer Normalization)相连,以促进梯度流动并保持输出的稳定性

2.5 堆叠与生成

Transformer解码器与编码器一样,由多个相同的解码器层堆叠而成(通常为6层)。每一层的输出作为下一层的输入,在每一层中,解码器会逐步地生成更丰富的序列表示

通过多层堆叠和反复处理,解码器能够将前面生成的单词信息与编码器提供的上下文信息进行结合,逐步生成一个完整的输出序列。在生成过程中,每一层都能够精细调整输出序列,确保生成的内容在语法和语义上与目标序列保持一致

3 小结

最终的 Transformer 模型由6层网络结构堆叠而成

从整体上理解,Transformer的 架构设计简洁且高效,主要由以下几个模块组成:将多个 self-attention 堆成 Multi-Head Attention,加上 Add & Norm 就构成了 Encoder。经过掩码操作后的Masked Multi-Head Attention 加上 Encoder 同款结构,就构成了 Decoder


http://www.kler.cn/a/548776.html

相关文章:

  • redis cluster测试
  • 基于Istio Ambient Mesh的无边车架构:实现零侵入式服务网格的云原生革命
  • Android remount failed: Permission denied 失败解决方法
  • 【前端框架】Vue 3组件生命周期钩子的使用场景
  • 3.5 企业级AI Agent运维体系构建:从容器化部署到智能监控的工业级实践指南
  • 基于单片机的日程管理系统设计
  • 报错 - 你不能打开应用程序“Docker.app”,因为它没有响应
  • 用Python构建Mad Libs经典文字游戏
  • Android 13 媒体权限适配指南
  • CMake无法生成可执行文件,一直生成库文件
  • Qt QDateTimeEdit总结
  • Android:播放Rtsp视频流的两种方式
  • 在 Go 项目中实现 JWT 用户认证与续期机制
  • 总结前端常用数据结构 之 数组篇【JavaScript -包含常用数组方法】
  • easyCode代码模板配置
  • Mybatisplus自定义sql
  • 双指针-三数之和
  • 机器视觉--switch语句
  • 海尔小红书年度规划方案拆解
  • 使用 Ansys Fluent 进行电池热滥用失控传播仿真