当前位置: 首页 > article >正文

(即插即用模块-特征处理部分) 二十、(TPAMI 2022) Permute-MLP 置换MLP模块

在这里插入图片描述

文章目录

  • 1、Permute-MLP layer
  • 2、代码实现

paper:Vision Permutator: A Permutable MLP-Like Architecture for Visual Recognition

Code:https://github.com/Andrew-Qibin/VisionPermutator


1、Permute-MLP layer

传统的 MLP-like 模型(如 Mixer 和 ResMLP)在编码图像特征时,首先会将空间维度展平,但这会导致丢失重要的位置信息,从而限制了模型的表达能力。这篇论文提出一中 置换MLP模块(Permute-MLP layer),Permute-MLP 旨在解决这个问题,通过分别对高度和宽度维度进行编码,保留位置信息,从而更好地捕捉图像中的空间关系。

PermuteMLP 的核心是分段排列操作,其能够有效地将空间信息嵌入到特征表示中,并保留位置信息。Permute-MLP 由三个独立的分支组成,分别负责编码高度、宽度和通道维度上的信息。每个分支包含一个全连接层,将输入特征映射到隐藏空间。

对于一个输入X,Permute MLP 的实现过程:

  1. 分段: 将输入特征沿通道维度分割成 S 个片段。
  2. 高度-通道置换: 对每个片段进行高度-通道置换操作。
  3. 通道维度拼接: 将置换后的片段沿通道维度拼接。
  4. 全连接层: 将拼接后的特征输入到一个全连接层,进行特征融合。
  5. 逆置换: 对特征进行逆置换操作,恢复到原始维度。
  6. 重复: 对宽度维度进行类似的操作,得到宽度信息编码结果。
  7. 通道信息编码: 对输入特征进行通道信息编码,得到通道信息编码结果。
  8. 特征融合: 将高度、宽度和通道信息编码结果拼接在一起,并输入到一个全连接层进行特征融合,得到 Permute-MLP 的最终输出。

Permute-MLP layer 结构图:
在这里插入图片描述

2、代码实现

import torch
import torch.nn as nn


class Mlp(nn.Module):
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x


class WeightedPermuteMLP(nn.Module):
    def __init__(self, dim, segment_dim=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
        super().__init__()
        self.segment_dim = segment_dim

        self.mlp_c = nn.Linear(dim, dim, bias=qkv_bias)
        self.mlp_h = nn.Linear(dim, dim, bias=qkv_bias)
        self.mlp_w = nn.Linear(dim, dim, bias=qkv_bias)

        self.reweight = Mlp(dim, dim // 4, dim * 3)

        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x):
        B, H, W, C = x.shape

        S = C // self.segment_dim
        h = x.reshape(B, H, W, self.segment_dim, S).permute(0, 3, 2, 1, 4).reshape(B, self.segment_dim, W, H * S)
        h = self.mlp_h(h).reshape(B, self.segment_dim, W, H, S).permute(0, 3, 2, 1, 4).reshape(B, H, W, C)

        w = x.reshape(B, H, W, self.segment_dim, S).permute(0, 1, 3, 2, 4).reshape(B, H, self.segment_dim, W * S)
        w = self.mlp_w(w).reshape(B, H, self.segment_dim, W, S).permute(0, 1, 3, 2, 4).reshape(B, H, W, C)

        c = self.mlp_c(x)

        a = (h + w + c).permute(0, 3, 1, 2).flatten(2).mean(2)
        a = self.reweight(a).reshape(B, C, 3).permute(2, 0, 1).softmax(dim=0).unsqueeze(2).unsqueeze(2)

        x = h * a[0] + w * a[1] + c * a[2]

        x = self.proj(x)
        x = self.proj_drop(x)

        return x


if __name__ == '__main__':
    x = torch.randn(4, 512, 8, 8).cuda()
    x = x.permute(0,3,2,1)
    model = WeightedPermuteMLP(512).cuda()
    out = model(x)
    out = out.permute(0,3,2,1)
    print(out.shape)


http://www.kler.cn/a/528948.html

相关文章:

  • Redis|前言
  • 简要介绍C++中的 max 和 min 函数以及返回值
  • 【gRPC-gateway】初探grpc网关,插件安装,默认实现,go案例
  • 缩位求和——蓝桥杯
  • C++并行化编程
  • 【linux】Linux 常见目录特性、权限和功能
  • LeetCode题练习与总结:种花问题--605
  • C基础寒假练习(6)
  • 【数据采集】案例01:基于Scrapy采集豆瓣电影Top250的详细数据
  • 解决istoreos无法拉取青龙镜像
  • Java小白入门教程:HashSet
  • ZZNUOJ(C/C++)基础练习1031——1040(详解版)
  • 【JAVA】循环语句
  • 工作中使用到的单词(软件开发)_第一、二、三版汇总
  • TensorFlow 示例摄氏度到华氏度的转换(一)
  • 作者新游戏1.0
  • Linux中 端口被占用如何解决
  • rust跨平台调用动态库
  • 设计模式Python版 组合模式
  • DRM系列六:Drm之KMS
  • 线程的状态转换和调度
  • 深入理解Spring框架:从基础到实践
  • python学opencv|读取图像(五十三)原理探索:使用cv.matchTemplate()函数实现最佳图像匹配
  • 996引擎 -地图-添加安全区
  • 群速度与相速度辨析
  • NIST的 临床质量指标的简介