当前位置: 首页 > article >正文

xLSTM模型学习笔记

笔记来源:bilibili

LSTM 回顾

原始的 LSTM 是为了解决 RNN 时序反向传播中梯度消失和爆炸问题而提出的。

其所谓的门控机制,其实就是一种时序上的注意力机制,相当于把不同时间进行"掺和",是对时序信息的一种选择性控制。从这个视角看,它与 Transformer 和 mamba 都异曲同工之妙。核心思想都是选择性控制信息流动,更好地处理时序数据或序列信息。门控机制通过固定的结构和参数来控制信息流,而注意力机制通过动态计算权重来控制信息流。因此,门控机制可以看作是一种特定形式的时序注意力机制,对不同时间步的信息进行选择性控制和"掺和"。可以认为是一种约束版或者简化版的注意力机制。

其缺点

LSTM 有三个主要局限性:

  1. 在处理长序列时效率低
  2. 记忆容易有限
  3. 不可以并行处理数据

而 transformer 借助网格模块堆叠,参数规模扩充和 GPU 并行处理拼算力,在一定的程序上解决了以上的问题,所以 transformer 实现了赶超。

初级版:sLSTM 改进注意力机制

改进的点:

  1. 输入门和遗忘门的激活函数从 sigmoid 改成了指数函数(红色部分)
  2. 引入了一个归一化状态 n t n_t nt,相应的隐层 h t h_t ht 的计算方式亦了,改成了 c t / n t c_t/n_t ct/nt(公式 10)
  3. 引入了一个额外状态 m t m_t mt 来进一步稳定门控

改进的原因如下:

  1. 指数函数相比于 sigmoid 函数,具有更大的输出范围和更大的梯度,可以减轻梯度消失的问题,使得梯度在反向传播过程中不会迅速减小,从而使得模型在训练时能够更有效地更新权重。
  2. 指数函数的增长速度比 sigmoid 函数快,对输入变化更加敏感,因此,可以更迅速地强烈的调整输入和遗忘门的输出,使得模型能够更快地捕捉到输入信息的变化,更加选择性地记住或忘记信息,从而提高模型的记忆和遗忘能力。
  3. 强烈的选择性可以让模型能够更准确地保留重要信息和丢弃不重要的。在特定任务(如长序列的最近邻搜索或稀有事件预测)中表现得尤为显著,能够显著提升模型的性能。

引入归一化和状态 m t m_t mt 都是为了稳定,因为指数激活函数可能导致数值过大而溢出,前者相当于搞了一个大分母。后者通过下面的公式进行:

第一个式子中,使用了 log 的作用就是防止输入门和遗忘门都不要太猛。然后根据 m t m_t mt 来调整输入了门与遗忘门,这样就实现了对输入门与遗忘门的调节。

在原论文中,还证明了在瘦身传播中使用 f t ′ f'_t ft i t ′ i'_t it 替换 f t f_t ft i t i_t it 不会改变整个网络的输出,也不会改变参数损失的导数。

增加了这些公式相当于增加了新的记忆单元,它们之间通过连接从长短期记忆状态,借助门控(阀门)i, f, 0 进行记忆混合。门控就是选择,也是一种时序注意力机制的体现。

中级版:mLSTM 改进内存处理

解决了敏感度,某种程序上也是长序列处理效率问题,为了增强 LSTM 的存储能力,文章将 LSTM 的记忆单元从一个标量 c 增加到矩阵 C。而且在这里引入了 transformer 键值对的概念,更新规则如下:
C t = C t − 1 + v t k t T C_t=C_{t-1}+v_tk_t^T Ct=Ct1+vtktT
在将输入投影到键和值之前,mLSTM 进行层归一化,使得均值为零。同时,将协方差更新规则,也就是优化器整合到 LSTM 架构中,遗忘门对应于衰减率,输入门对应于学习率,而输出门则缩放检索到的向量,最终形成了下面的迭代公式:

与之前的 sLSTM 对比,最大的区别之一就是状态和权重参数都变成了矩阵的形式,对应的运算变成了向量矩阵简洁和哈达玛积。区别之二是增加了 q t q_t qt k t k_t kt, v t v_t vt 这种键值对的计算公式,优化了自注意力机制,多了好几个权重矩阵增强了模型表达能力。其他的公式基本没变。相当于,记忆单元没变,只是每个单元扩容了记忆的容量。

此外,这种框架可以使用多头模型,头与头之间没有记忆混合,因此可以充分并行,可以提升并行能力。

高级版:xLSTM 大模型

Cover 定理

Cover 定理:它及衍生的高维空间中非线性映射理论是现代大模型设计的重要理论依据之一。尤其是在深度学习和大规模神经网络的设计中,直到了关键的作用。

大模型中,激活函数通过非线性变换将数据映射到高维空间,使得模型可以捕捉复杂的模式和特征,增强模型的表达能力。深度网络在权重矩阵和激活函数共同作用,将输入数据逐步映射到越来越高的维度。这使得在低维空间难以分离的模式在高维空间国变得线性可分。Transformer 模型就是通过多头注意力机制在高维空间中进行并行处理,使得不同位置的特征可以相互影响和结合,从而提高了模型的性能。

Cover 定理为这些设计提供了理论支持,解释了为什么通过高维空间国的非线性映射可以提高模型的性能。

核心模块和工作原理

它做了以下的事:

  1. 非线性总结(压缩信息):通过残差块在高维空间中对历史信息进行非线性总结,使得不同的历史或上下文信息更容易分离。
  2. 线性映射回原始空间:完成高维空间中的处理后,再将数据线性映射回原始空间。这一过程利用了高维空间中的优势,使得模型可以更好地分离和记忆历史信息。

而具体的长维,其结构如下:

左侧可以看成 sLSTM,右侧则可以看成:mLSTM。其输入方向为,从下往上输入。

左边是先在原始空间中总结信息(sLSTM),然后映射到高维空间,再返回原始空间。可以看到,有一个倒梯形矩阵用于升维,处理后再降维。而右边是先映射到高维空间,总结信息后再返回原始空间。输入直接上投影,再使用 mLSTM 处理,然后再降维。

关于为啥,左边使用 sLSTM,而右边使用 mLSTM:高维空间中的记忆容量更大,所以使用矩阵化记忆单元的 mLSTM 更合适,而在低维空间处理 sLSTM 更合适。

以下是两个模块的详细设计:

  • PF=3/4 和 PF=4/3:投影因子,用于将输入维度缩小或扩大为原来的 PF 倍。
  • GN:组归一化。在每一组内进行归一化,有助于加速训练和提高模型稳定性,特别是在小批量(batch)训练时。
  • Swish:一种平滑的非线性激活函数,可以帮助模型学习到更复杂的模式。
  • Conv 4: 卷积层,卷积核大小为 4,提取局部特征。
  • LN:层归一化,帮助稳定和加速训练过程。
  • NH=4:表示有 4 个头。此外,将输入块,使用块对角线结构进行线性变换,有助于捕捉局部相关性。

  • PF=1/2 和 PF=2:投影因子。前者将输入维度缩小一半,后者将输入维度扩大 2 倍。
  • LSkip:类似于残差连接,可以帮助梯度更好地传递,防止梯度消失和爆炸。这里相当于有两种跳线残差。
  • q,k,v:从输入中生成,用于计算注意力权重和进行信息检索。
  • BS=4:块大小为 4 的块对角投影矩阵。

整体上都是充分利用了残差堆叠结构,层归一化技术等稳定网络,通过升降维度实现空间变换,激活函数非线性变换,然后利用 LSTM 进行记忆混合,或者说时序上的选择性自力机制计算,采用多头和块对角模式实现并行处理。

与 Transformer 的对比

有了这两种基本构建模块,通过堆叠增加模型的深度,可以逐层提取更高层次的特征。最终,整个堆叠结构作为一个端到端的模型进行训练。

同时,Transformer 能干的,xLSTM 也可以干,但是 xLSTM 有更加明确的逻辑结构,有数据公式的严密推导,效率更高。

与 Transformer 不同,xLSTM 在计算复杂度与内存复杂序上随着序列长度呈再发关系。由于 xLSTM 有记忆压缩性,很适合在工业应用和边缘设备上实现。

适用的场景

  • sLSTM(无法并行化):需要高精度和复杂特征提取的任务,计算资源充足且不需要并行化的应用,对延迟敏感但不受并行化限制的场景。
  • mLSTM(可以并行化):图像识别,视频处理等需要高效并行计算的任务,计算资源有限且需要高效利用内存的应用;需要在工业环境或边缘设备上部署的任务。

小结

xLSTM 的原理:借助指数门控混合记忆和新内存结构,LSTM 增强为 sLSTM 和 mLSTM。二者的结合构成了 xLSTM,进一步堆叠可以实现大模型化。


http://www.kler.cn/a/301637.html

相关文章:

  • 图像处理|腐蚀操作
  • 人工智能-数据分析及特征提取思路
  • 计算机网络(三)——局域网和广域网
  • xml简介
  • 基于SpringBoot的洗浴管理系统
  • Mac中配置vscode(第一期:python开发)
  • 高性能计算机A100会带有BMC功能 ;BMC;SSH
  • Thinkphp5实现一周签到打卡功能
  • 前端算法(持续更新)
  • Linux_kernel移植rootfs10
  • 普发Pfeiffer TCP600TCP5000手侧
  • Python——贪吃蛇
  • JAVA基础:抽象类,接口,instanceof,类关系,克隆
  • 在深度学习计算机视觉的语义分割中,Boundary和Edge的区别是?
  • NX—UI界面生成的文件在VS上的设置
  • 【网络】DNS
  • 备忘录模式memento
  • C语言初识编译和链接
  • 01 Docker概念和部署
  • 【论文速读】| SEAS:大语言模型的自进化对抗性安全优化
  • Arm GIC-v3中断原理及验证(通过kvm-unit-tests)
  • Go语言练习——语法实践
  • Leetcode 3281. Maximize Score of Numbers in Ranges
  • 双端队列--java--黑马
  • Python 学习笔记(二)
  • Java后端开发(十六)-- JavaBean对象拷贝工具类:运用反射机制,实现对象的深拷贝