当前位置: 首页 > article >正文

探索卷积层参数量与计算量

1 问题

  1. 了解VGG网络并利用PyTorch实现VGG

  2. 探索1x1卷积的作用

  3. 探索卷积层参数量、计算量的计算方法

2 方法

  1. 了解VGG网络并利用PyTorch实现VGG
    1、VGG是Oxford的Visual Geometry Group的组提出的,VGG的缩写也来自于这个组的名字。

    VGG网络探索了提升网络的深度对最终的图像识别准确率的重要性,同时在VGG中尝试使用小的卷积核来构建深层的卷积网络。


    2、利用pytorch实现VGG代码:

    class ResidualBlock(nn.Module):
       def __init__(self, in_planes=64):
           super().__init__()
           self.conv1 = nn.Conv2d(
               in_channels=in_planes, #! Residual模块的输入通道数
               out_channels=64,
               kernel_size=3,
               padding=1,
           )
           self.relu2 = nn.ReLU()
           self.conv3 = nn.Conv2d(
               in_channels=64,
               out_channels=in_planes,
               kernel_size=3,
               padding=1,
           )
           self.relu4 = nn.ReLU()
       def forward(self, x):
           identity = x  
           x = self.conv1(x)
           x = self.relu2(x)
           x = self.conv3(x)
           x = x + identity #! 重点
           x = self.relu4(x)
           return x
    class Network(nn.Module):
       def __init__(self):
           super().__init__()
           self.conv1 = nn.Conv2d(
               in_channels=1,
               out_channels=32,
               kernel_size=3,
               padding=1,
               stride=2,
           )
           self.res_block2 = ResidualBlock(in_planes=32)
           self.conv3 = nn.Conv2d(
               in_channels=32,
               out_channels=64,
               kernel_size=3,
               padding=1,
               stride=2,
           )
           self.res_block4 = ResidualBlock(in_planes=64)
           self.res_block5 = ResidualBlock(in_planes=64)
           self.flatten6 = nn.Flatten()
           self.fc7 = nn.Linear(
               in_features=64*7*7,
               out_features=512,
           )
           self.fc8 = nn.Linear(
               in_features=512,
               out_features=10,
           )
       def forward(self, x):
           x = self.conv1(x)
           x = self.res_block2(x)
           x = self.conv3(x)
           x = self.res_block4(x)
           x = self.res_block5(x)
           x = self.flatten6(x)
           x = self.fc7(x)
           x = self.fc8(x)
           return x
    if __name__ == '__main__':
       x = torch.rand(size=(1, 1, 28, 28))
       net = Network()
       out = net(x)
       print(out.shape)
  2. 探索1x1卷积的作用

1、降维/升维

1x1卷积核可以通过控制卷积核数量实现降维或升维,卷积后的特征图通道数与卷积核的个数是相同的。所以,如果想要升维或降维,只需要通过修改卷积核的个数即可。

2、 跨通道信息交互(通道的变换)

使用1x1卷积核,实现降维和升维的操作其实就是 channel 间信息的线性组合变化。

比如:在尺寸 3x3,64通道个数的卷积核后面添加一个尺寸1x1,28通道个数的卷积核,就变成了尺寸3x3,28尺寸的卷积核。原来的64个通道就可以理解为跨通道线性组合变成了28通道,这就是通道间的信息交互。

3、增加网络深度(增加非线性)

每使用 1x1卷积核,及增加一层卷积层,所以网络深度得以增加。而使用 1x1卷积核后,可以保持特征图大小与输入尺寸相同,卷积层卷积过程会包含一个激活函数,从而增加了非线性。

在输入尺寸不发生改变的情况下而增加了非线性,所以会增加整个网络的表达能力

三、探索卷积层参数量、计算量的计算方法

1、卷积层的参数:

filter,若一个3*3的卷积核,其参数为 9个,再加上通道数,其参数就是:通道数*9

(1)普通卷积层参数计算:

有一张通道为3大小为224∗224的图片,卷积核大小为:2∗2,输出通道数为32,则一个卷积核的参数2∗2∗3=12;则32个卷积核的参数为32∗(2∗2∗3)+32=416,其中最后的32为偏置项参数。

  1. 池化层参数计算:
    池化(Pooling):也称为欠采样或下采样;主要用于特征降维,压缩数据和参数的数量,减小过拟合,同时提高模型的容错性;pooling在不同的 depth 上是分开执行的,池化操作是分开应用到各个特征图的,因此池化不需要参数控制。

  2. 全连接层参数计算:

例如某一网络最后一层卷积层输出的大小为7∗7∗512,全连接层输出尺寸为 1∗1024;则需卷积层的尺寸为:7∗7∗1024,所以所需参数为512∗7∗7∗1024+1024=25691136。

  1. 卷积层的计算量:

  1. 普通卷积层的计算量
    卷积层计算量 = 卷积矩阵操作 + 融合操作 + 偏置项操作
    若有一张通道为3。大小为7∗7的图片,卷积核大小为:5∗5,stride为1,padding为0,输出通道数为64,其输出feature map的大小为 3∗3∗64,feature map中的每一个像素点,都是64个的5×5×3的filter共同作用于7∗7∗3的图片计算一次得到的。

  2. 全连接层的计算量

若需对7∗7∗512的数据进行全连接操作,其输出尺寸为1∗1024(即1*1024*1*1),其:filter:7∗7∗512 ,则所需计算量为:1024∗(1∗1)∗[(7∗7∗512)+512∗(7∗7−1)+(512−1)]+1024,其中1024为输出通道数,(1∗1)为输出feature map尺寸,7∗7∗512为某一通道feature map中某一像素的矩阵操作乘法的计算量,512∗(7∗7−1)为某一通道feature map中某一像素的矩阵操作加法的计算量,(512−1)为通道融合加法的计算量,最后的 1024 为偏置项操作计算量。

3 结语

针对卷积层参数量、计算量的计算方法, 其中全连接层参数计算参数量巨大,全连接层参数就可占整个网络参数80%左右,所以全连接层参数具有冗余性。

卷积层计算量的计算方法中,矩阵操作先乘法,再加法。在全连接层的计算量中,得知减少网络参数应主要针对全连接层,进行计算量优化时,重点应放在卷积层。


http://www.kler.cn/news/362734.html

相关文章:

  • 如何保护服务器的系统日志
  • python 爬虫抓取百度热搜
  • MFC小游戏设计
  • 【设计模式系列】命令模式
  • flex常用固定搭配
  • 各种包管理工具(npm,pip,yum,brew...)换镜像源
  • MySQL--基本介绍
  • HBuilder X 中Vue.js基础使用2(三)
  • 基于 Konva 实现Web PPT 编辑器(三)
  • qt生成uuid,转成int。ai回答亲测可以
  • 线性可分支持向量机的原理推导 9-32线性分类超平面的位置 公式解析
  • Dubbo接口解析
  • WordPress多站点子目录模式更换域名的教程方法
  • elementUI进度条el-progress不显示白色
  • 使用预测或实际LTV计算ROI
  • ubuntu22 安装labelimg制作自己的深度学习目标检测数据集
  • 微软大哥,全球第一(交易积累)
  • IDEA 如何导入NC65项目
  • 【贪心算法】(第十一篇)
  • Docker 基础入门
  • 将Django项目从PyCharm迁移到VSCode
  • Vue实现消息提示功能
  • Apache请求日志采集
  • 数据库聚合函数
  • Django自定义过滤器
  • 【软件测试】JUnit