当前位置: 首页 > article >正文

SPP/SPPF/Focal Module

一、在图像的分类任务重,卷积神经网络(CNN)一般含有5层:

  1. 输入层
  2. 卷积层
  3. 激活层
  4. 池化层
  5. 全连接层

·  全连接层通常要求输入为一维向量。在CNN中,卷积层和池化层的输出特征图会被展平(flatten)为一维向量,然后作为全连接层的输入。因此,全连接层对输入的尺寸有严格要求。

·  第一个全连接层的输入维度必须与前一层展平后的特征向量的长度一致,这就要求前面的卷积和池化层的输出特征图尺寸必须经过计算或预定义,以确保能够正确展平。

全连接层要求每个输入神经元与每个输出神经元完全连接。因此,全连接层的输入必须是一个固定长度的一维向量。这意味着输入的特征图的尺寸必须是固定的,以确保展平后的向量长度是确定的。如果输入特征图的尺寸发生变化,那么展平后的向量长度也会改变,这将导致全连接层无法正确处理这些输入数据。

二、在yolov3中引入了SPP,在yolov5及以后使用SPPF

1.什么是SPP?为什么引入SPP?

卷积神经网络(CNN)由卷积层和全连接层组成,其中卷积层对于输入数据的大小并没有要求,唯一对数据大小有要求的则是第一个全连接层

SPP的提出就是为了解决CNN输入图像大小必须固定的问题。

SPP的作用:

  1. 处理不同输入尺寸

 SPP 模块通过在不同尺度上进行池化操作,将特征图转换为固定长度的输出特征。具体来说,SPP 模块在特征图上应用多个不同大小的池化窗口(例如 1×1、2×2、4×4),将这些池化操作的结果拼接在一起,从而获得一个固定长度的特征向量。

  1. 保留空间信息

SPP 模块在不同尺度上进行池化操作时,可以保留输入图像的空间信息。不同尺度的池化操作捕捉了特征图中的不同层次的空间信息,从而保留了图像的局部和全局特征。

SPPF比SPP更快:

SPP 是使用了3个kernel size不一样大的pooling 并行运算。SPPF是将kernel size为5的 pooling串行运算,这样的运算的效果和SPP相同,但是运算速度加快。因为SPPF减少了重复的运算,每一次的pooling 运算都是在上一次运算的基础上进行的。

SPP结构图:

SPPF结构图:

使用FocalNet替代SPPF

一、SPP、SPPF

1.SPP模块:SPP的提出就是为了解决CNN输入图像大小必须固定的问题

       主要目的是通过多个不同尺度的池化操作,提取输入特征图中的多尺度特征。这种操作有助于模型更好地理解不同尺度的目标,并增强其对目标的检测能力。

FocalNet的亮点在于其专为关注重要信息而设计的Focal Module。它利用空间金字塔池化(Spatial Pyramid Pooling)和动态卷积(Dynamic Convolution)来捕捉不同尺度的目标,并有效地抑制背景干扰。这种设计在保持模型小规模的同时,提高了模型对复杂场景的理解能力。

FocalModulation模型通过以下步骤实现:

1. 焦点上下文化: 用深度卷积层堆叠来编码不同范围的视觉上下文。

焦点上下文化是指在不同的焦点级别(spatial scales)(也就是使用不同大小的卷积核)上聚合上下文信息。它通过一系列的卷积层对输入特征进行不同大小的卷积操作,从而捕捉不同范围的上下文信息。


2. 门控聚合: 通过门控机制,选择性地将上下文信息聚合到每个查询令牌的调制器中。

通过门控机制对不同焦点级别的上下文信息进行加权和合并的过程。其核心作用是根据每个焦点级别的信息重要性来调整其贡献,以实现更有效的上下文信息融合。


3. 逐元素仿射变换: 将聚合后的调制器通过仿射变换注入到每个查询令牌中。

意思就是将聚合后的上下文特征(ctx_all)注入到每个查询令牌(q)中,具体表现为与查询特征逐元素相乘。这一步骤将上下文信息直接作用于查询特征中,实现特征调制。

深度可分离卷积

=深度卷积+点对点卷积

演示分组,深度,深度可分离卷积|3D卷积神经网络_哔哩哔哩_bilibili

import torch
import torch.nn as nn

__all__ = ['FocalModulation']


class FocalModulation(nn.Module):
    def __init__(self, dim, focal_window=3, focal_level=2, focal_factor=2, bias=True, proj_drop=0.,
                 use_postln_in_modulation=False, normalize_modulator=False):
        super().__init__()

        self.dim = dim
        self.focal_window = focal_window
        self.focal_level = focal_level
        self.focal_factor = focal_factor
        self.use_postln_in_modulation = use_postln_in_modulation
        self.normalize_modulator = normalize_modulator

        #f_linear 是1x1卷积核,用于线性投影,将输入特征映射到更高的维度上 来生成查询、上下文、门控
        self.f_linear = nn.Conv2d(dim, 2 * dim + (self.focal_level + 1), kernel_size=1, bias=bias)
        #self.h  是1x1卷积核   用于后续的焦点调制
        self.h = nn.Conv2d(dim, dim, kernel_size=1, stride=1, bias=bias)  #stride步长  bias布尔值 表示是否使用偏置项
        # self.act 定义了GELU激活函数
        self.act = nn.GELU()
        # self.proj 是1x1卷积,用于投影调制后的特征。
        self.proj = nn.Conv2d(dim, dim, kernel_size=1)
        # self.proj_drop 是一个dropout层,用于正则化
        self.proj_drop = nn.Dropout(proj_drop)
        # self.focal_layers 是一个模块列表,存储了不同焦点级别的卷积层
        self.focal_layers = nn.ModuleList()

        self.kernel_sizes = []
        for k in range(self.focal_level):
            kernel_size = self.focal_factor * k + self.focal_window     #focal_factor 是放大因子 用来控制卷积核大小的增伤速率
            # self.focal_window 是基准卷积核的大小
            #随着k的增大,卷积核不断地增大
            self.focal_layers.append(  #用于存储所有的焦点卷积层
                nn.Sequential(   #nn.Sequential用于将多个神经网络层组合在一起  包含卷积层和激活函数
                    nn.Conv2d(dim, dim, kernel_size=kernel_size, stride=1, #stride=1表示卷积步长为1。
                              # groups=dim启用了深度可分离卷积,即每个输入通道都有自己的卷积核,这大大减少了计算量
                              groups=dim, padding=kernel_size // 2, bias=False),
                    nn.GELU(), #一个GELU激活函数,常用于增加非线性,使网络能更好地学习复杂的模式
                )
            )
            self.kernel_sizes.append(kernel_size) #用于存储每个焦点级别的卷积核大小。
        if self.use_postln_in_modulation:
            self.ln = nn.LayerNorm(dim)

    def forward(self, x):
        """
        Args:
            x: input features with shape of (B, H, W, C)
        """
        C = x.shape[1]

        # pre linear projection
        x = self.f_linear(x).contiguous() #1x1的卷积操作,作用是将输入特征x进行通道上的线性投影
        q, ctx, gates = torch.split(x, (C, C, self.focal_level + 1), 1)  #将投影后的张量x沿通道维度(维度1)分成三部分:
        # q:查询(Query),大小为(B, C, H, W)
        # ctx:上下文(Context),大小为(B, C, H, W)
        # gates:门控(Gates),大小为(B, self.focal_level + 1, H, W),用于对不同焦点级别的上下文进行加权

        # context aggreation
        #上下文聚合
        ctx_all = 0.0
        #ctx_all = 0.0初始化一个变量,用于累积不同焦点级别的上下文。
        for l in range(self.focal_level):
            #逐级聚合上下文
            ctx = self.focal_layers[l](ctx)
            # 使用门控值gates对当前焦点级别的上下文ctx进行加权,然后累加到ctx_all
            ctx_all = ctx_all + ctx * gates[:, l:l + 1]
        #ctx.mean(2, keepdim=True).mean(3, keepdim=True):对上下文ctx进行全局平均池化,得到一个全局上下文特征。
        ctx_global = self.act(ctx.mean(2, keepdim=True).mean(3, keepdim=True)) #self.act()  对全局上下文特征应用GELU激活函数
        # 对全局上下文特征ctx_global应用最后一个门控值的加权,然后加到ctx_all
        ctx_all = ctx_all + ctx_global * gates[:, self.focal_level:]

        # normalize context    上下文归一化
        if self.normalize_modulator:
            ctx_all = ctx_all / (self.focal_level + 1)

        # focal modulation
        x_out = q * self.h(ctx_all)  # 卷积后的上下文特征与查询q逐元素相乘(对应于焦点调制机制),得到输出特征x_out。
        x_out = x_out.contiguous()  # 确保x_out在内存中的连续性

        # post linear porjection  后线性投影:对调制后的特征进行投影和Dropout,得到最终输出
        x_out = self.proj(x_out)
        x_out = self.proj_drop(x_out)
        return x_out


http://www.kler.cn/news/290113.html

相关文章:

  • 无人机 PX4 飞控 | ROS应用层开发:offboard 模式切换详细总结
  • Linux系统的字体管理
  • 12、Django Admin在列表视图页面上显示计算字段
  • Codeforces Round 968 (Div. 2)
  • 代码随想录算法训练营第36天|1049. 最后一块石头的重量、494. 目标和、474.一和零
  • 注册中心 Eureka Nacos
  • 重塑视频监控体验:WebRTC技术如何赋能智慧工厂视频高效管理场景
  • 负载均衡--资源申请说明(三)
  • Android随记
  • 坑——fastjson将字符串转到带枚举的java对象
  • ElasticSearch-数据建模
  • Go语言 Go程基础
  • rust feature 简介
  • shell 学习笔记:向脚本传递参数
  • Android Camera系列(三):GLSurfaceView+Camera
  • 分类预测|基于灰狼GWO优化BP神经网络的数据分类预测Matlab程序GWO-BP|基于鲸鱼WOA优化BP神经网络的数据分类预测Matlab程序WOA-BP
  • 智能提醒助理系列-基础设施准备
  • getLocation:fail, the permission value is offline verifying
  • Flutter--- 常规知识点
  • Redis从入门到入门(上)
  • springboot党员之家服务系统小程序论文源码调试讲解
  • Python知识点:如何使用Python实现强化学习机器人
  • 单片机与人工智能:融合创新的未来之路
  • 【LVGL- 组 lv_group_t】
  • 代码随想录算法训练营第五十六天 | 图论part06
  • 基于STM32的猫狗宠物喂养系统设计(微信小程序)(215)
  • k8s使用报错
  • JavaScript 作用链
  • [C++11#46](三) 详解lambda | 可变参数模板 | emplace_back | 默认的移动构造
  • RESTful基本要求