当前位置: 首页 > article >正文

(即插即用模块-Attention部分) 三十、(ICCV 2023) EAA 有效附加注意力

在这里插入图片描述

文章目录

  • 1、Efficient Additive Attention
  • 2、SwiftFormer
  • 3、代码实现

paper:SwiftFormer: Efficient Additive Attention for Transformer-based Real-time Mobile Vision Applications

Code:https://github.com/Amshaker/SwiftFormer


1、Efficient Additive Attention

在现有的研究中,传统的 Multi-Head Self-Attention (MHSA) 计算 复杂度高,难以在移动设备上实时运行。而现有的 Additive Attention 则需要计算 key 和 value 之间的显式交互,这限制了其效率和灵活性。所以,这篇论文进一步提出一种 有效附加注意力(Efficient Additive Attention) 。旨在解决 Transformer 模型在移动设备上部署时遇到的效率问题。

EAA 的基本思想主要有三点: 首先,EAA 通过将矩阵乘法替换为元素级乘法,显著降低了计算复杂度。然后,EAA 通过消除 key-value 交互,并使用线性变换来学习 token 之间的关系,从而简化了计算过程。最后,EAA 具有与 token 长度线性相关的计算复杂度,可以应用于网络的各个阶段,从而在整个网络中实现一致的上下文信息学习。

对于输入X,EAA 的实现过程:

  1. 前置处理:将图像或序列数据输入网络,并经过卷积或线性层转换为嵌入向量。将嵌入向量分为 query 和 key 部分。

  2. Query 聚合:将 query 嵌入向量与可学习的权重向量相乘,得到每个 query 的权重。使用 Softmax 函数将权重归一化,并使用池化操作将所有 query 的加权嵌入向量聚合为一个全局 query 向量。

  3. 全局上下文编码:将全局 query 向量与 key 嵌入向量进行元素级乘法,得到全局上下文表示。

  4. 线性变换:对全局上下文表示进行线性变换,得到最终的输出。

  5. 后置处理:将最终输出进行池化或平均等操作,得到最终的预测结果。


Efficient Additive Attention 结构图:
在这里插入图片描述


2、SwiftFormer

在 EAA 的基础上,论文还提出一种新框架 SwiftFormer ,SwiftFormer是一种轻量级的视觉 Transformer 架构,旨在在移动设备上实现高效的实时视觉应用。

SwiftFormer 主要由以下几个部分组成:

  • Patch Embedding: 将输入图像分割成多个 patch,并将每个 patch 转换为嵌入向量。
  • Conv Encoder: 使用深度可分离卷积提取局部特征,并进行特征混合。
  • SwiftFormer Encoder: 使用 Efficient Additive Attention 机制学习局部和全局特征,并进行特征融合。
  • Downsampling: 在每个阶段之间使用下采样操作,逐渐降低特征图的分辨率,并增加特征维度。
  • Classification Head: 对最终的特征图进行池化或平均,并进行线性变换,得到最终的分类结果。

SwiftFormer 结构图:
在这里插入图片描述

3、代码实现

import torch
import torch.nn as nn
import einops


class EfficientAdditiveAttnetion(nn.Module):
    def __init__(self, in_dims=512, token_dim=256, num_heads=2):
        super().__init__()

        self.to_query = nn.Linear(in_dims, token_dim * num_heads)
        self.to_key = nn.Linear(in_dims, token_dim * num_heads)

        self.w_g = nn.Parameter(torch.randn(token_dim * num_heads, 1))
        self.scale_factor = token_dim ** -0.5
        self.Proj = nn.Linear(token_dim * num_heads, token_dim * num_heads)
        self.final = nn.Linear(token_dim * num_heads, token_dim)

    def forward(self, x):
        query = self.to_query(x)
        key = self.to_key(x)

        query = torch.nn.functional.normalize(query, dim=-1) #BxNxD
        key = torch.nn.functional.normalize(key, dim=-1) #BxNxD

        query_weight = query @ self.w_g # BxNx1 (BxNxD @ Dx1)
        A = query_weight * self.scale_factor # BxNx1

        A = torch.nn.functional.normalize(A, dim=1) # BxNx1

        G = torch.sum(A * query, dim=1) # BxD

        G = einops.repeat(
            G, "b d -> b repeat d", repeat=key.shape[1]
        ) # BxNxD

        out = self.Proj(G * key) + query #BxNxD

        out = self.final(out) # BxNxD

        return out


if __name__ == '__main__':
    x = torch.randn(4, 10, 512).cuda()
    model = EfficientAdditiveAttnetion(512, 512).cuda()
    output = model(x)
    print(output.shape)

http://www.kler.cn/a/461711.html

相关文章:

  • 启航数据结构算法之雅舟,悠游C++智慧之旅——线性艺术:顺序表之细腻探索
  • HTML——73.button按钮
  • seata分布式事务详解(AT)
  • 第R3周:RNN-心脏病预测
  • 简单使用linux
  • 超高分辨率 图像 分割处理
  • Redis下载与安装
  • Python-MNE-源空间和正模型04:头模型和前向计算
  • 计算机毕业设计Python动漫推荐系统 漫画推荐系统 动漫视频推荐系统 机器学习 bilibili动漫爬虫 数据可视化 数据分析 大数据毕业设计
  • vue2 如何刷新页面
  • 【每日学点鸿蒙知识】上拉加载下拉刷新、napi调试报错、安装验证包、子线程播放音视频文件、OCR等
  • 【Vim Masterclass 笔记04】S03L12:Vim 文本删除同步练习课 + S03L13:练习课点评
  • redis是如何保证数据安全的?
  • LoRA微调系列笔记
  • UML类图的六大关系:依赖,泛化,实现,关联,聚合,组合
  • 使用Python实现基因组数据分析:探索生命的奥秘
  • 免押租赁系统助力共享经济发展新模式
  • 【JAVA】神经网络的基本结构和前向传播算法
  • WebAssembly 学习笔记
  • 网络安全 | 5G网络安全:未来无线通信的风险与对策
  • OpenVPN 被 Windows 升级破坏
  • Linux命令——3.网络与用户
  • SQL常用语句(基础)大全
  • C++算法20例
  • Listwise 模型时间线梳理
  • Flask是什么?深入解析 Flask 的设计与应用实践