当前位置: 首页 > article >正文

笔记01----Transformer高效语义分割解码器模块DEPICT(即插即用)

学习笔记01----即插即用的解码器模块DEPICT

    • 前言
    • 源码下载
    • DEPICT实现
    • 实验

前言

文 章 标 题:《Rethinking Decoders for Transformer-based Semantic Segmentation: Compression is All You Need》
当前的 Transformer-based 方法(如 DETR 和其变体)取得了显著进展。但这些解码器(decoder)的设计更多是基于经验,缺乏理论解释,难以确定性能瓶颈并进行进一步改进。
该论文将语义分割任务建模为“从主空间到子空间的信息压缩”问题,强调从高维图像特征中提取类别相关的紧凑表示。
提出 DEPICT 解码器:

  • 基于 自注意力(MSSA) 和 交叉注意力(MSCA) 设计简单高效的解码器。
  • MSSA 构建主子空间,去除冗余,优化图像特征。
  • MSCA 动态提取类别相关特征,生成类别嵌入的低维表示。

源码下载

源代码地址:https://github.com/QishuaiWen/DEPICT

DEPICT实现

在这里插入图片描述
DEPICT流程:
1. 图像特征输入: 通过vit的主干网络对图像进行特征提取。这些特征中可能包含很多不重要的信息,比如背景噪声。我们的目标是提取出与分类相关的特征。
2.sa模式—自注意力模块(MSSA): 通过自注意力机制(Multi-head Subspace Self-Attention, MSSA),捕捉图像块之间的全局关系,去掉不相关信息,优化出更加紧凑的主要特征(主子空间)。它的具体操作是将 类别嵌入向量图像特征进行 拼接操作 输入 MSSA模块进行特征优化。
3.ca模式—交叉注意力模块(MSCA):类别嵌入(这是一个可学习的特征向量)作为查询,图像特征作为键和值,通过交叉注意力(Multi-head Subspace Cross-Attention, MSCA)提取每个类别的相关特征,生成类别嵌入的低维表示。它的具体操作是将 类别嵌入向量 作为 查询向量 通过MSCA进行特征优化。
类别嵌入向量是一个可学习的参数,是从 主空间中提取 出的,与类别强相关的特征子集,是图像特征的降维。
4.生成分割掩码:用点积操作比较图像特征和类别嵌入,生成每块图像属于每个类别的概率。

import torch
import torch.nn as nn
from einops import rearrange
from timm.models.layers import trunc_normal_
from dec_blocks import Transformer
from segm.model.utils import init_weights
class MaskTransformer(nn.Module):
    def __init__(
            self,
            n_cls,#类别数量
            patch_size,# 图像分块大小
            n_layers,  # Transformer 的层数
            n_heads,  # 多头注意力中的头数
            d_model,  # 特征的嵌入维度
            dropout,  # dropout 概率
            mode='ca',  # 模式选择:'ca' (交叉注意力) 或 'sa' (自注意力)
    ):
        super().__init__()

        self.patch_size = patch_size
        self.n_cls = n_cls
        self.mode = mode

        # cls_emb 是类别嵌入矩阵,初始化为随机值,形状为 (1, n_cls, d_model)。
        # 在 DEPICT 中,类别嵌入对应于主子空间的基向量 P
        self.cls_emb = nn.Parameter(torch.randn(1, n_cls, d_model))
        
        if mode == 'sa':
            # 提取图像主特征
            self.net = Transformer(d_model, n_layers, n_heads, 100, dropout)
            self.decoder_norm = nn.LayerNorm(d_model)
        elif mode == 'ca':
            # 用于优化图像特征的主特征
            self.snet = Transformer(d_model, n_layers, n_heads, 100, dropout)
            # 用于进一步提取类别嵌入
            self.cnet = Transformer(d_model, 3, n_heads, 50, dropout)
            self.snorm = nn.LayerNorm(d_model)
            self.cnorm = nn.LayerNorm(d_model)
        else:
            raise ValueError(f"Provided mode: {mode} is not valid.")
            
        self.mask_norm = nn.LayerNorm(n_cls)

        self.apply(init_weights)
        trunc_normal_(self.cls_emb, std=0.02)

    @torch.jit.ignore
    def no_weight_decay(self):
        return {"cls_emb"}

    def forward(self, x, im_size=None):
        H, W = im_size

        GS = H // self.patch_size

        # 扩张维度从(1, n_cls, d_model)到(batch_size,n_cls,d_model)
        cls_emb = self.cls_emb.expand(x.size(0), -1, -1)
        
        if self.mode == 'sa':
            # 拼接图像特征和类别嵌入
            # (batch_size,num_patches,d_model)
            x = torch.cat((x, cls_emb), 1)
            # 通过 Transformer 网络
            x = self.net(x)
            # 归一化处理
            x = self.decoder_norm(x)
            # patches优化后的图像特征。
            # cls_seg_feat:更新后的类别嵌入
            patches, cls_seg_feat = x[:, :-self.n_cls], x[:, -self.n_cls:]
        else:
            # 优化图像特征
            x = self.snet(x)
            # 归一化处理
            x = self.snorm(x)
            # 通过交叉注意力提取类别嵌入
            cls_emb = self.cnet(x, query=cls_emb)
            # 归一化
            cls_emb = self.cnorm(cls_emb)
            # patches优化后的图像特征。
            # cls_seg_feat:更新后的类别嵌入
            patches, cls_seg_feat = x, cls_emb

        #  向量标准化
        patches = patches / patches.norm(dim=-1, keepdim=True)
        cls_seg_feat = cls_seg_feat / cls_seg_feat.norm(dim=-1, keepdim=True)

        # 点积操作:生成掩码
        # patches:形状为 (batch_size, num_patches, d_model)。
        # cls_seg_feat:形状为 (batch_size, n_cls, d_model)
        # 转为 (batch_size, d_model, n_cls),方便点积运算。
        # 输出 masks 的形状为 (batch_size, num_patches, n_cls),表示每个 patch 属于每个类别的得分。
        masks = patches @ cls_seg_feat.transpose(1, 2)
        # 标准化为了简化训练
        masks = self.mask_norm(masks)
        # 重排掩码形状
        masks = rearrange(masks, "b (h w) n -> b n h w", h=int(GS))

        return masks

调用测试代码

def main():
    # 配置参数
    n_cls = 10           # 类别数,例如分割任务有 10 个类别
    patch_size = 16       # 图像分块大小
    n_layers = 4          # Transformer 层数
    n_heads = 8           # 多头注意力头数
    d_model = 128         # 特征嵌入维度
    dropout = 0.1         # dropout 比例
    mode = 'ca'           # 模式选择:'ca' 或 'sa'
    # 初始化 MaskTransformer
    model = MaskTransformer(
        n_cls=n_cls,
        patch_size=patch_size,
        n_layers=n_layers,
        n_heads=n_heads,
        d_model=d_model,
        dropout=dropout,
        mode=mode
    )
    # 测试输入
    batch_size = 2        # 批次大小
    image_size = 128      # 图像尺寸(假设输入为 128x128)
    num_patches = (image_size // patch_size) ** 2  # 分块后有多少个 patch
    # 生成随机的图像特征输入 (batch_size, num_patches, d_model)
    x = torch.randn(batch_size, num_patches, d_model)
    # 设置 im_size
    im_size = (image_size, image_size)
    # 运行模型
    masks = model(x, im_size=im_size)
    # 输出形状
    print("Output masks shape:", masks.shape)

实验

ADE20KcityscapePascalContext数据集
在这里插入图片描述
在这里插入图片描述

在这里插入图片描述


http://www.kler.cn/a/399127.html

相关文章:

  • Qt 每日面试题 -10
  • 微服务即时通讯系统的实现(客户端)----(3)
  • C++之内存管理
  • 活着就好20241118
  • 【windows笔记】08-Windows中的各种快捷方式、符号链接、目录联接、硬链接的区别和使用方法
  • VSCode插件
  • 【配置后的基本使用】CMake基础知识
  • opc da 服务器数据 转 IEC61850项目案例
  • 人工智能+辅助诊疗
  • 雨晨 Hotpatch 24H2 Windows 11 iotltsc2024 极简版 26100.2240
  • 十五届蓝桥杯赛题-c/c++ 大学b组
  • R语言机器学习与临床预测模型77--机器学习预测常用R语言包
  • 基于STM32的智能家居系统:MQTT、AT指令、TCP\HTTP、IIC技术
  • CentOS中的Firewalld:全面介绍与实战应用
  • 《C++设计模式:重塑游戏角色系统类结构的秘籍》
  • GCP Cloud Storage 的lock retention policy是什么
  • pytorch tensor在CPU和GPU之间转换,numpy之间的转换
  • C++初级入门(1)
  • Istio分布式链路监控搭建:Jaeger与Zipkin
  • 在VMware虚拟机环境下识别U盘
  • 25-深入理解 JavaScript 异步生成器的实现
  • 基于Java的旅游类小程序开发与优化
  • Qt桌面应用开发 第四天(对话框 界面布局)
  • 【项目开发】理解SSL延迟:为何HTTPS比HTTP慢?
  • MoneyPrinterTurbo - AI自动生成高清短视频
  • 学习大数据DAY62 指标计算