(即插即用模块-Attention部分) 三十、(ICCV 2023) EAA 有效附加注意力
文章目录
- 1、Efficient Additive Attention
- 2、SwiftFormer
- 3、代码实现
paper:SwiftFormer: Efficient Additive Attention for Transformer-based Real-time Mobile Vision Applications
Code:https://github.com/Amshaker/SwiftFormer
1、Efficient Additive Attention
在现有的研究中,传统的 Multi-Head Self-Attention (MHSA) 计算 复杂度高,难以在移动设备上实时运行。而现有的 Additive Attention 则需要计算 key 和 value 之间的显式交互,这限制了其效率和灵活性。所以,这篇论文进一步提出一种 有效附加注意力(Efficient Additive Attention) 。旨在解决 Transformer 模型在移动设备上部署时遇到的效率问题。
EAA 的基本思想主要有三点: 首先,EAA 通过将矩阵乘法替换为元素级乘法,显著降低了计算复杂度。然后,EAA 通过消除 key-value 交互,并使用线性变换来学习 token 之间的关系,从而简化了计算过程。最后,EAA 具有与 token 长度线性相关的计算复杂度,可以应用于网络的各个阶段,从而在整个网络中实现一致的上下文信息学习。
对于输入X,EAA 的实现过程:
-
前置处理:将图像或序列数据输入网络,并经过卷积或线性层转换为嵌入向量。将嵌入向量分为 query 和 key 部分。
-
Query 聚合:将 query 嵌入向量与可学习的权重向量相乘,得到每个 query 的权重。使用 Softmax 函数将权重归一化,并使用池化操作将所有 query 的加权嵌入向量聚合为一个全局 query 向量。
-
全局上下文编码:将全局 query 向量与 key 嵌入向量进行元素级乘法,得到全局上下文表示。
-
线性变换:对全局上下文表示进行线性变换,得到最终的输出。
-
后置处理:将最终输出进行池化或平均等操作,得到最终的预测结果。
Efficient Additive Attention 结构图:
2、SwiftFormer
在 EAA 的基础上,论文还提出一种新框架 SwiftFormer ,SwiftFormer是一种轻量级的视觉 Transformer 架构,旨在在移动设备上实现高效的实时视觉应用。
SwiftFormer 主要由以下几个部分组成:
- Patch Embedding: 将输入图像分割成多个 patch,并将每个 patch 转换为嵌入向量。
- Conv Encoder: 使用深度可分离卷积提取局部特征,并进行特征混合。
- SwiftFormer Encoder: 使用 Efficient Additive Attention 机制学习局部和全局特征,并进行特征融合。
- Downsampling: 在每个阶段之间使用下采样操作,逐渐降低特征图的分辨率,并增加特征维度。
- Classification Head: 对最终的特征图进行池化或平均,并进行线性变换,得到最终的分类结果。
SwiftFormer 结构图:
3、代码实现
import torch
import torch.nn as nn
import einops
class EfficientAdditiveAttnetion(nn.Module):
def __init__(self, in_dims=512, token_dim=256, num_heads=2):
super().__init__()
self.to_query = nn.Linear(in_dims, token_dim * num_heads)
self.to_key = nn.Linear(in_dims, token_dim * num_heads)
self.w_g = nn.Parameter(torch.randn(token_dim * num_heads, 1))
self.scale_factor = token_dim ** -0.5
self.Proj = nn.Linear(token_dim * num_heads, token_dim * num_heads)
self.final = nn.Linear(token_dim * num_heads, token_dim)
def forward(self, x):
query = self.to_query(x)
key = self.to_key(x)
query = torch.nn.functional.normalize(query, dim=-1) #BxNxD
key = torch.nn.functional.normalize(key, dim=-1) #BxNxD
query_weight = query @ self.w_g # BxNx1 (BxNxD @ Dx1)
A = query_weight * self.scale_factor # BxNx1
A = torch.nn.functional.normalize(A, dim=1) # BxNx1
G = torch.sum(A * query, dim=1) # BxD
G = einops.repeat(
G, "b d -> b repeat d", repeat=key.shape[1]
) # BxNxD
out = self.Proj(G * key) + query #BxNxD
out = self.final(out) # BxNxD
return out
if __name__ == '__main__':
x = torch.randn(4, 10, 512).cuda()
model = EfficientAdditiveAttnetion(512, 512).cuda()
output = model(x)
print(output.shape)