当前位置: 首页 > article >正文

基于 Transformer 的语言模型

基于 Transformer 的语言模型

Transformer 是一类基于注意力机制(Attention)的模块化构建的神经网络结构。给定一个序列,Transformer 将一定数量的历史状态和当前状态同时输入,然后进行加权相加。对历史状态和当前状态进行“通盘考虑”,然后对未来状态进行预测。

基于 Transformer 的语言模型,以词序列作为输入,基于一定长度的上文和当前词来预测下一个词出现的概率。

Transformer

Transformer 模型是由这两种核心模块构建的模块化网络结构,它们共同构成了模型的主体。下面我将详细介绍这两种模块:

1. 注意力(Attention)模块

注意力模块是 Transformer 中的核心,它允许模型在序列的不同位置之间动态地分配不同的注意力权重,从而捕捉序列内部的依赖关系。

  • 自注意力层(Self-Attention Layer)

  • 残差连接(Residual Connections)

  • 层正则化(Layer Normalization)

2. 全连接前馈(Fully-connected Feedforward)模块

全连接前馈模块对自注意力层的输出进行进一步的处理。

  • 全连接前馈层

  • 残差连接

  • 层正则化

Transformer 模型通过堆叠多个这样的注意力模块和全连接前馈模块来构建深层网络,每个模块都可以并行处理序列中的所有位置,这使得 Transformer 模型在处理序列数据时非常高效。此外,Transformer 模型不依赖于循环或卷积结构,这使得它在处理长距离依赖问题时比传统的RNN更加有效。

注意力模块与全连接前馈模块在这里插入图片描述
注意力机制示意图
在这里插入图片描述

自注意力层的工作原理

  1. 输入编码
    • 输入序列被编码为 Query(Q)、Key(K)和 Value(V)三部分。这些编码通常是通过与权重矩阵相乘得到的。
  2. 计算注意力权重
    • 通过计算 Query(Q)和 Key(K)之间的相似度来确定注意力权重 α \alpha α。这个相似度通常使用点积(dot product)来计算,即 sim ( q t , k i ) = q t ⋅ k i \text{sim}(q_t, k_i) = q_t \cdot k_i sim(qt,ki)=qtki
  3. 应用 softmax 函数
    • 计算得到的相似度分数通过 softmax 函数转换为概率分布,确保所有权重的和为1。这允许模型在不同位置之间动态地分配注意力。
  4. 加权求和
    • 使用得到的注意力权重 α \alpha α 对 Value(V)进行加权求和,得到最终的输出。
自注意力层的计算公式

自注意力层的计算可以表示为:
Attention ( Q , K , V ) = softmax ( Q K T d k ) V \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V Attention(Q,K,V)=softmax(dk QKT)V
其中:

  • Q Q Q 是 Query 矩阵。
  • K K K 是 Key 矩阵。
  • V V V 是 Value 矩阵。
  • d k d_k dk 是 Key 向量的维度。
  • d k \sqrt{d_k} dk 是为了稳定梯度,避免点积结果过大。
具体步骤
  1. 计算相似度
    sim ( q t , k i ) = q t ⋅ k i \text{sim}(q_t, k_i) = q_t \cdot k_i sim(qt,ki)=qtki
  2. 应用 softmax
    α t , i = exp ⁡ ( sim ( q t , k i ) ) ∑ i = 1 t exp ⁡ ( sim ( q t , k i ) ) \alpha_{t,i} = \frac{\exp(\text{sim}(q_t, k_i))}{\sum_{i=1}^{t} \exp(\text{sim}(q_t, k_i))} αt,i=i=1texp(sim(qt,ki))exp(sim(qt,ki))
  3. 加权求和
    Attention ( x t ) = ∑ i = 1 t α t , i v i \text{Attention}(x_t) = \sum_{i=1}^{t} \alpha_{t,i} v_i Attention(xt)=i=1tαt,ivi

全连接前馈层的工作原理

全连接前馈层通常包含两个线性变换,中间夹着一个非线性激活函数。这种结构使得模型能够学习输入数据的非线性表示。

  1. 第一层线性变换
    • 输入向量首先通过一个线性变换,通常表示为 W 1 v + b 1 W_1v + b_1 W1v+b1,其中 W 1 W_1 W1是权重矩阵, b 1 b_1 b1是偏置项。
  2. 非线性激活函数
    • 第一层线性变换的输出通过一个非线性激活函数,常用的激活函数是ReLU(Rectified Linear Unit),即 max ⁡ ( 0 , x ) \max(0, x) max(0,x)。这使得模型能够引入非线性,增加模型的表达能力。
  3. 第二层线性变换
    • 经过激活函数处理的输出再通过另一个线性变换,通常表示为 W 2 W_2 W2 乘以激活函数的输出加上偏置项 b 2 b_2 b2
全连接前馈层的计算公式

全连接前馈层的计算可以表示为:
FFN ( v ) = W 2 max ⁡ ( 0 , W 1 v + b 1 ) + b 2 \text{FFN}(v) = W_2 \max(0, W_1v + b_1) + b_2 FFN(v)=W2max(0,W1v+b1)+b2
其中:
- v v v 是输入向量。
- W 1 W_1 W1 W 2 W_2 W2 是权重矩阵。
- b 1 b_1 b1 b 2 b_2 b2 是偏置项。
- max ⁡ ( 0 , x ) \max(0, x) max(0,x) 是ReLU激活函数,它将所有负值置为0。

具体步骤
  1. 第一层线性变换
    z 1 = W 1 v + b 1 z_1 = W_1v + b_1 z1=W1v+b1
  2. 应用ReLU激活函数
    a 1 = max ⁡ ( 0 , z 1 ) a_1 = \max(0, z_1) a1=max(0,z1)
  3. 第二层线性变换
    FFN ( v ) = W 2 a 1 + b 2 \text{FFN}(v) = W_2a_1 + b_2 FFN(v)=W2a1+b2
    这种全连接前馈层的设计使得 Transformer 模型能够对序列中的每个位置进

层正则化的工作原理

层正则化通过对每个子层的输出进行归一化处理,使得每个样本的隐藏状态具有相同的均值和标准差,从而减少不同层之间的差异,使得训练更加稳定。

  1. 计算均值和标准差
    • 对于每个子层的输出,计算其均值 μ \mu μ 和标准差 δ \delta δ
  2. 归一化处理
    • 使用均值和标准差对输出进行归一化处理,使得归一化后的输出具有零均值和单位标准差。
  3. 缩放和平移
    • 通过两个可学习的参数 α \alpha α β \beta β 对归一化后的输出进行缩放和平移,以恢复模型的表达能力。
层正则化的计算公式

层正则化的计算可以表示为:
LN ( v i ) = α ( v i − μ δ ) + β \text{LN}(v_i) = \alpha \left( \frac{v_i - \mu}{\delta} \right) + \beta LN(vi)=α(δviμ)+β
其中:
- v i v_i vi 是子层的输出。
- μ \mu μ v i v_i vi 的均值。
- δ \delta δ v i v_i vi 的标准差。
- α \alpha α β \beta β 是可学习的参数,用于缩放和平移归一化后的输出。

具体步骤
  1. 计算均值和标准差
    μ = 1 N ∑ i = 1 N v i \mu = \frac{1}{N} \sum_{i=1}^{N} v_i μ=N1i=1Nvi
    δ = 1 N ∑ i = 1 N ( v i − μ ) 2 + ϵ \delta = \sqrt{\frac{1}{N} \sum_{i=1}^{N} (v_i - \mu)^2 + \epsilon} δ=N1i=1N(viμ)2+ϵ
    其中 N N N 是样本数量,( \epsilon$ 是一个很小的常数,用于防止除以零。
  2. 归一化处理
    v i ^ = v i − μ δ \hat{v_i} = \frac{v_i - \mu}{\delta} vi^=δviμ
  3. 缩放和平移
    LN ( v i ) = α v i ^ + β \text{LN}(v_i) = \alpha \hat{v_i} + \beta LN(vi)=αvi^+β

残差连接的工作原理

残差连接(Residual Connection)是深度学习中用于解决梯度消失问题的一种技术,它通过将每个子层的输入直接添加到该子层的输出上来实现。

这种结构允许梯度在网络中更直接地流动,从而减轻了梯度消失的问题,并且有助于训练更深的网络。

在 Transformer 模型中,残差连接被广泛用于自注意力层和全连接前馈层。

  1. 子层输入和输出

    • 每个子层(例如自注意力层或全连接前馈层)接收输入 x x x 并产生输出 y y y
  2. 残差添加

    • 子层的输出 y y y 与输入 x x x 相加,形成残差连接的中间结果: x + y x + y x+y
  3. 层正则化

    • 将残差连接的结果 x + y x + y x+y 通过层正则化(Layer Normalization),以进一步稳定训练过程。

残差连接的计算公式

残差连接的计算可以表示为:
Output = LN ( x + Sublayer ( x ) ) \text{Output} = \text{LN}(x + \text{Sublayer}(x)) Output=LN(x+Sublayer(x))
其中:
- x x x 是子层的输入。
- Sublayer ( x ) \text{Sublayer}(x) Sublayer(x) 是子层的输出,例如自注意力层或全连接前馈层的输出。
- LN \text{LN} LN 表示层正则化操作。

具体步骤
  1. 子层计算
    • 计算子层的输出: y = Sublayer ( x ) y = \text{Sublayer}(x) y=Sublayer(x)
  2. 残差添加
    • 将子层的输入和输出相加: x + y x + y x+y
  3. 层正则化
    • 对残差连接的结果进行层正则化: Output = LN ( x + y ) \text{Output} = \text{LN}(x + y) Output=LN(x+y)
优点
  • 梯度流动:残差连接允许梯度更直接地从网络的末端流向开始,有助于解决梯度消失问题。
  • 训练深度网络:它使得训练更深的网络成为可能,因为梯度可以更有效地传递。
  • 网络稳定性:层正则化进一步稳定了训练过程,提高了模型的泛化能力。

Transfomer 结构示意图

在这里插入图片描述
原始的 Transformer 采用 Encoder-Decoder 架构,其包含 Encoder 和 Decoder 两部分。这两部分都是由自注意力模块和全连接前馈模块重复连接构建而成。

其中,Encoder 部分由六个级联的 encoder layer 组成,每个encoder layer 包含一个注意力模块和一个全连接前馈模块。其中的注意力模块为自注意力模块(query,key,value 的输入是相同的)。

Decoder 部分由六个级联的decoder layer 组成,每个 decoder layer 包含两个注意力模块和一个全连接前馈模块。其中,第一个注意力模块为自注意力模块,第二个注意力模块为交叉注意力模块(query,key,value 的输入不同)。

Decoder 中第一个 decoder layer 的自注意力模块的输入为模型的输出。**其后的 decoder layer 的自注意力模块的输入为上一个 decoderlayer 的输出。**Decoder 交叉注意力模块的输入分别是自注意力模块的输出(query)和最后一个 encoder layer 的输出(key,value)。

基于 Transformer 的语言模型

预训练任务和模型类型

  1. Encoder-Only 模型
    • 如 BERT(Bidirectional Encoder Representations from Transformers),它使用 Transformer 的 Encoder 部分,通过掩词补全(Masked Language Model, MLM)等任务进行预训练。
  2. Encoder-Decoder 模型
    • 如 T5(Text-to-Text Transfer Transformer),它结合了 Transformer 的 Encoder 和 Decoder 部分,通过截断补全、顺序恢复等多个有监督和自监督任务进行预训练。
  3. Decoder-Only 模型
    • 如 GPT-3(Generative Pre-trained Transformer 3),它使用 Transformer 的 Decoder 部分,通过下一词预测任务进行预训练。

训练流程和损失函数

  1. 下一词预测

    • 基于 Transformer 的语言模型根据当前和历史的词序列 { w 1 , w 2 , . . . , w i } \{w_1, w_2, ..., w_i\} {w1,w2,...,wi} 来预测下一个词 w i + 1 w_{i+1} wi+1 的概率。
  2. 输出表示

    • 模型的输出是一个概率分布向量,每一维代表词典中一个词的概率。
  3. 序列概率

    • 整个词序列 { w 1 , w 2 , . . . , w N } \{w_1, w_2, ..., w_N\} {w1,w2,...,wN} 出现的概率是序列中每个词条件概率的乘积:
      P ( w 1 : N ) = ∏ i = 1 N o i [ w i + 1 ] P(w_1:N) = \prod_{i=1}^{N} o_i[w_{i+1}] P(w1:N)=i=1Noi[wi+1]
  4. 交叉熵损失函数

    • 用于衡量模型预测的概率分布与真实词的概率分布之间的差异:
      l C E ( o i ) = − ∑ d = 1 ∣ D ∣ I ( w ^ d = w i + 1 ) log ⁡ o i [ w ^ d ] l_{CE}(o_i) = -\sum_{d=1}^{|D|} I(\hat{w}_d = w_{i+1}) \log o_i[\hat{w}_d] lCE(oi)=d=1DI(w^d=wi+1)logoi[w^d]
    • 其中 I ( ⋅ ) I(\cdot) I() 是指示函数,当 w ^ d = w i + 1 \hat{w}_d = w_{i+1} w^d=wi+1 时为1,否则为0。
  5. 总损失

    • 训练集 S S S 的总损失是所有样本损失的平均值:
      L ( S , W ) = 1 ∣ S ∣ ∑ s = 1 ∣ S ∣ ∑ i = 1 N l C E ( o i , s ) L(S, W) = \frac{1}{|S|} \sum_{s=1}^{|S|} \sum_{i=1}^{N} l_{CE}(o_{i,s}) L(S,W)=S1s=1Si=1NlCE(oi,s)

文本生成和自回归

  1. 自回归文本生成
    • 在自回归过程中,模型通过迭代预测下一个词来生成文本。
  2. Teacher Forcing
    • 在训练过程中,使用真实的下一个词作为输入,而不是模型预测的词,以提高训练效率和效果。

并行计算和长序列挑战

  1. 并行计算
    • Transformer 的并行输入特性使其能够高效地进行并行计算,这与 RNN 的串行计算形成对比。
  2. 长序列挑战
    • Transformer 的模型规模随着输入序列长度的增长而平方次增长,这为处理长序列带来了挑战。

Transformer 模型的这些特性使其在处理长序列时比 RNN 更高效,但也带来了计算资源的需求增加。为了解决长序列的问题,研究者们提出了一些策略,如使用更长的上下文窗口、改进的注意力机制(如局部注意力或稀疏注意力),以及更高效的训练技术。这些方法有助于扩展 Transformer 模型的应用范围,并提高其在处理长序列数据时的性能和效率。


http://www.kler.cn/a/379990.html

相关文章:

  • 【Docker】docker compose 安装 Redis Stack
  • VS2015 + OpenCV + OnnxRuntime-Cpp + YOLOv8 部署
  • GO随记:不使用主键id 如何分表与mysql大表
  • G1原理—2.G1是如何提升分配对象效率
  • Fastapi + vue3 自动化测试平台(1)--开篇
  • Qt QDockWidget详解以及例程
  • 【BUG分析】clickhouse表final成功,但存在数据未合并
  • 十四届蓝桥杯STEMA考试Python真题试卷第二套第一题
  • 贝尔不等式的验证
  • Es 基础操作 增删改查
  • 一些常用的react hooks以及各自的作用
  • 【漏洞复现】泛微OA E-Office group_xml.php SQL注入漏洞
  • Vue项目与IE浏览器的兼容性分析(Vue|ElementUI)
  • Web大学生网页作业成品——和平精英网页设计与实现(HTML+CSS+JS)(4个页面)
  • MATLAB——矩阵操作
  • CSS基础学习篇——选择器
  • ThreeJS创建一个3D物体的基本流程
  • Github 2024-11-01 开源项目月报 Top19
  • 信息学科平台开发:Spring Boot核心技巧与实践
  • 银行金融知识竞赛活动策划方案
  • 回归预测 | MATLAB实现基于RF-Adaboost随机森林结合AdaBoost多输入单输出回归预测
  • 上云管理之Git/GitHub/GitLab 详解(一)
  • 中汽测评观察 亲子出行健康为先,汽车健康用材成重要考量
  • PHP常量
  • Unity 生命周期的事件顺序
  • 32.Redis高级数据结构HyperLogLog