当前位置: 首页 > article >正文

【Attention】SKAttention

SKAttention选择核注意力

标题:SKAttention

期刊:IEEE2019

代码: https://github.com/implus/SKNet

简介:

  • 动机:增大感受野来提升性能、多尺度信息聚合方式
  • 解决的问题:自适应调整感受野大小
  • 创新性:提出选择性内核(SK)卷积softmax来进行自适应选择

模型结构

在这里插入图片描述

模型代码

import numpy as np
import torch
from torch import nn
from torch.nn import init
from collections import OrderedDict

# Selective Kernel Attention
class SKAttention(nn.Module):

    def __init__(self, channel=512, kernels=[1, 3, 5, 7], reduction=16, group=1, L=32):
        super().__init__()
        # 中间维度d的计算
        self.d = max(L, channel // reduction)
        # 多分支卷积层(使用不同尺寸的卷积核)
        self.convs = nn.ModuleList([])
        for k in kernels:
            self.convs.append(
                nn.Sequential(OrderedDict([
                    # 分组卷积(输入输出通道数相同,保持维度)
                    ('conv', nn.Conv2d(channel, channel, kernel_size=k, padding=k // 2, groups=group)),
                    # 批归一化(保持维度)  
                    ('bn', nn.BatchNorm2d(channel)),
                    # ReLU激活函数
                    ('relu', nn.ReLU())
                ]))
            )
        # # 通道压缩层(全连接层)
        self.fc = nn.Linear(channel, self.d)
        # 多分支注意力权重生成层
        self.fcs = nn.ModuleList([])
        for i in range(len(kernels)):
            self.fcs.append(nn.Linear(self.d, channel))
        # 注意力权重归一化(沿分支维度softmax)
        self.softmax = nn.Softmax(dim=0)

    def forward(self, x):# 输入x形状: [B, C, H, W]
        bs, c, _, _ = x.size() # 获取输入的batch_size, 通道数, 高度, 宽度
        conv_outs = []
        ### Split阶段:多分支特征提取
        for conv in self.convs:
            conv_outs.append(conv(x)) # 每个分支输出: [B, C, H, W]
        feats = torch.stack(conv_outs, 0)  # 堆叠后形状: [K, B, C, H, W](K是kernel数量)

        ### Fuse阶段:特征融合
        U = sum(conv_outs) # 逐元素相加 → [B, C, H, W]

        ### Channel Reduction:通道压缩
        S = U.mean(-1).mean(-1)  # 空间全局平均池化 → [B, C,1,1]
        Z = self.fc(S)   # 全连接层降维 → [B, d](d=self.d)

        ### 计算注意力权重
        weights = []
        for fc in self.fcs: #  每个kernel对应一个全连接层
            weight = fc(Z) # 全连接层输出 → [B, C]
            weights.append(weight.view(bs, c, 1, 1))  # 调整形状 → [B, C, 1, 1]
        attention_weughts = torch.stack(weights, 0)   # 堆叠 → [K, B, C, 1, 1]
        attention_weughts = self.softmax(attention_weughts)  # 沿K维度softmax归一化

        ### fuse
        V = (attention_weughts * feats).sum(0) # 加权求和 → [B, C, H, W]
        return V


if __name__ == '__main__':
    input = torch.rand(1,64,256,256).cuda()
    model = SKAttention(channel=64, reduction=8).cuda()
    output = model (input)
    print('input_size:', input.size())
    print('output_size:', output.size())
    print("最大内存占用:", torch.cuda.max_memory_allocated() // 1024 // 1024, "MB")
    

http://www.kler.cn/a/596357.html

相关文章:

  • MySQL密码修改的全部方式一篇详解
  • vue学习九
  • 红宝书第十一讲:超易懂版「ES6类与继承」零基础教程:用现实例子+图解实现
  • 生物信息复习笔记(3)——GEO数据库
  • CPU架构和微架构
  • Redis 知识点梳理
  • 如何快速定位高 CPU 使用率的进程
  • git_version_control_proper_practice
  • Linux:基础IO---文件描述符
  • cmakelist中添加opencv
  • 【风信】邮件系统的介绍和使用。
  • Stable Diffusion lora训练(一)
  • 如何防御大模型中的 Prompt 攻击?
  • [蓝桥杯 2023 省 B] 子串简写
  • 深入理解 Spring 框架中的 IOC 容器
  • 六种开源智能体通信协议对比:MCP、ANP、Agora、agents.json、LMOS、AITP
  • 第十六届蓝桥杯模拟二
  • C++面试准备一(常考)
  • JVM垃圾回收笔记01
  • 冒排排序相关