(即插即用模块-特征处理部分) 二十、(TPAMI 2022) Permute-MLP 置换MLP模块
文章目录
- 1、Permute-MLP layer
- 2、代码实现
paper:Vision Permutator: A Permutable MLP-Like Architecture for Visual Recognition
Code:https://github.com/Andrew-Qibin/VisionPermutator
1、Permute-MLP layer
传统的 MLP-like 模型(如 Mixer 和 ResMLP)在编码图像特征时,首先会将空间维度展平,但这会导致丢失重要的位置信息,从而限制了模型的表达能力。这篇论文提出一中 置换MLP模块(Permute-MLP layer),Permute-MLP 旨在解决这个问题,通过分别对高度和宽度维度进行编码,保留位置信息,从而更好地捕捉图像中的空间关系。
PermuteMLP 的核心是分段排列操作,其能够有效地将空间信息嵌入到特征表示中,并保留位置信息。Permute-MLP 由三个独立的分支组成,分别负责编码高度、宽度和通道维度上的信息。每个分支包含一个全连接层,将输入特征映射到隐藏空间。
对于一个输入X,Permute MLP 的实现过程:
- 分段: 将输入特征沿通道维度分割成 S 个片段。
- 高度-通道置换: 对每个片段进行高度-通道置换操作。
- 通道维度拼接: 将置换后的片段沿通道维度拼接。
- 全连接层: 将拼接后的特征输入到一个全连接层,进行特征融合。
- 逆置换: 对特征进行逆置换操作,恢复到原始维度。
- 重复: 对宽度维度进行类似的操作,得到宽度信息编码结果。
- 通道信息编码: 对输入特征进行通道信息编码,得到通道信息编码结果。
- 特征融合: 将高度、宽度和通道信息编码结果拼接在一起,并输入到一个全连接层进行特征融合,得到 Permute-MLP 的最终输出。
Permute-MLP layer 结构图:
2、代码实现
import torch
import torch.nn as nn
class Mlp(nn.Module):
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class WeightedPermuteMLP(nn.Module):
def __init__(self, dim, segment_dim=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
super().__init__()
self.segment_dim = segment_dim
self.mlp_c = nn.Linear(dim, dim, bias=qkv_bias)
self.mlp_h = nn.Linear(dim, dim, bias=qkv_bias)
self.mlp_w = nn.Linear(dim, dim, bias=qkv_bias)
self.reweight = Mlp(dim, dim // 4, dim * 3)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x):
B, H, W, C = x.shape
S = C // self.segment_dim
h = x.reshape(B, H, W, self.segment_dim, S).permute(0, 3, 2, 1, 4).reshape(B, self.segment_dim, W, H * S)
h = self.mlp_h(h).reshape(B, self.segment_dim, W, H, S).permute(0, 3, 2, 1, 4).reshape(B, H, W, C)
w = x.reshape(B, H, W, self.segment_dim, S).permute(0, 1, 3, 2, 4).reshape(B, H, self.segment_dim, W * S)
w = self.mlp_w(w).reshape(B, H, self.segment_dim, W, S).permute(0, 1, 3, 2, 4).reshape(B, H, W, C)
c = self.mlp_c(x)
a = (h + w + c).permute(0, 3, 1, 2).flatten(2).mean(2)
a = self.reweight(a).reshape(B, C, 3).permute(2, 0, 1).softmax(dim=0).unsqueeze(2).unsqueeze(2)
x = h * a[0] + w * a[1] + c * a[2]
x = self.proj(x)
x = self.proj_drop(x)
return x
if __name__ == '__main__':
x = torch.randn(4, 512, 8, 8).cuda()
x = x.permute(0,3,2,1)
model = WeightedPermuteMLP(512).cuda()
out = model(x)
out = out.permute(0,3,2,1)
print(out.shape)