当前位置: 首页 > article >正文

ResNeSt-2020笔记

来源:

[2004.08955] ResNeSt: Split-Attention Networks

相关工作:

#CNN_Architectures #Multi-path_and_featuremap_Attention #Neural_Architecture_Search

创新点:

ulo1vhnd.y1j.png

2sx5rqad.1pl.png

3ajfotzj.c4c.png

贡献:

  • 提出了一种新的Split-Attention块,能够在不同特征图组之间实现特征图注意力。

  • 通过引入新的基数(radix)超参数,扩展了特征图分组的数量,提高了模型的表示能力。

  • 实现了一种高效的径向优先(radix-major)实现方式,使得Split-Attention块能够通过标准CNN操作进行加速。

代码:

# ---------------------------------------  
# 论文: ResNest: Split-attention networks (arXiv 2020)  
# ---------------------------------------  
import torch  
from torch import nn  
import torch.nn.functional as F  
  
  
def make_divisible(v, divisor=8, min_value=None, round_limit=.9):  
    min_value = min_value or divisor  
    new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)  
    # Make sure that round down does not go down by more than 10%.  
    if new_v < round_limit * v:  
        new_v += divisor  
    return new_v  
  
  
class RadixSoftmax(nn.Module):  
    def __init__(self, radix, cardinality):  
        super(RadixSoftmax, self).__init__()  
        self.radix = radix  
        self.cardinality = cardinality  
  
    def forward(self, x):  
        batch = x.size(0)  
        if self.radix > 1:  
            x = x.view(batch, self.cardinality, self.radix, -1).transpose(1, 2)  
            x = F.softmax(x, dim=1)  
            x = x.reshape(batch, -1)  
        else:  
            x = x.sigmoid()  
        return x  
  
  
class SplitAttn(nn.Module):  
    """Split-Attention (aka Splat)  
    """  
    def __init__(self, in_channels, out_channels=None, kernel_size=3, stride=1, padding=None,  
                 dilation=1, groups=1, bias=False, radix=2, rd_ratio=0.25, rd_channels=None, rd_divisor=8,  
                 act_layer=nn.ReLU, norm_layer=None, drop_block=None, **kwargs):  
        super(SplitAttn, self).__init__()  
        out_channels = out_channels or in_channels  
        self.radix = radix  
        self.drop_block = drop_block  
        mid_chs = out_channels * radix  
        if rd_channels is None:  
            attn_chs = make_divisible(  
                in_channels * radix * rd_ratio, min_value=32, divisor=rd_divisor)  
        else:  
            attn_chs = rd_channels * radix  
  
        padding = kernel_size // 2 if padding is None else padding  
        self.conv = nn.Conv2d(  
            in_channels, mid_chs, kernel_size, stride, padding, dilation,  
            groups=groups * radix, bias=bias, **kwargs)  
        self.bn0 = norm_layer(mid_chs) if norm_layer else nn.Identity()  
        self.act0 = act_layer()  
        self.fc1 = nn.Conv2d(out_channels, attn_chs, 1, groups=groups)  
        self.bn1 = norm_layer(attn_chs) if norm_layer else nn.Identity()  
        self.act1 = act_layer()  
        self.fc2 = nn.Conv2d(attn_chs, mid_chs, 1, groups=groups)  
        self.rsoftmax = RadixSoftmax(radix, groups)  
  
    def forward(self, x):  
        x = self.conv(x)  
        x = self.bn0(x)  
        if self.drop_block is not None:  
            x = self.drop_block(x)  
        x = self.act0(x)  
  
        B, RC, H, W = x.shape  
        if self.radix > 1:  
            x = x.reshape((B, self.radix, RC // self.radix, H, W))  
            x_gap = x.sum(dim=1)  
        else:  
            x_gap = x  
        x_gap = x_gap.mean(2, keepdims=True).mean(3, keepdims=True)  
        x_gap = self.fc1(x_gap)  
        x_gap = self.bn1(x_gap)  
        x_gap = self.act1(x_gap)  
        x_attn = self.fc2(x_gap)  
  
        x_attn = self.rsoftmax(x_attn).view(B, -1, 1, 1)  
        if self.radix > 1:  
            out = (x * x_attn.reshape((B, self.radix,  
                                       RC // self.radix, 1, 1))).sum(dim=1)  
        else:  
            out = x * x_attn  
        return out  
  
  
# 输入 N C H W,  输出 N C H Wif __name__ == '__main__':  
    block = SplitAttn(64)  
    input = torch.rand(3, 64, 32, 32)  
    output = block(input)  
    print(input.size(), output.size())

http://www.kler.cn/a/524658.html

相关文章:

  • C# dataGridView1获取选中行的名字
  • 【Jave全栈】Java与JavaScript比较
  • 【Numpy核心编程攻略:Python数据处理、分析详解与科学计算】1.20 极值追踪:高效获取数据特征的秘诀
  • 【深度学习】 UNet详解
  • 【Linux笔记】Day4
  • arkui-x 前端布局编码模板
  • 【愚公系列】《循序渐进Vue.js 3.x前端开发实践》033-响应式编程的原理及在Vue中的应用
  • P10638 BZOJ4355 Play with sequence Solution
  • 前端实战:小程序搭建商品购物全流程
  • 第21节课:前端构建工具—自动化与模块化的利器
  • 移动人的新春”序曲“
  • ZZNUOJ(C/C++)基础练习1011——1020(详解版)
  • C语言数组编程实例
  • CTF从入门到精通
  • ollama如何将模型移动到D盘以及如何直接下载到D盘
  • CTFSHOW-WEB入门-命令执行39-53
  • 基于 WEB 开发的在线学习系统设计与开发
  • Ubuntu 16.04用APT安装MySQL
  • 掌握Java反射:在项目中高效应用反射机制
  • 价值交换到底在交换什么
  • 批量卸载fnm中已经安装的所有版本
  • 解决双系统引导问题:Ubuntu 启动时不显示 Windows 选项的处理方法
  • Redis学习之哨兵二
  • axios如何利用promise无痛刷新token
  • 计算机专业的多元就业方向
  • 基于 AWS SageMaker 对 DeepSeek-R1-Distilled-Llama-8B 模型的精调与实践