当前位置: 首页 > article >正文

【AI学习】Mamba学习(十五):关于S4模型计算优化的再学习

前面理解了S4模型,但是对于具体的优化算法,还是没有完全理解透彻。现在补充学习。

S4 = SSM + HiPPO + Structured Matrices

具体方案:Structured State Spaces (S4)

简单总结:用HiPPO-LegS的矩阵形式初始化SSM,解决训练的稳定性问题。但是,基于卷积的并行化训练,依然处理复杂度很高,需要简化。

对角化

如何简化计算?最直接的思维就是矩阵A的对角化。前面的公式5的卷积形式,矩阵A的L次幂,如果是对角矩阵,计算就可以大大简化。
文章3.1节,给出了这种对角化动机的目的。

在这里插入图片描述

但是文章接着说,不幸的是,由于数值问题,对角化的简单应用在实践中不可行。尽管有其他方法,数值上也不稳定。
为什么数值不稳定,简单看一下,以N=3为例,对角化的结果是:

在这里插入图片描述

“HiPPO 矩阵的对角化涉及到的矩阵元素在状态大小 N 增大时会呈指数级增长,这使得对角化在数值上变得不稳定和不可行”。

那怎么办,3.2节给出方法:将HiPPO 矩阵A分解为正态矩阵和低秩矩阵的和。这样处理后获得了一个反对称矩阵,“重点来了,反对称矩阵不单单一定可以对角化,它一定可以被正交矩阵(复数域叫做酉矩阵)对角化!酉矩阵一般数值稳定性都非常好”。

S4 Parameterization: Normal Plus Low-Rank

前面的讨论意味着我们应该进行共轭计算通过条件良好的矩阵 V 。理想的情况是,当矩阵 A 可由完全条件(即酉)矩阵对角化时。根据线性代数的谱定理,这正是normal矩阵的一类。然而,这类矩阵具有限制性;特别是它不包括HiPPO 矩阵(2)。
我们观察到,尽管 HiPPO 矩阵不是normal矩阵,但它可以被分解为normal和低秩矩阵的和。然而,这本身仍然没有用:与对角矩阵不同,对这个幂进行累加(在(5)中)仍然很慢,也不容易优化。我们通过同时应用三种新技术克服了这一瓶颈。
在这里插入图片描述
上面的描述还是有点抽象,分步骤来看:

矩阵A分解为Normal Plus Low-Rank (NPLR)形式

在这里插入图片描述
上面说的公式(2)就是下面的LegS形式,对于的低秩r=1
在这里插入图片描述
论文中在附录C.1NPLR Representations of HiPPO Matrices进行了说明。

已知 HiPPO 矩阵A可以表示为:
在这里插入图片描述

在这里插入图片描述
在这里插入图片描述
“重点来了,可以看到这是一个反对称矩阵,所以它一定可以(在复数域中)对角化!于是我们就将A分解为了可对角化矩阵与低秩矩阵之和!可能有读者质疑,原本 A就 一定是可对角化矩阵,但还是有数值稳定性问题,难道这个反对称矩阵的对角化不用担心数值稳定性问题吗?
重点的重点来了,反对称矩阵不单单一定可以对角化,它一定可以被正交矩阵(复数域叫做酉矩阵)对角化!酉矩阵一般数值稳定性都非常好,所以不用担心这个问题,这也就是为什么我们不直接对角化 ,而绕一圈来构建反对称矩阵的原因。”
在这里插入图片描述
这样就将矩阵A转换为了正规矩阵+低秩矩阵的形式。

但是,论文中指出:然而,这本身仍然没有用:与对角矩阵不同,对这个幂进行累加(在(5)中)仍然很慢,也不容易优化。
公式5见下面:

在这里插入图片描述

生成函数

论文中指出,利用截断的生成函数!

这里的生成函数如何理解?熟悉卷积运算的就知道,卷积运算计算量大,可以先做FFT,在频域变成乘法,然后IFFT。这是利用FFT的简化卷积运算经常使用的方法。只不过,这里傅立叶变换所需要的实际是“截断生成函数”,将无限长度截断为L。
在这里插入图片描述
直观地说,生成函数基本上将SSM卷积滤波器从时域转换为频域。重要的是,它保留了相同的信息,并且可以从其生成函数的评估中恢复所需的SSM卷积滤波器。
这样 就把矩阵的幂的问题转化为矩阵求逆。
在这里插入图片描述

Woodbury Correction

然后,通过Woodbury恒等式来解决低秩问题。
虽然DPLR矩阵由于低秩项而不能有效地求幂,但它们可以通过众所周知的Woodbury恒等式有效地反转。
在这里插入图片描述
有了这个Woodbury恒等式,我们可以将DPLR矩阵A上的SSM生成函数转换为仅在其对角分量上的生成函数。
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

Cauchy Kernel

在这里插入图片描述
柯西矩阵计算是数值分析中一个研究得非常好的问题,有快速算法和基于著名的快速多极方法(FMM)的快速数值算法。

在这里插入图片描述
下面的卷积核的计算,第2、3步骤就是计算Lemma C.3中的K(z)
然后又因为截断为L长度的生成函数,需要下面的第4、5步骤。
在这里插入图片描述
具体第三章涉及的公式推导,可以参见苏神在文章《重温状态空间模型SSM:HiPPO的高效计算(S4)》中的详细推导。

S4模型的处理示意图

在S5的论文中,给出了S4模型的处理示意图。
在这里插入图片描述


http://www.kler.cn/a/389520.html

相关文章:

  • Yaml的使用
  • 【深度学习项目】语义分割-FCN网络(原理、网络架构、基于Pytorch实现FCN网络)
  • 高频词汇背诵 Day1
  • Python绘制数据地图-MovingPandas
  • idea中远程调试中配置的参数说明
  • 【LC】2239. 找到最接近 0 的数字
  • 【刷题列表-更新中】蓝桥杯和洛谷平台刷题列表
  • PyTorch 2.0: 新特性与升级指南
  • SwiftUI开发教程系列 - 第2章:基础布局与视图
  • 微服务之多机部署,负载均衡-LoadBalance
  • 卷积神经网络基础
  • 前缀和 so easy! 力扣.128 最长连续序列 leetcode longest-consecutive-sequence
  • 【动手学电机驱动】 STM32-FOC(2)STM32 导入和创建项目
  • 中兴光猫修改SN,MAC,修改地区,异地注册,改桥接,路由拨号
  • 今日 AI 简报|苹果推出的新框架,智源开源千万级多模态数据集,字节推出图像编辑模型,开源大语言模型和实时对话系统等
  • 24/11/7 算法笔记 PCA主成分分析
  • 【前端】JavaScript 方法速查大全-函数、正则、格式化、转换、进制、 XSS 转义(四)
  • ArkTS--应用状态
  • Linux服务器使用ps和top命令查看进程
  • 加载与存储指令及算数指令
  • HarmonyOS Next 实战卡片开发 01
  • Android CCodec Codec2 (二十)C2Buffer与Codec2Buffer
  • 深度学习中的 Dropout:原理、公式与实现解析
  • [Linux] 共享内存
  • 使用 IDEA 创建 Java 项目(二)
  • Hive:UDTF 函数