(icml2024)SLAattention,基于原文时序模型进行改进
#代码: https://github.com/xinghaochen/SLAB
#论文:https://arxiv.org/pdf/2405.11582
相关工作
1. 高效Transformer架构
背景: Transformer从最初的自然语言处理扩展到计算机视觉领域(例如ViT),但由于计算复杂度高,尤其是Attention机制的二次复杂度,难以部署在资源受限的设备上。
Transformer在自然语言处理和计算机视觉领域得到了广泛应用,但其计算复杂度较高,尤其是Attention模块的二次复杂度,成为部署在资源受限设备上的主要瓶颈。为此,研究者提出了多种优化方案:
降低Attention计算开销:通过限制Token的交互范围(如局部窗口计算)或引入近似线性Attention的策略,将计算复杂度从二次降低到线性。
优化Token处理模式:通过下采样、稀疏Attention模式等方法提高效率,同时尽量减少对性能的影响。
方法对比:
2. Transformer中的归一化
Transformer中归一化层对于稳定训练和提升性能至关重要。LayerNorm是Transformer中默认使用的归一化方法,但其在推理阶段的统计计算开销较高,限制了模型的效率。研究者尝试将BatchNorm等效率更高的归一化技术引入Transformer,但因BatchNorm对训练的敏感性和稳定性问题,其效果往往不如LayerNorm。
为解决这一问题,部分研究提出在训练过程中调整归一化策略,或结合多种归一化方法,以兼顾训练稳定性和推理效率。然而,在实际应用中仍存在性能下降或训练不稳定的问题
现有方法对比
模型结构
本文的方法是针对于cv进行的操作,所以用于nlp 的话会有点匹配不上,但是原文中还是说过可以把这个注意力用到时序任务上,
所以基于原文的思路,本人进行了一些改进。
原文注意力结构
使用ReLU代替传统的Softmax作为Attention权重的核函数:
不再需要复杂的Softmax操作,计算复杂度从 O(N²C)降低到O(NC²)
Depth-wise卷积增强特征:
在计算Attention输出后,加入Depth-wise卷积提升局部特征表达能力:
好吧揭穿一下就是Conv2d这个是dwc卷积hhhhhh。原文代码如下
class SimplifiedLinearAttention(nn.Module):
def __init__(self, dim, num_heads=8, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.1,
focusing_factor=3, kernel_size=5):
super().__init__()
assert dim % num_heads == 0, "dim must be divisible by num_heads."
self.dim = dim
self.num_heads = num_heads
head_dim = dim // num_heads
self.q = nn.Linear(dim, dim, bias=qkv_bias)
self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
# Depthwise convolution for local feature enhancement Depth-wise卷积
self.dwc = nn.Conv1d(head_dim, head_dim, kernel_size=kernel_size, groups=head_dim, padding=kernel_size // 2)
self.focusing_factor = focusing_factor
def forward(self, x):
B, N, C = x.shape # Batch size, sequence length, feature dimension
q = self.q(x) # Query projection
kv = self.kv(x).reshape(B, N, 2, C).permute(2, 0, 1, 3) # Key and Value projection
k, v = kv[0], kv[1]
# Apply ReLU as the kernel function
kernel_function = nn.ReLU()
q = kernel_function(q)
k = kernel_function(k)
q, k, v = (rearrange(t, "b n (h c) -> b h n c", h=self.num_heads) for t in [q, k, v])
# Linear attention mechanism
z = 1 / (torch.einsum("b h n c, b h c -> b h n", q, k.sum(dim=2)) + 1e-6)
kv_product = torch.einsum("b h m c, b h m d -> b h c d", k, v)
x = torch.einsum("b h n c, b h c d, b h n -> b h n d", q, kv_product, z)
# Combine heads and apply depthwise convolution
x = rearrange(x, "b h n c -> (b h) c n") # Reshape for depthwise convolution
x = self.dwc(x) * self.focusing_factor # Apply depthwise convolution
x = rearrange(x, "(b h) c n -> b n (h c)", h=self.num_heads) # Combine heads back
# Final projection
x = self.proj(x)
x = self.proj_drop(x)
return x
个人改进
我个人改进是基于时序任务的,这个是没有在原文中进行体现的一方面,并且原文代码中只用于图像,我想自己尝试一下能不能用到时序的地方。
4.1 改进优势
局部特征建模增强:
在原始全局建模的基础上,加入了对局部模式的感知能力,使模型更适用于局部显著性任务。
例如,在时间序列中捕获局部峰值模式,在图像任务中对局部区域的特征建模。
a.保留原始复杂度:
深度卷积的引入并未显著增加计算复杂度,仍然保持原文的线性复杂度。
b.增强灵活性:
用户可以通过 kernel_size 和 focusing_factor 调整局部增强的强度,适应不同任务需求。
4.2 潜在问题
可能偏离原文目标的纯粹性:
原文章的目标是简化注意力机制,而改进版本引入了深度卷积,可能略微增加模块复杂度。
如果应用场景仅需要全局特征建模,局部卷积可能是多余的。
额外参数:
深度卷积引入了额外的超参数(如卷积核大小、卷积强度等),可能需要更多的超参调优。
代码会更新到群里