当前位置: 首页 > article >正文

(icml2024)SLAattention,基于原文时序模型进行改进

#代码: https://github.com/xinghaochen/SLAB
#论文:https://arxiv.org/pdf/2405.11582

相关工作

1. 高效Transformer架构

背景: Transformer从最初的自然语言处理扩展到计算机视觉领域(例如ViT),但由于计算复杂度高,尤其是Attention机制的二次复杂度,难以部署在资源受限的设备上。

Transformer在自然语言处理和计算机视觉领域得到了广泛应用,但其计算复杂度较高,尤其是Attention模块的二次复杂度,成为部署在资源受限设备上的主要瓶颈。为此,研究者提出了多种优化方案:

降低Attention计算开销:通过限制Token的交互范围(如局部窗口计算)或引入近似线性Attention的策略,将计算复杂度从二次降低到线性。
优化Token处理模式:通过下采样、稀疏Attention模式等方法提高效率,同时尽量减少对性能的影响。

方法对比:
在这里插入图片描述

2. Transformer中的归一化

Transformer中归一化层对于稳定训练和提升性能至关重要。LayerNorm是Transformer中默认使用的归一化方法,但其在推理阶段的统计计算开销较高,限制了模型的效率。研究者尝试将BatchNorm等效率更高的归一化技术引入Transformer,但因BatchNorm对训练的敏感性和稳定性问题,其效果往往不如LayerNorm。

为解决这一问题,部分研究提出在训练过程中调整归一化策略,或结合多种归一化方法,以兼顾训练稳定性和推理效率。然而,在实际应用中仍存在性能下降或训练不稳定的问题
现有方法对比
在这里插入图片描述

模型结构

本文的方法是针对于cv进行的操作,所以用于nlp 的话会有点匹配不上,但是原文中还是说过可以把这个注意力用到时序任务上,
所以基于原文的思路,本人进行了一些改进。

原文注意力结构

使用ReLU代替传统的Softmax作为Attention权重的核函数:
在这里插入图片描述
在这里插入图片描述

不再需要复杂的Softmax操作,计算复杂度从 O(N²C)降低到O(NC²)
Depth-wise卷积增强特征:

在这里插入图片描述

在计算Attention输出后,加入Depth-wise卷积提升局部特征表达能力:
好吧揭穿一下就是Conv2d这个是dwc卷积hhhhhh。原文代码如下

class SimplifiedLinearAttention(nn.Module):
    def __init__(self, dim, num_heads=8, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.1,
                 focusing_factor=3, kernel_size=5):
        super().__init__()
        assert dim % num_heads == 0, "dim must be divisible by num_heads."
        self.dim = dim
        self.num_heads = num_heads
        head_dim = dim // num_heads

        self.q = nn.Linear(dim, dim, bias=qkv_bias)
        self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

        # Depthwise convolution for local feature enhancement  Depth-wise卷积
        self.dwc = nn.Conv1d(head_dim, head_dim, kernel_size=kernel_size, groups=head_dim, padding=kernel_size // 2)
        self.focusing_factor = focusing_factor

    def forward(self, x):
        B, N, C = x.shape  # Batch size, sequence length, feature dimension
        q = self.q(x)  # Query projection
        kv = self.kv(x).reshape(B, N, 2, C).permute(2, 0, 1, 3)  # Key and Value projection
        k, v = kv[0], kv[1]

        # Apply ReLU as the kernel function
        kernel_function = nn.ReLU()
        q = kernel_function(q)
        k = kernel_function(k)
        q, k, v = (rearrange(t, "b n (h c) -> b h n c", h=self.num_heads) for t in [q, k, v])

        # Linear attention mechanism
        z = 1 / (torch.einsum("b h n c, b h c -> b h n", q, k.sum(dim=2)) + 1e-6)
        kv_product = torch.einsum("b h m c, b h m d -> b h c d", k, v)
        x = torch.einsum("b h n c, b h c d, b h n -> b h n d", q, kv_product, z)

        # Combine heads and apply depthwise convolution
        x = rearrange(x, "b h n c -> (b h) c n")  # Reshape for depthwise convolution
        x = self.dwc(x) * self.focusing_factor  # Apply depthwise convolution
        x = rearrange(x, "(b h) c n -> b n (h c)", h=self.num_heads)  # Combine heads back

        # Final projection
        x = self.proj(x)
        x = self.proj_drop(x)

        return x

个人改进

我个人改进是基于时序任务的,这个是没有在原文中进行体现的一方面,并且原文代码中只用于图像,我想自己尝试一下能不能用到时序的地方。
4.1 改进优势
局部特征建模增强:
在原始全局建模的基础上,加入了对局部模式的感知能力,使模型更适用于局部显著性任务。
例如,在时间序列中捕获局部峰值模式,在图像任务中对局部区域的特征建模。

a.保留原始复杂度:
深度卷积的引入并未显著增加计算复杂度,仍然保持原文的线性复杂度。

b.增强灵活性:
用户可以通过 kernel_size 和 focusing_factor 调整局部增强的强度,适应不同任务需求。

4.2 潜在问题
可能偏离原文目标的纯粹性:

原文章的目标是简化注意力机制,而改进版本引入了深度卷积,可能略微增加模块复杂度。
如果应用场景仅需要全局特征建模,局部卷积可能是多余的。
额外参数:
深度卷积引入了额外的超参数(如卷积核大小、卷积强度等),可能需要更多的超参调优。

在这里插入图片描述

代码会更新到群里


http://www.kler.cn/a/454878.html

相关文章:

  • 编译笔记:vs 中 正在从以下位置***加载符号 C# 中捕获C/C++抛出的异常
  • .NET平台用C#通过字节流动态操作Excel文件
  • wordpress调用指定ID分类下浏览最多的内容
  • Nmap基础入门及常用命令汇总
  • C++ —— 模板类与函数
  • PH热榜 | 2024-12-26
  • 【AIGC篇】AIGC 引擎:点燃创作自动化的未来之火
  • 项目报 OutOfMemoryError 、GC overhead limit exceeded 问题排查以及解决思路实战
  • LeetCode 热题 100_二叉树的中序遍历(36_94_简单_C++)(二叉树;递归(中序遍历);迭代)
  • 如何在 Ubuntu 22.04 上安装 Ansible 教程
  • OpenStack系列第三篇:CentOS7 上部署 OpenStack(Train版)集群教程 Ⅲ Nova Neutron 服务部署
  • Go语言反射从入门到进阶
  • js 生成二维码(qrcodejs2-fix)
  • Intel AMD Hygon CPU缓存
  • 分阶段总结:建材制造业“数字化转型”总体架构与实现路径
  • 06 - Django 视图view
  • 拉链表,流⽔表以及快照表的含义和特点
  • vscode remote-ssh 免密登录不生效的问题
  • vue2 通过url ‘URLScheme‘实现直接呼起小程序
  • 社区版Dify+Ollama+llama3.2-vision 实现多模态聊天
  • 设计模式-创建型-工厂方法模式
  • 上位机开发 的字符串处理
  • 【206】图书管理系统
  • 实现类似gpt 打字效果
  • 提示词工程教程(七):小样本和上下文学习
  • Stability AI 新一代AI绘画模型:StableCascade 本地部署教程