当前位置: 首页 > article >正文

Scale Decoupled Distillation 论文中SPP发生了什么

SPP 

import torch
import torch.nn as nn

# 定义SPP类
class SPP(nn.Module):
    def __init__(self, M=None):
        super(SPP, self).__init__()
        self.pooling_4x4 = nn.AdaptiveAvgPool2d((4, 4))
        self.pooling_2x2 = nn.AdaptiveAvgPool2d((2, 2))
        self.pooling_1x1 = nn.AdaptiveAvgPool2d((1, 1))

        self.M = M
        print(f"初始化 M:{self.M}")

    def forward(self, x):
        print(f"输入 x 的形状: {x.shape}")
        
        # 进行4x4池化
        x_4x4 = self.pooling_4x4(x)
        print(f"4x4池化后 x_4x4 的形状: {x_4x4.shape}")
        
        # 进行2x2池化
        x_2x2 = self.pooling_2x2(x_4x4)
        print(f"2x2池化后 x_2x2 的形状: {x_2x2.shape}")
        
        # 进行1x1池化
        x_1x1 = self.pooling_1x1(x_4x4)
        print(f"1x1池化后 x_1x1 的形状: {x_1x1.shape}")

        # 展平特征图
        x_4x4_flatten = torch.flatten(x_4x4, start_dim=2, end_dim=3)
        print(f"4x4池化展平后 x_4x4_flatten 的形状: {x_4x4_flatten.shape}")
        
        x_2x2_flatten = torch.flatten(x_2x2, start_dim=2, end_dim=3)
        print(f"2x2池化展平后 x_2x2_flatten 的形状: {x_2x2_flatten.shape}")
        
        x_1x1_flatten = torch.flatten(x_1x1, start_dim=2, end_dim=3)
        print(f"1x1池化展平后 x_1x1_flatten 的形状: {x_1x1_flatten.shape}")

        # 根据 M 值拼接特征
        if self.M == '[1,2,4]':
            x_feature = torch.cat((x_1x1_flatten, x_2x2_flatten, x_4x4_flatten), dim=2)
            print(f"特征拼接后 (M=[1,2,4]) x_feature 的形状: {x_feature.shape}")
        elif self.M == '[1,2]':
            x_feature = torch.cat((x_1x1_flatten, x_2x2_flatten), dim=2)
            print(f"特征拼接后 (M=[1,2]) x_feature 的形状: {x_feature.shape}")
        else:
            raise NotImplementedError('ERROR M')

        # 计算特征强度
        x_strength = x_feature.permute((2, 0, 1))
        print(f"特征强度计算前 x_strength 形状: {x_strength.shape}")
        
        x_strength = torch.mean(x_strength, dim=2)
        print(f"特征强度计算后 x_strength 的形状: {x_strength.shape}")

        return x_feature, x_strength


# 创建一个SPP模块实例
M = '[1,2,4]'  # 设置M为'[1,2,4]',拼接所有三个尺度
spp = SPP(M=M)

# 输入一个示例张量,形状为 [batch_size, channels, height, width]
input_tensor = torch.randn(2, 3, 16, 16)  # 假设 batch_size=2, channels=3, height=16, width=16

# 前向传播并打印每一步结果
x_feature, x_strength = spp(input_tensor)

运行结果 

输入 x 的形状: torch.Size([2, 3, 16, 16])
4x4池化后 x_4x4 的形状: torch.Size([2, 3, 4, 4])
2x2池化后 x_2x2 的形状: torch.Size([2, 3, 2, 2])
1x1池化后 x_1x1 的形状: torch.Size([2, 3, 1, 1])
4x4池化展平后 x_4x4_flatten 的形状: torch.Size([2, 3, 16])
2x2池化展平后 x_2x2_flatten 的形状: torch.Size([2, 3, 4])
1x1池化展平后 x_1x1_flatten 的形状: torch.Size([2, 3, 1])
特征拼接后 (M=[1,2,4]) x_feature 的形状: torch.Size([2, 3, 21])
特征强度计算前 x_strength 形状: torch.Size([21, 2, 3])
特征强度计算后 x_strength 的形状: torch.Size([21, 2])

 x_strength = torch.mean(x_strength, dim=2)

x_strength = torch.mean(x_strength, dim=2) 这行代码的作用是对张量 x_strength 的第三个维度(即 dim=2)进行平均操作。具体来说,它是在指定的维度上计算每个元素的平均值,从而减少该维度的大小。

详细解释:

  1. x_strength 的形状

    • 在这行代码之前,x_strength 的形状为 [feature_num, batch_size, channels],也就是通过 permute 操作后将原来的 [batch_size, channels, feature_num] 重新排列为 [feature_num, batch_size, channels]
    • 这意味着,x_strength 的第0个维度表示特征数量(来自多尺度池化的特征块),第1个维度表示批量大小,第2个维度表示通道数。
  2. torch.mean 操作

    • torch.mean(x_strength, dim=2) 表示在第2个维度(即 channels)上计算均值,这意味着对每个特征块在不同通道上的值取平均。
    • 计算后的结果不再包含通道维度,因此它会将形状从 [feature_num, batch_size, channels] 变为 [feature_num, batch_size]
    • 换句话说,对于每个特征块(feature_num),在 batch_size 上的每个样本中,通道上的值被压缩为一个平均值。

举例:

假设 x_strength 的形状为 [3, 2, 4],其中:

  • feature_num = 3(表示有3个特征块),
  • batch_size = 2(表示有2个样本),
  • channels = 4(表示每个特征块有4个通道)。
x_strength = torch.tensor([
    [[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0]],  # 第一个特征块
    [[2.0, 3.0, 4.0, 5.0], [6.0, 7.0, 8.0, 9.0]],  # 第二个特征块
    [[3.0, 4.0, 5.0, 6.0], [7.0, 8.0, 9.0, 10.0]]  # 第三个特征块
])

# 对dim=2(通道)取平均
x_strength_mean = torch.mean(x_strength, dim=2)

在这个例子中,x_strength_mean 将对每个特征块在4个通道上的值取平均:

x_strength_mean = [
    [2.5, 6.5],  # 第一个特征块的通道均值
    [3.5, 7.5],  # 第二个特征块的通道均值
    [4.5, 8.5]   # 第三个特征块的通道均值
]

最终结果:

  • 经过 torch.mean(x_strength, dim=2) 后,x_strength 的形状从 [feature_num, batch_size, channels] 变为 [feature_num, batch_size],即每个特征块对每个样本生成了一个强度值,这个值是该特征块在所有通道上的平均强度。

可不可以说特征强度计算后的x_strength的3通道变成了1通道

是的,可以这样理解。

在计算特征强度的过程中,torch.mean(x_strength, dim=2) 对第3个维度(即通道维度)进行求平均操作,将原本具有多个通道(如3通道或其他通道数)的特征信息压缩为1个值。这个过程相当于将多个通道的信息整合成一个值,表示这些通道的平均特征强度。

因此,可以说计算后的 x_strength 从多通道(例如3通道)变成了1通道,因为它只剩下一个平均强度值来代表原来的多个通道信息。

解释:

  1. 计算前:原始的 x_strength 形状为 [feature_num, batch_size, channels],表示有 channels 个通道。
  2. 计算后:经过 torch.mean(x_strength, dim=2) 操作后,channels 维度被压缩,输出形状为 [feature_num, batch_size],即每个特征块只有一个通道,这个通道的值是原来多个通道的平均值。

因此,通道数从多个变为1个,这个1个通道的值是对所有通道的强度的平均表示。

SPP后

代码

import torch
import torch.nn as nn

# 定义SPP类
class SPP(nn.Module):
    def __init__(self, M=None):
        super(SPP, self).__init__()
        self.pooling_4x4 = nn.AdaptiveAvgPool2d((4, 4))
        self.pooling_2x2 = nn.AdaptiveAvgPool2d((2, 2))
        self.pooling_1x1 = nn.AdaptiveAvgPool2d((1, 1))

        self.M = M
        print(f"初始化 M:{self.M}")

    def forward(self, x):
        print(f"输入 x 的形状: {x.shape}")
        
        x_4x4 = self.pooling_4x4(x)
        x_2x2 = self.pooling_2x2(x_4x4)
        x_1x1 = self.pooling_1x1(x_4x4)

        x_4x4_flatten = torch.flatten(x_4x4, start_dim=2, end_dim=3)
        x_2x2_flatten = torch.flatten(x_2x2, start_dim=2, end_dim=3)
        x_1x1_flatten = torch.flatten(x_1x1, start_dim=2, end_dim=3)

        if self.M == '[1,2,4]':
            x_feature = torch.cat((x_1x1_flatten, x_2x2_flatten, x_4x4_flatten), dim=2)
        elif self.M == '[1,2]':
            x_feature = torch.cat((x_1x1_flatten, x_2x2_flatten), dim=2)
        else:
            raise NotImplementedError('ERROR M')

        x_strength = x_feature.permute((2, 0, 1))
        x_strength = torch.mean(x_strength, dim=2)

        return x_feature, x_strength


# 定义主网络,包含SPP和全连接层
class NetWithSPP(nn.Module):
    def __init__(self, M=None, num_classes=1000):
        super(NetWithSPP, self).__init__()
        self.spp = SPP(M)
        self.fc = nn.Linear(3, num_classes)  # 将输入维度修改为3,匹配通道数

    def forward(self, feat4):
        # 从 SPP 获取多尺度特征
        x_spp, x_strength = self.spp(feat4)
        print(f"x_spp 的形状: {x_spp.shape}")
        
        # 调整 x_spp 的维度
        x_spp = x_spp.permute((2, 0, 1))
        print(f"维度转换后 x_spp 的形状: {x_spp.shape}")
        
        # 获取维度大小
        m, b, c = x_spp.shape[0], x_spp.shape[1], x_spp.shape[2]
        print(f"m (feature_num): {m}, b (batch_size): {b}, c (channels): {c}")
        
        # 展平 x_spp 以便输入全连接层
        x_spp = torch.reshape(x_spp, (m * b, c))
        print(f"展平后的 x_spp 的形状: {x_spp.shape}")
        
        # 通过全连接层生成 patch_score
        patch_score = self.fc(x_spp)
        print(f"通过全连接层后的 patch_score 的形状: {patch_score.shape}")
        
        # 将 patch_score 重新调整形状
        patch_score = torch.reshape(patch_score, (m, b, 1000))
        print(f"重新调整形状后的 patch_score 的形状: {patch_score.shape}")
        
        # 最后 permute,恢复到 [batch_size, 1000, feature_num]
        patch_score = patch_score.permute((1, 2, 0))
        print(f"最后 permute 后 patch_score 的形状: {patch_score.shape}")

        return patch_score


# 创建网络实例
M = '[1,2,4]'  # 设置 M 为 '[1,2,4]'
net = NetWithSPP(M=M, num_classes=1000)

# 输入一个示例张量,形状为 [batch_size, channels, height, width]
input_tensor = torch.randn(2, 3, 16, 16)  # 假设 batch_size=2, channels=3, height=16, width=16

# 前向传播并打印每一步结果
patch_score = net(input_tensor)

运行结果 

输入 x 的形状: torch.Size([2, 3, 16, 16])
x_spp 的形状: torch.Size([2, 3, 21])
维度转换后 x_spp 的形状: torch.Size([21, 2, 3])
m (feature_num): 21, b (batch_size): 2, c (channels): 3
展平后的 x_spp 的形状: torch.Size([42, 3])
通过全连接层后的 patch_score 的形状: torch.Size([42, 1000])
重新调整形状后的 patch_score 的形状: torch.Size([21, 2, 1000])
最后 permute 后 patch_score 的形状: torch.Size([2, 1000, 21])


http://www.kler.cn/a/354691.html

相关文章:

  • 优化提示词改善答疑机器人回答质量
  • 【Leetcode 热题 100】20. 有效的括号
  • eNSP之家——路由器--入门实例详解
  • 接口测试-postman(使用postman测试接口笔记)
  • Improving Language Understanding by Generative Pre-Training GPT-1详细讲解
  • 【Nginx】设置https和http同时使用同一个端口访问
  • 一款AutoXJS现代化美观的日志模块AxpLogger
  • k8s-配置网络策略 NetworkPolicy
  • docker/docker-compose里面Command和entrypoint的关系
  • 股票Tick数据如何获取做量化交易
  • springboot如何接入阿里云短信
  • Vue 3 中的状态管理:深入探讨 Vuex 和 Pinia 的比较与最佳实践
  • 初识git · 有关模型
  • 【C语言】数据类型
  • 实用篇:如何让Win11右键默认显示更多呢
  • STM32 独立看门狗和窗口看门狗区别
  • Python进阶知识
  • 智能平台或系统中的归因、根因分析案例集锦
  • 使用python实现图书管理系统
  • Unity动画系统
  • 外包干了3周,技术退步太明显了。。。。。
  • 使用React Router实现前端的权限访问控制
  • 【Flutter】Dart:异步
  • docker容器里的时间不对,linux解决方案
  • 机器学习——向量化
  • 学习第三十六行