论文阅读笔记:VMamba: Visual State Space Model
论文阅读笔记:VMamba: Visual State Space Model
- 1 背景
- 2 创新点
- 3 方法
- 4 模块
- 4.1 2D选择性扫描模块(SS2D)
- 4.2 加速VMamba
- 5 效果
- 5.1 和SOTA方法对比
- 5.2 SS2D和自注意力
- 5.3 有效感受野
- 5.4 扫描模式
论文:https://arxiv.org/pdf/2401.10166
代码:https://github.com/MzeroMiko/VMamba
1 背景
视觉表征学习作为计算机视觉领域的一个基础研究方向,在深度学习时代取得了令人瞩目的进展。为了表示视觉数据中的复杂模式,卷积神经网络和视觉Transformer这两类主干网络被提出广泛应用于各种视觉任务中。相比于卷积神经网络,ViT由于融合了自注意力机制,在大规模数据集通常表现出更强大的学习能力。然而,自注意力的二次复杂度在设计大空间分辨率的下游任务重带来了巨大的计算开销。
为了应对这一挑战,大部分工作都在提高注意力计算的效率。然而,现有的方法要么限制了有效感受野的大小,要么在各种任务上都有显著的性能下降。这促使需要开发出一种新颖的视觉数据架构,同时保持普通自注意力机制的固有优势,即全局感受野和动态加权参数。
最近,在自然语言处理领域,Mamba提出了一种新的状态空间模型SSM,为具有线性复杂度的长序列建模提供了一种有前途的方法。从这一工作中,本文引入了VMamba,一个集成了基于SSM快的视觉主干,以实现高效的视觉表示学习。然而,Mamba的核心算法——并行化的选择性扫描操作,本质上是为处理一维序列数据而设计的。这在将其用于处理视觉数据是提出了挑战。缺乏视觉组件的固有顺序安排。为了解决这个问题,本文提出了2D Selective Scan(SS2D),一种用于空间域遍历的四路扫描机制。与注意力机制相比,SS2D保证了每个图像块经过沿其对应扫描路径设计的压缩隐藏状态来获取上下文知识,从而将时间复杂度从二次降低到线性。
在提出的视觉状态空间(VSS)块的基础上,本文开发了一系列VMamba结构(Tiny/Small/Base),并通过架构改进和实现优化来提高他们的性能。
2 创新点
-
提出了VMamba,一个基于SSM的视觉主干,用于具有线性时间复杂度的视觉表示学习。为了提高VMamba的推理速度,对VMamba进行了一系列的结构和实现上的改进。
-
引入2D选择性扫描( SS2D )来桥接1D阵列扫描和2D平面遍历,从而扩展选择性SSM来处理视觉数据。
-
VMamba在图像分类、目标检测和语义分割等视觉任务中取得了良好的性能。它对输入序列的长度也表现出显著的适应性,表现为计算复杂度的线性增长。
3 方法
本文在三个尺度上开发VMamba:Tiny,Small和Base。在图3(a)中对VMamba-T的结构进行了概述(附录中提供了详细的配置)。输入图像 I ∈ R H × W × 3 I∈R^{H×W×3} I∈RH×W×3 首先被划分成 patch,从而得到空间维度为 H 4 × W 4 \frac{H}{4}×\frac{W}{4} 4H×4W 的2D特征图,在不引入额外位置嵌入的情况下,使用多个网络阶段创建分辨率为 H 8 × W 8 \frac{H}{8}×\frac{W}{8} 8H×8W, H 16 × W 16 \frac{H}{16}×\frac{W}{16} 16H×16W , H 32 × W 32 \frac{H}{32}×\frac{W}{32} 32H×32W 的表示。具体来说,每个阶段包括一个下采样层(除第一阶段外),后面是一个堆叠的视觉状态空间(VSS)块。
VSS块作为Mamba块的视觉对应模块用于表示学习。通过替换S6模块制定了VSS模块的初始架构(图3©的普通VSS块)。S6是Mamba的核心,实现了全局感受野,动态权重(即选择性)和线性复杂度。本文提出用新的2D选择性扫描(SS2D)模块来代替它。为了进一步提高计算效率,作者去掉了整个乘法分支,因为门控机制的作用已经通过SS2D的选择性来实现。因此,改进后的VSS块(如图3(d))由单个网络分支和两个残差模块组成,模仿了普通Transformer模块的架构。本文所有的结构均是在该架构下使用VSS块构建的VMamba模型得到的。
4 模块
4.1 2D选择性扫描模块(SS2D)
虽然S6中扫描操作的顺序性与设计时间数据的NLP任务很好的吻合,但应用于视觉数据是一个重大的挑战,因为图像本质上是非序列的,它包含了空间信息(如局部纹理和全局结构)。为了解决这个问题,本文提出了SS2D模块,以适应视觉数据而不损害其优势。
如图2所示,SS2D中的数据前向过程包括3个步骤:
-
交叉扫描。SS2D先沿着四条不同的遍历路径将收入块展开成序列。
-
使用S6块对每个patch序列进行并行处理。
-
交叉合并。对结果序列重新整理和合并,形成输出图(从代码看就是将对应patch直接相加)。
通过使用互补的1D遍历路径,SS2D允许图像中的每个像素在不同的方向上集成来自所有其他像素的信息。这种整合有利于2D空间中建立全局感受野。
4.2 加速VMamba
如图3(e)所示,采用普通VSS快的VMamba-T模型(包含22.9M个参数和5.6G FLOPs)实现了每秒426帧图像的吞吐率。尽管取得了82.2%的最高分类准确率(比Swin-T高出0.9%),但低吞吐量和高内存开销给VMamba的实际部署带来了巨大的挑战。
在这一部分,作者通过实现细节和架构设计两方面的改进提升其推理速度:
-
Step a(+0.0%, +41 img/s):通过在Triton中重新实现交叉扫描和交叉合并。
-
Step b(+0.0%, -3 img/s):通过调整CUDA实现选择性扫描,以适应float16和float32输出,这显著的提高了训练效率,尽管在测试时速度略有波动。
-
Step c(+0.0%, +174 img/s):将选择性扫描中相对慢的 einsum 用线性变换(即 torch.nn.functional.linear)代替。同时还采用(B,C,H,W)的张量步距,以消除不必要的数据置换。
-
Step d(+0.0%, +175 img/s):在VMamba中引入MLP。丢弃DWConv深度可分离卷积层,并将层配置从 [ 2 , 2 , 9 , 2 ] [2,2,9,2] [2,2,9,2] 改为 [ 2 , 2 , 2 , 2 ] [2,2,2,2] [2,2,2,2] 以降低FLOPs。
-
Step e(+0.6%, +366 img/s):将参数 ssm-ratio(特征扩展因子)从2.0减少到1.0,将层数增加到 [ 2 , 2 , 5 , 2 ] [2,2,5,2] [2,2,5,2],并丢弃图3© 中的整个乘法分支。
-
Step f(+0.3%, +161 img/s):通过引入DWConv层,并将参数 d_state(SSM状态维数)从16.0降低到1.0,同时将 ssm-ratio 提高到2.0。
-
Step g(+0.1%, +346 img/s):将 ssm-ratio减少到1.0,同时将层数配置从 [ 2 , 2 , 5 , 2 ] [2,2,5,2] [2,2,5,2] 改为 [ 2 , 2 , 8 , 2 ] [2,2,8,2] [2,2,8,2]。
5 效果
5.1 和SOTA方法对比
分类任务上的效果对比。
检测任务上的效果对比。
对下游任务(4a)和输入尺寸(4b)的泛化性。
随着输入尺寸增加的计算量、吞吐量和内存占用对比。
5.2 SS2D和自注意力
为了表示长度为
T
T
T 的时间区间
[
a
,
b
]
[a,b]
[a,b] 内的响应
Y
Y
Y,作者将相应的SSM相关变量 $u_i⊙\Delta_i∈R{1×D_v},B_i∈R{1×D_k} $ 和
C
i
∈
R
T
×
D
v
C_i∈R^{T×D_v}
Ci∈RT×Dv 分别视为
V
∈
R
T
×
D
v
,
K
∈
R
T
×
D
k
V∈R^{T×D_v},K∈R^{T×D_k}
V∈RT×Dv,K∈RT×Dk 和
Q
∈
R
T
×
D
k
Q∈R^{T×D_k}
Q∈RT×Dk。那么,沿着
y
b
y_b
yb 的
D
v
D_v
Dv 维度方向的第
j
j
j 个切片,记为
y
b
(
j
)
∈
R
y_b^{(j)}∈R
yb(j)∈R 可以表示为:
其中 h a ∈ R D k h_a∈R^{D_k} ha∈RDk 是第a步的隐藏状态, ⊙ ⊙ ⊙ 表示元素乘积, V i ( j ) V_i^{(j)} Vi(j) 是一个标量, w : = [ w 1 ; … ; w T ] ∈ R T × D K × D v w:=[w_1;…;w_T]∈R^{T×D_K×D_v} w:=[w1;…;wT]∈RT×DK×Dv,其中 w i ∈ R D k × D v w_i∈R^{D_k×D_v} wi∈RDk×Dv中的每个元素的表达式可以写成 w i = ∏ j = 1 i e A Δ a − 1 + j T w_i=\prod_{j=1}^ie^{A\Delta^T_{a-1+j}} wi=∏j=1ieAΔa−1+jT,表示沿着扫描路径在第 i i i 步计算的累积注意力权重。
因此,
Y
Y
Y 的第
j
j
j 个维度,即
Y
(
j
)
∈
R
T
×
1
Y^{(j)}∈R^{T×1}
Y(j)∈RT×1 可以表示为:
其中 M M M 表示 T × T T×T T×T 的时间掩码矩阵,下三角部分为1,其余部分为0。
作者可视化了SS2D中的
Q
K
T
QK^T
QKT 和
(
Q
⊙
w
)
(
K
/
w
)
T
(Q⊙w)(K/w)^T
(Q⊙w)(K/w)T。如图6 ( b )所示,
Q
K
T
QK^T
QKT 的激活图证明了SS2D在捕获和保留遍历信息方面的有效性,所有先前扫描到前景区域的token都被激活。此外,
w
w
w 的加入使得激活图更加集中在查询块(图6 ( c )) )的邻域内,这与
w
w
w 的提法所固有的时间加权效应是一致的。尽管如此,选择性扫描机制允许VMamba沿着扫描路径积累历史,有利于建立图像块之间的长期依赖关系。这在红色方框(图6 ( d )) )包围的子图中表现得很明显,在该子图中,距离左侧(在早期的步骤中扫描)较远的绵羊patch仍然处于激活状态。
将6(d)中每一行的结果加起来可以得到6©中每一行的结果。
5.3 有效感受野
有效感受野ERF是指输入空间中有助于激活特定输出单元的区域。作者对比了训练前后各视觉backbone上的ERF,如图7,在所有研究的模型中,只有DeiT,HiViT,Vim和VMamba显示了全局ERF,其他模型虽然具有全局覆盖的理论潜力,但也只显示了局部ERF,同时与DeiT和HiViT相比,VMamba的线性复杂度提高了其计算效率。与Vim相比,虽然两者均是基于Mamba架构,但VMamba的ERF比Vim更均匀。
5.4 扫描模式
作者对比了提出的交叉扫描和单项扫描,双向扫描以及级联扫描(依次扫描行和列)的效果,如图8所示。