当前位置: 首页 > article >正文

MambaTab:表格数据处理的新利器

——基于结构化状态空间模型的特征增量学习框架

摘要
本文提出MambaTab,一种基于结构化状态空间模型(SSM)的表格数据处理框架。通过创新的嵌入稳定化设计与轻量化SSM架构,MambaTab在普通监督学习和特征增量学习场景中均表现优异,参数规模仅为Transformer基线的1%,同时实现更高的预测精度与更强的特征扩展性。实验验证其在8个公开数据集上的平均AUROC超越SOTA模型1.2-5.7个百分点,为工业级表格数据处理提供高效解决方案。

关键词:表格数据;结构化状态空间模型;特征增量学习;轻量级架构;自监督预训练

一、范式革新:从固定特征到动态扩展

1.1 表格数据的现实挑战

在金融风控(如Credit-g)、医疗诊断(如Ionosphere)等领域,表格数据常面临特征动态增长(如新增传感器指标、合规字段)与计算资源受限(如边缘设备部署)的双重挑战。传统方法(XGBoost、TabNet)依赖固定特征集,Transformer类模型(TabTransformer)参数量爆炸(如FT-Trans在Ionosphere需193K参数),均难以满足实时扩展与轻量化需求。

1.2 增量学习的破局思路

MambaTab提出维度稳定嵌入(Dimension-Stable Embedding)机制,通过全连接层+层归一化,确保新增特征时嵌入维度不变(如图1左侧)。结合结构化状态空间模型的上下文选择性,首次实现无需架构调整的特征动态扩展,突破"特征增删=模型重构"的传统范式。
在这里插入图片描述

二、架构解析:从预处理到预测的极简链路

2.1 标准化预处理流水线(图1左)

  • 特征统一:二值/类别特征序数编码,数值特征众数填充+Min-Max归一化(公式1)
  • 维度对齐:全连接层将异构特征映射为32维稳定嵌入(默认),层归一化替代批归一化保障增量稳定性

2.2 核心模块:上下文感知SSM(图1右)

d h ( t ) d t = A h ( t ) + B u ( t ) , x ( t ) = C h ( t ) \frac{dh(t)}{dt} = A h(t) + B u(t), \quad x(t) = C h(t) dtdh(t)=Ah(t)+Bu(t),x(t)=Ch(t)
通过离散化采样(公式2),引入输入依赖的时变矩阵 B ( t ) 、 C ( t ) B(t)、C(t) B(t)C(t),使SSM具备内容选择性遗忘能力。对比Transformer的O(n²)复杂度,SSM的线性复杂度使其在长序列特征(如100+字段)中保持高效。

2.3 轻量化预测头

Mamba输出经残差连接后,单FC层映射至二分类概率(sigmoid激活),总参数量仅13K(默认配置),约为TabTrans的0.3%。

三、MambaTab方法详解

3.1 数据预处理

对于表格数据集 { F i , y i } i = 1 m \{F_{i}, y_{i}\}_{i = 1}^{m} {Fi,yi}i=1m,其中 F i = { v i , j } j = 1 n F_{i}=\{v_{i, j}\}_{j = 1}^{n} Fi={vi,j}j=1n表示第 i i i个样本的特征, y i ∈ { 0 , 1 } y_{i} \in \{0,1\} yi{0,1}是相应标签, v i , j v_{i, j} vi,j可以是分类、二进制或数值型数据。MambaTab将二进制和分类特征都视为分类特征,使用序数编码器进行编码。对于数值特征保持不变,并通过填充众数处理缺失值。在将数据输入模型前,利用最小 - 最大缩放将 v i , j v_{i, j} vi,j的值归一化到 [ 0 , 1 ] [0,1] [0,1],公式为:
v i , j ′ = v i , j − min ⁡ i , j = 1 i = n , j = m ( v i , j ) max ⁡ i , j = 1 i = n , j = m ( v i , j ) − min ⁡ i , j = 1 i = n , j = m ( v i , j ) v_{i, j}'=\frac{v_{i, j}-\min_{i, j = 1}^{i = n, j = m}(v_{i, j})}{\max_{i, j = 1}^{i = n, j = m}(v_{i, j})-\min_{i, j = 1}^{i = n, j = m}(v_{i, j})} vi,j=maxi,j=1i=n,j=m(vi,j)mini,j=1i=n,j=m(vi,j)vi,jmini,j=1i=n,j=m(vi,j)

3.2 嵌入表示学习

预处理后的数据通过全连接层学习嵌入表示。这一步骤至关重要,它能够提供更有意义的表示作为后续模型的输入。同时,嵌入表示学习器可以使模型直接从特征中学习多维表示,避免依赖强加的顺序。此外,在特征增量学习过程中,它能确保下游Mamba模块在训练和测试时的输入特征维度一致。为了保持表示的稳定性,使用层归一化而非批归一化。
在这里插入图片描述
特征增量学习设置说明。尽管大多数现有的方法仅能够从一组固定的特征中进行学习,但MambaTab以及现有的TransTab方法却能够在增量特征设置下进行学习。在此,特征集i(其中i = 1, 2, 3)是逐步添加了特征的。特征集X代表测试数据的特征集。

3.3 MambaTab模型构建

经过层归一化的嵌入表示,先经过ReLU激活函数,然后输入到Mamba模块。Mamba模块内部,两个分支的全连接层分别计算线性投影 L P 1 LP_{1} LP1 L P 2 LP_{2} LP2 L P 1 LP_{1} LP1的输出经过一维因果卷积和SiLU激活函数后,进入结构化状态空间模型(SSM)。

连续时间的SSM是一个一阶常微分方程组:
d h ( t ) d t = A h ( t ) + B u ( t ) , x ( t ) = C h ( t ) \frac{dh(t)}{dt}=Ah(t)+Bu(t), x(t)=Ch(t) dtdh(t)=Ah(t)+Bu(t),x(t)=Ch(t)
其中, h ( t ) h(t) h(t) N N N维的潜在状态( N N N为状态扩展因子), u ( t ) u(t) u(t) D D D维的输入( D D D为维度因子或通道数), x ( t ) x(t) x(t)通常取一维, A A A B B B C C C是相应的系数矩阵。

通过时间采样,得到离散版本的SSM:
h k = A ‾ h k − 1 + B ‾ u k , x k = C h k h_{k}=\overline{A}h_{k - 1}+\overline{B}u_{k}, x_{k}=Ch_{k} hk=Ahk1+Buk,xk=Chk
其中, h k h_{k} hk u k u_{k} uk x k x_{k} xk分别是 h ( t ) h(t) h(t) u ( t ) u(t) u(t) x ( t ) x(t) x(t)在时间 k Δ k\Delta kΔ的样本, A ‾ = exp ⁡ ( Δ A ) \overline{A}=\exp(\Delta A) A=exp(ΔA) B ‾ = ( Δ A ) − 1 ( exp ⁡ ( Δ A ) − I ) Δ B \overline{B}=(\Delta A)^{-1}(\exp(\Delta A)-I)\Delta B B=(ΔA)1(exp(ΔA)I)ΔB

Mamba中, B B B C C C Δ \Delta Δ是依赖于输入的线性时变函数。这种设计使得SSM具有上下文和输入选择性,帮助Mamba模块在长输入标记序列中选择性地传播或遗忘信息。最后,SSM的输出与 S ( L P 2 ) S(LP_{2}) S(LP2)进行乘法调制,再经过全连接投影,从而实现基于内容的特征提取和长程依赖及特征交互的推理。

3.4 输出预测

Mamba模块的输出经过连接后,通过全连接层投影到“批量×1”的维度,得到预测对数几率 y i ′ y_{i}' yi。再通过sigmoid激活函数:
s i g m o i d ( y i ′ ) = 1 1 + exp ⁡ ( − y i ′ ) sigmoid(y_{i}')=\frac{1}{1+\exp(-y_{i}')} sigmoid(yi)=1+exp(yi)1
得到预测概率分数,用于计算受试者工作特征曲线下面积(AUROC)和二元交叉熵损失。

四、实验设置与结果

4.1 实验设置

为了全面评估MambaTab的性能,使用了8个不同的公共数据集,包括Credit - g(CG)、Credit - approval(CA)等。将这些数据集按照70%训练集、10%验证集、20%测试集的比例进行划分。

在实现细节上,遵循简单的预处理方法,避免手动干预和调优。训练时,使用Adam优化器和余弦退火学习率调度器,初始学习率为 1 e − 4 1e^{-4} 1e4,训练1000个epoch,提前停止的耐心值设为5。MambaTab的默认超参数设置为:嵌入表示大小(长度) = 32,SSM状态扩展因子( N N N) = 32,局部卷积宽度( d c o n v d_{conv} dconv) = 4,SSM模块扩展因子( M M M) = 1。

选择了多种方法作为基线模型,包括LR、XGBoost、MLP等经典方法,以及TabNet、DCN、AutoInt等深度学习模型。在对比实验中,严格遵循这些基线模型在TransTab中的架构和实现细节。

4.2 普通监督学习性能

在普通监督学习设置下,按照[Wang和Sun, 2022]的协议进行实验。为了减少随机因素的影响,在每个数据集上进行10次不同随机种子的运行,并报告平均结果。

从实验结果(见表1)可以看出,MambaTab - D在3个数据集(CG、CA、BL)上的性能优于基线模型,在其他数据集上与基于Transformer的基线模型性能相当。例如,MambaTab - D在8个数据集中的5个(CG、CA、DS、CB、BL)上优于TransTab。经过超参数调整后的MambaTab - T性能更优,在6个数据集上优于所有基线模型,在另外2个数据集上排名第二。

MethodsDatasets
CGCADSADCBBLIOIC
LR0.7200.8360.5570.8510.7480.8010.7690.860
XGBoost0.7260.8950.5870.9120.8920.8210.7580.925
MLP0.6430.8320.5680.9040.6130.8320.7790.893
SNN0.6410.8800.5400.9020.6210.8340.7940.892
TabNet0.5850.8000.4780.9040.6800.8190.7420.896
DCN0.7390.8700.6740.9130.8480.8400.7680.915
AutoInt0.7440.8660.6720.9130.8080.8440.7620.916
TabTrans0.7180.8600.6480.9140.8550.8200.7940.882
FT - Trans0.7390.8590.6570.9130.8620.8410.7930.915
VIME0.7350.8520.4850.9120.7690.8370.7860.908
SCARF0.7330.8610.6630.9110.7190.8330.7580.905
TransTab0.7680.8810.6430.9070.8510.8450.8220.919
MambaTab - D0.7710.9540.6430.9060.8620.8520.7850.906
MambaTab - T0.8010.9630.6810.9140.8960.8540.8120.920

4.3 特征增量学习性能

在特征增量学习设置中,将每个数据集的特征集 F F F划分为三个不重叠的子集 s 1 s_1 s1 s 2 s_2 s2 s 3 s_3 s3。大多数基线模型只能从固定的特征集(如丢弃增量特征的 s 1 s_1 s1或丢弃旧数据的 s 1 s_1 s1 s 2 s_2 s2 s 3 s_3 s3)学习,而MambaTab和TransTab可以从 s 1 s_1 s1逐步学习到 s 1 s_1 s1 s 2 s_2 s2,再到 s 1 s_1 s1 s 2 s_2 s2 s 3 s_3 s3

MambaTab通过改变输入特征的基数 n ( s e t i ) n(set_{i}) n(seti),同时保持架构不变来实现特征增量学习。这得益于Mamba强大的内容和上下文选择性,以及固定的表示空间维度。即使使用默认超参数,MambaTab - D在特征增量学习上的性能也优于所有基线模型(见表2)。

在8个数据集上,MambaTab-T(调优版)在6个数据集登顶:

  • Credit-g:AUROC 0.801(+3.3% vs TransTab)
  • Credit-approval:0.963(+8.2% vs XGBoost)
  • 平均参数量:50K(TransTab的1.2%),内存占用降低2个数量级
MethodsDatasets
CGCADSADCBBLIOIC
LR0.6700.7730.4750.8320.7270.8060.6550.825
XGBoost0.6080.8170.5270.8910.7780.8160.6920.898
MLP0.5860.6760.5160.8900.6310.8250.6260.885
SNN0.5830.7380.4420.8880.6440.8180.6430.881
TabNet0.5730.6890.4190.8860.5710.8370.6800.882
DCN0.6740.8350.5780.8930.7780.8400.6600.891
AutoInt0.6710.8250.5630.8930.7690.8360.6760.887
TabTrans0.6530.7320.5840.8560.7840.7920.6740.828
FT - Trans0.6620.8240.6260.8920.7680.8400.6450.889
VIME0.6210.6970.5710.8920.7690.8030.6830.881
SCARF0.6510.7530.5560.8910.7030.8290.6800.887
TransTab0.7410.8790.6650.8940.7910.8410.7390.897
MambaTab - D0.7870.9610.6690.9040.8600.8530.7830.908

4.4 可学习参数比较

MambaTab不仅性能优越,在可学习参数大小方面也具有显著优势。与基于Transformer的方法相比(见表3),MambaTab(包括MambaTab - D和MambaTab - T)通常仅使用不到TransTab 1%的可学习参数,就能取得与之相当甚至更好的性能。这表明MambaTab在内存和空间利用上更加高效。

MethodsDatasets
CGCADSADCBBLIOIC
TabTrans2.7M1.2M2.0M1.2M6.5M3.4M87.0M1.0M
FT-Trans176K176K179K178K203K176K193K177K
TransTab4.2M4.2M4.2M4.2M4.2M4.2M4.2M4.2M
MambaTab-D13K13K13K13K14K13K15K13K
MambaTab-T50K38K5K255K30K11K13K10K

4.5 超参数调整

MambaTab对重要超参数进行调整,通过验证损失来确定最优参数,测试集不参与调优过程。调优后的MambaTab(MambaTab-T)在10次不同随机种子运行中的平均测试结果显示出了更好的性能(见表1)。有趣的是,MambaTab-T在某些数据集(如DS、BL、IO和IC)上的参数消耗甚至比MambaTab-D更少。表4展示了MambaTab-T的关键超参数调整值。

HyperparametersDatasets
CGCADSADCBBLIOIC
Embedding Representation Space6432166432161632
State Expansion Factor1664326484864
Block Expansion Factor3421071091

五、超参数敏感性分析和消融研究

5.1 模块扩展因子

对模块扩展因子(内核大小)在{1, 2, …, 10}范围内进行实验,其他超参数保持默认。结果表明,MambaTab的性能随模块扩展因子变化仅有轻微波动,无明显趋势。受[Gu和Dao, 2023]启发,将默认值设为2,但进一步调整该参数可能在部分数据集上提升性能。

5.2 状态扩展因子

使用{4, 8, 16, 32, 64, 128}中的值探究状态扩展因子(N)的影响。实验发现,随着N增大,MambaTab在AUROC指标上的性能提升,但较大的N会消耗更多内存。为平衡性能与内存消耗,选择32作为默认值。

5.3 嵌入表示大小

对嵌入表示长度在{4, 8, 16, 32, 64, 128}范围内进行敏感性分析。结果显示,MambaTab性能随嵌入大小增加而提升,但会增加参数数量和计算资源需求。综合考虑,将默认嵌入长度设为32。

5.4 层归一化的消融

通过在CG和CB数据集的普通监督学习实验中,对比有无层归一化的情况,验证其在模型中的作用。结果表明,有层归一化时MambaTab性能更优,证明了层归一化在模型中的有效性。

5.5 批量大小的影响

研究批量大小在{60, 80, 100, 120, 140}范围内变化对模型性能的影响。发现不同批量大小下,MambaTab在CG和CB数据集上的性能变化较小,展示了其在批量大小方面良好的泛化能力,因此将默认批量大小设为100。

5.6 扩展Mamba模块

研究通过残差连接扩展Mamba模块的效果,从M=2到100进行堆叠实验。结果显示,随着Mamba模块数量增加,MambaTab性能保持稳定,可学习参数线性增加。表明少数Mamba模块即可获得良好性能,因此默认使用M=1。

六、未来展望与结论

MambaTab通过维度稳定嵌入+上下文SSM的创新组合,突破表格数据处理的三大瓶颈:特征增量难、参数量大、预处理复杂。未来将探索:

  1. 回归任务扩展(如房价预测)
  2. 多模态表格融合(结合文本/图像特征)
  3. 在线学习适配(增量样本+增量特征联合优化)

这一轻量级框架为表格数据的实时智能处理开辟新路径,尤其适用于需快速迭代的工业场景。


http://www.kler.cn/a/587634.html

相关文章:

  • 基于Asp.net的汽车租赁管理系统
  • 机器学习中的激活函数是什么起什么作用
  • 百年匠心焕新居:约克VRF中央空调以科技赋能健康理想家
  • [特殊字符] C语言经典案例整理 | 附完整运行效果
  • Matlab实现RIME-CNN-LSTM-Multihead-Attention多变量多步时序预测
  • 芯科科技推出的BG29超小型低功耗蓝牙®无线SoC,是蓝牙应用的理想之选
  • 渗透测试环境搭建,包含常用命令(AndroidIOS)
  • 【机器学习】非结构化数据革命:机器学习中的文本、图像与音频
  • 运维未来发展趋势
  • 3ds Max 快捷键分类指南(按功能划分)
  • 股指期货基差怎么计算?公式介绍
  • SpringBoot(2)——SpringBoot入门:微服务
  • 正点原子[第三期]Arm(iMX6U)Linux移植学习笔记-4 uboot目录分析
  • 从开发者视角找寻Postman的替代工具
  • golang字符串常用的系统函数
  • GobiNet 驱动移植调试
  • 如何利用Python爬虫获取微店商品详情数据:实战指南
  • DeepSeek大模型在政务服务领域的应用
  • 使用 PaddlePaddle 官方提供的 Docker 镜像
  • 超声重建,3D重建 超声三维重建,三维可视化平台 UR 3D Reconstruction