当前位置: 首页 > article >正文

YOLOv6-4.0部分代码阅读笔记-dbb_transforms.py

dbb_transforms.py

yolov6\layers\dbb_transforms.py

目录

dbb_transforms.py

1.所需的库和模块

2.def transI_fusebn(kernel, bn): 

3.def transII_addbranch(kernels, biases): 

4.def transIII_1x1_kxk(k1, b1, k2, b2, groups): 

5.def transIV_depthconcat(kernels, biases): 

6.def transV_avg(channels, kernel_size, groups): 

7.def transVI_multiscale(kernel, target_kernel_size): 


1.所需的库和模块

import torch
import numpy as np
import torch.nn.functional as F

2.def transI_fusebn(kernel, bn): 

# 它用于将一个卷积核与一个批量归一化(Batch Normalization)层融合。这种融合通常用于在模型部署时简化网络结构,减少计算量和内存消耗。
# 1.kernel :卷积核的权重张量。
# 2.bn :批量归一化层,预期是一个 nn.BatchNorm2d 的实例。
def transI_fusebn(kernel, bn):
    # 获取批量归一化层的缩放因子(也称为权重)。
    gamma = bn.weight
    # 计算批量归一化层的标准差,其中 running_var 是运行方差, eps 是用于数值稳定性的小常数。
    std = (bn.running_var + bn.eps).sqrt()
    # kernel * ((gamma / std).reshape(-1, 1, 1, 1)) :将卷积核与缩放因子除以标准差的值相乘,并重塑为 (-1, 1, 1, 1) 的形状,以匹配卷积核的维度。这样,每个卷积核通道都被乘以相应的缩放因子。
    # bn.bias - bn.running_mean * gamma / std :计算去除均值偏移后的偏置。这里, running_mean 是批量归一化层的运行均值, gamma 是缩放因子, std 是标准差。
    # 返回融合后的卷积核和偏置。
    return kernel * ((gamma / std).reshape(-1, 1, 1, 1)), bn.bias - bn.running_mean * gamma / std

3.def transII_addbranch(kernels, biases): 

# 它用于将多个卷积核和偏置相加,通常用于在模型部署时合并具有相同功能的多个分支。
# 1.kernels :一个包含多个卷积核权重张量的列表或元组。
# 2.biases :一个包含多个卷积核偏置张量的列表或元组。
def transII_addbranch(kernels, biases):
    # 将 kernels 列表中的所有卷积核权重张量相加,得到总的卷积核权重张量。
    # 将 biases 列表中的所有偏置张量相加,得到总的偏置张量。
    # 返回合并后的卷积核和偏置。
    return sum(kernels), sum(biases)
# 这个方法可以在需要将多个卷积分支合并为一个单一卷积层的场景中使用,尤其是在模型部署时。

4.def transIII_1x1_kxk(k1, b1, k2, b2, groups): 

# 它用于将两个卷积核(一个1x1和一个kxk)以及它们的偏置进行融合。这种融合通常用于在模型部署时简化网络结构,特别是在处理分组卷积时。
# 1.k1 :第一个卷积核(1x1卷积核)的权重张量。
# 2.b1 :第一个卷积核的偏置张量。
# 3.k2 :第二个卷积核(kxk卷积核)的权重张量。
# 4.b2 :第二个卷积核的偏置张量。
# 5.groups :分组卷积的组数。
def transIII_1x1_kxk(k1, b1, k2, b2, groups):
    # 当 groups 参数为 1 时,表示没有使用分组卷积,或者所有卷积核都在同一组中。
    if groups == 1:
        # 使用 PyTorch 的 F.conv2d 函数将 kxk 卷积核 k2 与转置后的 1x1 卷积核 k1 进行卷积操作。
        # k1.permute(1, 0, 2, 3) 将 k1 的维度从 (out_channels, in_channels, height, width) 转换为 (in_channels, out_channels, height, width) ,以匹配 F.conv2d 函数的要求。
        k = F.conv2d(k2, k1.permute(1, 0, 2, 3))
        # 将 kxk 卷积核 k2 与 1x1 卷积核的偏置 b1 相乘。
        # b1.reshape(1, -1, 1, 1) 将 b1 从一维张量转换为四维张量,以便与 k2 进行元素乘法。然后,通过 .sum((1, 2, 3)) 在空间维度(高度和宽度)上对结果进行求和,得到新的偏置 b_hat 。
        b_hat = (k2 * b1.reshape(1, -1, 1, 1)).sum((1, 2, 3))
    # 多组情况,即当卷积核被分成多个组进行处理时。
    else:
        # 初始化两个列表,用于存储每个分组计算得到的卷积核和偏置。
        k_slices = []
        b_slices = []
        # 将1x1卷积核 k1 转置,以适应 F.conv2d 的输入要求。
        k1_T = k1.permute(1, 0, 2, 3)
        # 计算1x1卷积核每组的通道数。
        k1_group_width = k1.size(0) // groups
        # 计算kxk卷积核每组的通道数。
        k2_group_width = k2.size(0) // groups
        # 遍历每个分组,对每个分组的卷积核和偏置进行操作。
        for g in range(groups):
            # 提取1x1卷积核的当前分组切片。
            k1_T_slice = k1_T[:, g*k1_group_width:(g+1)*k1_group_width, :, :]
            # 提取kxk卷积核的当前分组切片。
            k2_slice = k2[g*k2_group_width:(g+1)*k2_group_width, :, :, :]
            # 对当前分组的kxk卷积核和1x1卷积核进行卷积操作,并将结果添加到 k_slices 列表中。
            k_slices.append(F.conv2d(k2_slice, k1_T_slice))
            # 计算当前分组的偏置,并将结果添加到 b_slices 列表中。
            b_slices.append((k2_slice * b1[g * k1_group_width:(g+1) * k1_group_width].reshape(1, -1, 1, 1)).sum((1, 2, 3)))
        # def transIV_depthconcat(kernels, biases): -> 它用于将多个卷积核和偏置在深度维度上进行拼接(concatenation)。返回拼接后的卷积核和偏置。 -> return torch.cat(kernels, dim=0), torch.cat(biases)
        # 使用 transIV_depthconcat 函数将所有分组的卷积核和偏置在深度维度上进行拼接,得到最终的卷积核 k 和偏置 b_hat 。
        k, b_hat = transIV_depthconcat(k_slices, b_slices)
    # 返回 融合后的卷积核 k 和 偏置 b_hat + b2 。
    return k, b_hat + b2

5.def transIV_depthconcat(kernels, biases): 

# 它用于将多个卷积核和偏置在深度维度上进行拼接(concatenation)。这种操作通常用于在模型中合并来自不同分支的特征,特别是在深度学习模型中处理多通道数据时。
# 1.kernels :一个包含多个卷积核权重张量的列表或元组。.
# 2.biases :一个包含多个卷积核偏置张量的列表或元组。
def transIV_depthconcat(kernels, biases):
    # torch.cat(kernels, dim=0) :在深度维度(即通道维度)上拼接所有卷积核权重张量。这里的 dim=0 表示在第一个维度上进行拼接,通常对应于通道数。
    # torch.cat(biases) :将所有偏置张量拼接在一起。由于偏置通常是一维张量,因此不需要指定维度。
    # 返回拼接后的卷积核和偏置。
    return torch.cat(kernels, dim=0), torch.cat(biases)

6.def transV_avg(channels, kernel_size, groups): 

# 它用于创建一个平均池化操作的等效卷积核。
# 1.channels :输入和输出的通道数。
# 2.kernel_size :卷积核的大小,也就是平均池化的窗口大小。
# 3.groups :分组数,用于确定卷积核的形状。
def transV_avg(channels, kernel_size, groups):
    # 计算每个组的输入通道数。
    input_dim = channels // groups
    # 创建一个形状为 (channels, input_dim, kernel_size, kernel_size) 的零张量,用于存储卷积核的值。
    k = torch.zeros((channels, input_dim, kernel_size, kernel_size))
    # np.tile(A,reps)
    # 函数可对输入的数组,元组或列表进行重复构造,其输出是数组。
    # 该函数有两个参数:
    # A :输入的数组,元组或列表。
    # reps :重复构造的形状,可为数组,元组或列表。

    # 为卷积核的每个通道设置值。这里使用 np.arange(channels) 生成通道索引, np.tile(np.arange(input_dim), groups) 生成每个组内的通道索引,然后为每个位置设置值为 1.0 / kernel_size ** 2 。
    # 这样,卷积核的每个元素都是平均池化窗口大小的倒数,实现了平均池化的效果。
    k[np.arange(channels), np.tile(np.arange(input_dim), groups), :, :] = 1.0 / kernel_size ** 2
    # 返回创建的平均池化等效卷积核 k 。
    return k
# 这个方法可以在需要将平均池化层转换为卷积层的场景中使用,尤其是在模型部署时。

7.def transVI_multiscale(kernel, target_kernel_size): 

# 它用于将一个卷积核调整到目标尺寸,通常是通过填充(padding)来实现。这个函数可以帮助在不同尺寸的卷积核之间进行转换,特别是在需要将卷积核适配到特定大小时。
# 尚未使用非方形内核( kernel.size(2) != kernel.size(3) )或偶数大小内核进行测试
#   This has not been tested with non-square kernels (kernel.size(2) != kernel.size(3)) nor even-size kernels
# 1.kernel :原始的卷积核权重张量。
# 2.target_kernel_size :目标卷积核的大小,通常是一个整数,表示期望的卷积核尺寸。
def transVI_multiscale(kernel, target_kernel_size):
    # 计算高度方向需要填充的像素数。 kernel.size(2) 是卷积核的高度。
    H_pixels_to_pad = (target_kernel_size - kernel.size(2)) // 2
    # 计算宽度方向需要填充的像素数。 kernel.size(3) 是卷积核的宽度。
    W_pixels_to_pad = (target_kernel_size - kernel.size(3)) // 2
    # torch.nn.functional.pad(input, pad, mode='constant', value=0)
    # F.pad 是 PyTorch 中的一个函数,用于对张量进行填充(padding)。这个函数可以对多维张量进行操作,按照指定的模式在张量的边界或内部添加值(通常是0)。
    # 参数 :
    # input :要进行填充的输入张量。
    # pad :一个元组,指定了每个边界的填充量。对于二维张量(例如图像), pad 元组的形式为 (left, right, top, bottom) 。对于三维张量,形式为 (left, right, top, bottom, front, back) 。
    # mode :填充模式,可选值为 'constant' 、 'reflect' 、 'replicate' 或 'circular' 。默认为 'constant' 。
    # 'constant' :填充常数值(由 value 参数指定)。
    # 'reflect' :反射填充,即张量的边界值会反射到另一边。
    # 'replicate' :复制填充,即张量的边界值会被复制。
    # 'circular' :循环填充,即张量的最后一个值会填充到开始位置,反之亦然。
    # value :用于 'constant' 模式下的填充值,默认为0。
    # 返回值 :
    # 返回填充后的张量。

    # 使用 PyTorch 的 F.pad 函数对卷积核进行填充。填充的参数 [H_pixels_to_pad, H_pixels_to_pad, W_pixels_to_pad, W_pixels_to_pad] 分别对应于左、右、上、下的填充量。
    return F.pad(kernel, [H_pixels_to_pad, H_pixels_to_pad, W_pixels_to_pad, W_pixels_to_pad])


# 在PyTorch框架中, nn.ReLU 和 F.relu 都用于实现 ReLU (Rectified Linear Unit,修正线性单元)激活函数,但它们在用法和场景上存在明显的区别。
# 以下是两者的主要区别:
# 1.本质差异
# nn.ReLU :这是一个类,继承自 torch.nn.Module 。它作为一个网络层(layer)存在,具有内部状态,适合在构建神经网络模型时使用。在定义模型时, nn.ReLU 需要被实例化并添加到模型中,以作为模型的一部分。
# F.relu :这是一个函数,属于 torch.nn.functional 模块。它作为一个函数接口存在,无内部状态,适用于在模型的前向传播(forward pass)中直接调用。
# 2.使用方式
# nn.ReLU :使用前需要先实例化,然后作为模型的一部分进行调用。例如,在模型的 forward 方法中,可以通过 self.relu(input) 的方式调用。
# F.relu :直接通过 torch.nn.functional.relu(input) 或简写为 F.relu(input) 的方式在 forward 方法中调用,无需实例化。
# 3.场景差异
# nn.ReLU :更适合在定义模型结构时使用,因为它可以被视为模型的一部分,具有更明确的层次结构。
# F.relu :更适用于快速原型开发或在前向传播中直接调用,因为它不需要实例化,使用起来更为简洁。
# 4.打印网络结构
# 当使用 print(model) 来打印模型结构时, nn.ReLU 会被视为模型的一层并显示出来,而 F.relu 则不会出现在模型结构的打印输出中,因为它只是作为函数调用的一部分。
# 5.性能与效果
# 两者在功能和效果上是等价的,都是实现 ReLU 激活函数。但是,由于 nn.ReLU 作为模型的一部分,可能在某些优化或自动微分方面与模型的其他部分有更好的集成。

# 总结
# 选择 nn.ReLU 还是 F.relu 主要取决于使用场景和个人偏好。在定义模型结构时,推荐使用 nn.ReLU 以保持模型的清晰性和一致性。而在需要快速实现或调试时, F.relu 的简洁性可能更加吸引人。
# 两者在功能上是等价的,因此选择哪一种方式并不会对模型的性能产生显著影响。

 


http://www.kler.cn/a/375855.html

相关文章:

  • RabbitMQ最全教程-Part1(基础使用)
  • TVM前端研究--Pass
  • 【数据结构】宜宾大学-计院-实验六
  • 在工作中常用到的 Linux 命令总结
  • 【C语言】宏封装的实用总结
  • 商务英语学习柯桥学外语到泓畅-老外说“go easy on me”是什么意思?
  • ESP-IDF 配置 SimpleFOC 项目
  • 企业如何通过架构蓝图实现数字化转型
  • Unity 游戏性能优化实践:内存管理与帧率提升技巧
  • Android和iOS有什么区别?
  • redis详细教程(4.GEO,bitfield,Stream)
  • 自动驾驶上市潮中,会诞生下一个“英伟达”吗?
  • 基于深度学习的机器人智能控制算法 笔记
  • 【Linux】编辑器vim 与 编译器gcc/g++
  • OpenCV Python 版使用教程(二)摄像头调用
  • 二叉树选择题
  • 11.01学习
  • Linux云计算 |【第五阶段】CLOUD-DAY7
  • Shell 编程-Shell三剑客 Grep 学习
  • K8s pod 调度策略
  • 数据库相关概念
  • leaflet 地图基础应用篇
  • ssh和ssl的区别在哪些方面?
  • Facebook群控策略详解
  • 基于微信小程序的公务员考试信息查询系统+LW示例参考
  • 农作物病害图像分割系统:深度学习检测