ResNet(残差网络)
1️⃣ ResNet介绍
ResNet斩获2015
ImageNet竞赛图像分类任务第一名, 目标检测第一名。其重要贡献是发现了神经网络的“退化现象”,即一直加深网络深度并不能总是提高性能,反而会导致性能下降。针对这种现象,发明了Shortcut connection,极大的消除了深度过大的神经网络训练困难问题。
2️⃣ 原理分析
ResNet有两种block,一种是两层的BasicBlock,一种是三层的Bottleneck
- 两层的BasicBlock用于浅层网络(ResNet18、ResNet34),分实线和虚线的BasicBlock
- 实线BasicBlock :包含两个卷积层
根据这个图举个例子,例如输入是64×56×56- 第一个卷积:64个3×3卷积核, stride=1,padding=1
输入:64×56×56
输出:64×56×56
- 第二个卷积:64个3×3卷积核, stride=1,padding=1
输入:64×56×56
输出:64×56×56
- Shortcut connection
直接相加
相加后的最终维度:64×56×56
- 第一个卷积:64个3×3卷积核, stride=1,padding=1
- 虚线BasicBlock:里面包含两个卷积层,但注意的是第一个卷积层stride=2,进行了高宽的降维,因此shortcut connect加入了1×1的卷积进行x和F(x)维度统一。
根据这个图举个例子,例如输入是64×56×56- 第一个卷积:128个3×3卷积核, stride=2,padding=1
输入:64×56×56
输出:128×28×28
- 第二个卷积:128个3×3卷积核, stride=1,padding=1
输入:128×28×28
输出:128×28×28
- Shortcut connection
原始的x是64×56×56,经过两层卷积后的F(x)维度128×28×28,x与F(x)维度不一致,因此x需要经过卷积层来统一维度:128个1×1卷积核,stride=2,padding=1- 输入x:64×56×56
- 输出:128×28×28
- 维度一致了,因此可以相加
- 相加后的最终维度:
128×28×28
- 第一个卷积:128个3×3卷积核, stride=2,padding=1
- 实线BasicBlock :包含两个卷积层
- 三层的Bottleneck用于深层网络(ResNet50、ResNet101、ResNet152),也分实线和虚线
- 实线Bottleneck:包含三个卷积层
- 虚线Bottleneck:包含三个卷积层,但注意的是第二个卷积层stride=2,进行了高宽的降维,因此shortcut connect加入了1×1的卷积进行x和F(x)维度统一
- 实线Bottleneck:包含三个卷积层
-
两层的BasicBlock与三层的Bottleneck的对比,这里我们只分析实线的情况
Bottleneck能够减少参数和运算量,其中第一个1×1的卷积用于降维,第二个1×1的卷积用于升维。假设输入的通道数为256【CNN参数与输入的高宽无关】,Basicblock需要1179648个参数,右侧模块需要69632个参数
注:CNN参数个数 = 卷积核高度×卷积核宽度×输入通道数×卷积核个数
3️⃣ 网络结构
这里以ResNet18(1个卷积+16个卷积+1个fc)为例进行分析:
-
网络输入:3×224×224
-
conv1:64个7×7卷积核,stride=2,padding=3
输入:3×224×224
输出:64×112×112
经过BN(批量归一化) -
maxpool:3×3,stride=2,padding=1
输入:64×112×112
输出:64×56×56
-
con2_x:两个实线BasicBlock,每个BasicBlock有两层卷积
-
第一个实线BasicBlock
- 第一个卷积:64个3×3卷积核, stride=1,padding=1
输入:64×56×56
输出:64×56×56
经过BN(批量归一化)
经过ReLu - 第二个卷积:64个3×3卷积核, stride=1,padding=1
输入:64×56×56
输出:64×56×56
经过BN(批量归一化) - Shortcut connection
直接相加
经过ReLu
相加后的最终维度:64×56×56
- 第一个卷积:64个3×3卷积核, stride=1,padding=1
-
第二个实线BasicBlock
- 第一个卷积:64个3×3卷积核, stride=1,padding=1
输入:64×56×56
输出:64×56×56
- 第二个卷积:64个3×3卷积核, stride=1,padding=1
输入:64×56×56
输出:64×56×56
- Shortcut connection
直接相加
经过ReLu
相加后的最终维度:64×56×56
- 第一个卷积:64个3×3卷积核, stride=1,padding=1
-
-
conv3_x:一个虚线BasicBlock,一个实线BasicBlock,每个BasicBlock有两层卷积
- 第一个BasicBlock(虚线,高宽减半,深度加倍)
- 第一个卷积:128个3×3卷积核,stride=2【注意这里】,padding=1
输入:64×56×56
输出:128×28×28
经过BN(批量归一化)
经过ReLu - 第二个卷积:128个3×3卷积核,stride=1,padding=1
输入:128×28×28
输出:128×28×28
经过BN(批量归一化) - Shortcut connection:原始的x是64×56×56,经过两层卷积后的F(x)维度128×28×28,x与F(x)维度不一致,因此x需要经过卷积层来统一维度:128个1×1卷积核,stride=2,padding=1
- 输入x:64×56×56
- 输出:128×28×28
- 经过BN(批量归一化)
- 维度一致了,因此可以相加
- 经过ReLu
- 相加后的最终维度:
128×28×28
- 第一个卷积:128个3×3卷积核,stride=2【注意这里】,padding=1
- 第二个BasicBlock
- 第一个卷积:128个3×3卷积核,stride=1,padding=1
输入:64×56×56
输出:128×28×28
经过BN(批量归一化)
经过ReLu - 第二个卷积:128个3×3卷积核,stride=1,padding=1
输入:128×28×28
输出:128×28×28
经过BN(批量归一化) - Shortcut connection
直接相加
经过ReLu
相加后的最终维度:128×28×28
- 第一个卷积:128个3×3卷积核,stride=1,padding=1
- 第一个BasicBlock(虚线,高宽减半,深度加倍)
-
con4_x:一个虚线BasicBlock,一个实线BasicBlock,每个BasicBlock有两层卷积
- 第一个BasicBlock(虚线,高宽减半,深度加倍)
- 第一个卷积:256个3×3卷积核,stride=2【注意这里】,padding=1
输入:128×28×28
输出:256×14×14
经过BN(批量归一化)
经过ReLu - 第二个卷积:256个3×3卷积核,stride=1,padding=1
输入:256×14×14
输出:256×14×14
经过BN(批量归一化) - Shortcut connection:原始的x是128×28×28,经过两层卷积后的F(x)维度256×14×14,x与F(x)维度不一致,因此x需要经过卷积层来统一维度:256个1×1卷积核,stride=2,padding=1
- 输入x:128×28×28
- 输出:256×14×14
- 经过BN(批量归一化)
- 维度一致了,因此可以相加
- 经过ReLu
- 相加后的最终维度:
256×14×14
- 第一个卷积:256个3×3卷积核,stride=2【注意这里】,padding=1
- 第二个BasicBlock
-
第一个卷积:256个3×3卷积核,stride=1,padding=1
输入:256×14×14
输出:256×14×14
-
第二个卷积:256个3×3卷积核,stride=1,padding=1
输入:256×14×14
输出:256×14×14
-
Shortcut connection
直接相加
经过ReLu
相加后的最终维度:256×14×14
-
- 第一个BasicBlock(虚线,高宽减半,深度加倍)
-
con5_x:一个虚线BasicBlock,一个实线BasicBlock,每个BasicBlock有两层卷积
- 第一个BasicBlock(虚线,高宽减半,深度加倍)
- 第一个卷积:512个3×3卷积核,stride=2【注意这里】,padding=1
输入:256×14×14
输出:512×7×7
经过BN(批量归一化)
经过ReLu - 第二个卷积:512个3×3卷积核,stride=1,padding=1
输入:512×7×7
输出:512×7×7
经过BN(批量归一化) - Shortcut connection:原始的x是256×14×14,经过两层卷积后的F(x)维度512×7×7,x与F(x)维度不一致,因此x需要经过卷积层来统一维度:512个1×1卷积核,stride=2,padding=1
- 输入x:256×14×14
- 输出:512×7×7
- 经过BN(批量归一化)
- 维度一致了,因此可以相加
- 经过ReLu
- 相加后的最终维度:
512×7×7
- 第一个卷积:512个3×3卷积核,stride=2【注意这里】,padding=1
- 第二个BasicBlock
- 第一个卷积:512个3×3卷积核,stride=1,padding=1
输入:512×7×7
输出:512×7×7
经过BN(批量归一化)
经过ReLu - 第二个卷积:512个3×3卷积核,stride=1,padding=1
输入:512×7×7
输出:512×7×7
经过BN(批量归一化) - Shortcut connection
直接相加
经过ReLu
相加后的最终维度:512×7×7
- 第一个卷积:512个3×3卷积核,stride=1,padding=1
- 第一个BasicBlock(虚线,高宽减半,深度加倍)
-
avgpool:输出
512, 1, 1
-
FC:
10
4️⃣ 代码
1.这里以resnet18为例进行分析,创建一个名为resnet18.py的文件
# 仅针对Resnet18
import torch
from torch import nn
from torch.nn import functional as F
class BasicBlock(nn.Module):
def __init__(self, input_channels, out_channels,
use_1x1conv=False, strides=1):
super().__init__()
self.conv1 = nn.Conv2d(input_channels, out_channels,
kernel_size=3, padding=1, stride=strides)
self.bn1 = nn.BatchNorm2d(out_channels)
self.relu1=nn.ReLU()
self.conv2 = nn.Conv2d(out_channels, out_channels,
kernel_size=3, padding=1)
self.bn2 = nn.BatchNorm2d(out_channels)
if use_1x1conv:
self.conv3 = nn.Conv2d(input_channels, out_channels,
kernel_size=1, stride=strides)
self.bn3=nn.BatchNorm2d(out_channels)
else:
self.conv3 = None
def forward(self, x):
out = self.conv1(x)
out = self.bn1(out)
out = self.relu1(out)
out=self.conv2(out)
out=self.bn2(out)
if self.conv3:
x = self.conv3(x)
x=self.bn3(x)
out += x
out = F.relu(out)
return out
def big_block(input_channels, out_channels, num_block,
first_bigblock=False):
blk = []
for i in range(num_block):
if first_bigblock:
blk.append(BasicBlock(out_channels, out_channels))
elif i==0:
blk.append(BasicBlock(input_channels, out_channels,
use_1x1conv=True, strides=2))
else:
blk.append(BasicBlock(out_channels, out_channels))
return blk
class Resnet(nn.Module):
def __init__(self):
super().__init__()
self.b1 = nn.Sequential(nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3),
nn.BatchNorm2d(64),
nn.MaxPool2d(kernel_size=3, stride=2, padding=1))
self.bigblock1 = nn.Sequential(*big_block(64, 64, 2, first_bigblock=True))
self.bigblock2 = nn.Sequential(*big_block(64, 128, 2))
self.bigblock3 = nn.Sequential(*big_block(128, 256, 2))
self.bigblock4 = nn.Sequential(*big_block(256, 512, 2))
# avgpool
self.averpool = nn.AdaptiveAvgPool2d((1, 1))
# 全连接层
self.linear = nn.Linear(512, 10)
def forward(self,x):
out = self.b1(x)
# 通过第一个 big_block
out = self.bigblock1(out)
# 通过第二个 big_block
out = self.bigblock2(out)
# 通过第三个 big_block
out = self.bigblock3(out)
# 通过第四个 big_block
out = self.bigblock4(out)
# 通过全局平均池化层
out = self.averpool(out)
# 将特征图展平成二维
out = out.view(out.size(0), -1)
# 通过全连接层得到输出
out = self.linear(out)
return out
def register_hooks(model):
# 定义钩子函数 hook:这是一个内部函数,它会被注册为模型每一层的钩子。
# 在前向传播过程中,当每一层执行完毕时,这个钩子函数会被触发
# module:表示当前被钩子的模块(层)。例如,它可以是卷积层、池化层、全连接层等。
# input:是传入当前模块的输入张量,形状与输入数据的形状一致(通常是元组形式)。
# output:是从当前模块得到的输出张量
def hook(module, input, output):
# 获取当前层的类名,例如 Conv2d、ReLU、MaxPool2d 等
class_name = module.__class__.__name__
# 打印当前层的类名和其输出张量的形状
print(f"{class_name} Output Shape: {output.shape}")
# 注册钩子函数
for layer in model.modules():
# 排除 nn.Sequential:nn.Sequential 是一个包含多个层的容器类,我们通常希望钩子直接注册到具体的层,而不是容器本身。
# 排除 nn.ModuleList:nn.ModuleList 是另一种容器类,通常用于将多个模块按列表存储。
# 排除顶级的模型类本身:检查 (layer == model) 以确保不会在整个模型对象上注册钩子
if not isinstance(layer, nn.Sequential) and not isinstance(layer, nn.ModuleList) and not (layer == model):
# 对于通过检查的每一个 layer,调用 register_forward_hook 方法,将之前定义的 hook 函数注册到该层上
layer.register_forward_hook(hook)
if __name__=='__main__':
x = torch.randn(size=(1, 3, 224, 224))
model=Resnet()
# 注册钩子以打印每层的输出尺寸
# register_hooks(model)
y = model(x)
print(y.shape)
2.输出结果
打印每一层的输出:
-------------------------------------------------------------------------------------------------
# conv1:
Conv2d Output Shape: torch.Size([1, 64, 112, 112])
BatchNorm2d Output Shape: torch.Size([1, 64, 112, 112])
-------------------------------------------------------------------------------------------------
# maxpool:
MaxPool2d Output Shape: torch.Size([1, 64, 56, 56])
-------------------------------------------------------------------------------------------------
# conv2_x:
# 第一个实线BasicBlock:第一个卷积
Conv2d Output Shape: torch.Size([1, 64, 56, 56])
BatchNorm2d Output Shape: torch.Size([1, 64, 56, 56])
ReLU Output Shape: torch.Size([1, 64, 56, 56])
# 第一个实线BasicBlock:第二个卷积
Conv2d Output Shape: torch.Size([1, 64, 56, 56])
BatchNorm2d Output Shape: torch.Size([1, 64, 56, 56])
# Shortcut连接,直接相加
BasicBlock Output Shape: torch.Size([1, 64, 56, 56])
---------------------------------------------------
# 第二个实线BasicBlock:第一个卷积
Conv2d Output Shape: torch.Size([1, 64, 56, 56])
BatchNorm2d Output Shape: torch.Size([1, 64, 56, 56])
ReLU Output Shape: torch.Size([1, 64, 56, 56])
# 第二个实线BasicBlock:第二个卷积
Conv2d Output Shape: torch.Size([1, 64, 56, 56])
BatchNorm2d Output Shape: torch.Size([1, 64, 56, 56])
# Shortcut连接,直接相加
BasicBlock Output Shape: torch.Size([1, 64, 56, 56])
-------------------------------------------------------------------------------------------------
# conv3_x:
# 第一个虚线BasicBlock:第一个卷积
Conv2d Output Shape: torch.Size([1, 128, 28, 28])
BatchNorm2d Output Shape: torch.Size([1, 128, 28, 28])
ReLU Output Shape: torch.Size([1, 128, 28, 28])
# 第一个虚线BasicBlock:第二个卷积
Conv2d Output Shape: torch.Size([1, 128, 28, 28])
BatchNorm2d Output Shape: torch.Size([1, 128, 28, 28])
# Shortcut连接,过1×1卷积
Conv2d Output Shape: torch.Size([1, 128, 28, 28])
BatchNorm2d Output Shape: torch.Size([1, 128, 28, 28])
# 相加
BasicBlock Output Shape: torch.Size([1, 128, 28, 28])
---------------------------------------------------
# 第二个实线BasicBlock:第一个卷积
Conv2d Output Shape: torch.Size([1, 128, 28, 28])
BatchNorm2d Output Shape: torch.Size([1, 128, 28, 28])
ReLU Output Shape: torch.Size([1, 128, 28, 28])
# 第二个实线BasicBlock:第二个卷积
Conv2d Output Shape: torch.Size([1, 128, 28, 28])
BatchNorm2d Output Shape: torch.Size([1, 128, 28, 28])
# Shortcut连接,直接相加
BasicBlock Output Shape: torch.Size([1, 128, 28, 28])
-------------------------------------------------------------------------------------------------
# conv4_x:
# 第一个虚线BasicBlock:第一个卷积
Conv2d Output Shape: torch.Size([1, 256, 14, 14])
BatchNorm2d Output Shape: torch.Size([1, 256, 14, 14])
ReLU Output Shape: torch.Size([1, 256, 14, 14])
# 第一个虚线BasicBlock:第二个卷积
Conv2d Output Shape: torch.Size([1, 256, 14, 14])
BatchNorm2d Output Shape: torch.Size([1, 256, 14, 14])
# Shortcut连接,过1×1卷积
Conv2d Output Shape: torch.Size([1, 256, 14, 14])
BatchNorm2d Output Shape: torch.Size([1, 256, 14, 14])
# 相加
BasicBlock Output Shape: torch.Size([1, 256, 14, 14])
---------------------------------------------------
# 第二个实线BasicBlock:第一个卷积
Conv2d Output Shape: torch.Size([1, 256, 14, 14])
BatchNorm2d Output Shape: torch.Size([1, 256, 14, 14])
ReLU Output Shape: torch.Size([1, 256, 14, 14])
# 第二个实线BasicBlock:第一个卷积
Conv2d Output Shape: torch.Size([1, 256, 14, 14])
BatchNorm2d Output Shape: torch.Size([1, 256, 14, 14])
# Shortcut连接,直接相加
BasicBlock Output Shape: torch.Size([1, 256, 14, 14])
-------------------------------------------------------------------------------------------------
# conv5_x:
# 第一个虚线BasicBlock:第一个卷积
Conv2d Output Shape: torch.Size([1, 512, 7, 7])
BatchNorm2d Output Shape: torch.Size([1, 512, 7, 7])
ReLU Output Shape: torch.Size([1, 512, 7, 7])
# 第一个虚线BasicBlock:第二个卷积
Conv2d Output Shape: torch.Size([1, 512, 7, 7])
BatchNorm2d Output Shape: torch.Size([1, 512, 7, 7])
# Shortcut连接,过1×1卷积
Conv2d Output Shape: torch.Size([1, 512, 7, 7])
BatchNorm2d Output Shape: torch.Size([1, 512, 7, 7])
# 相加
BasicBlock Output Shape: torch.Size([1, 512, 7, 7])
---------------------------------------------------
# 第二个实线BasicBlock:第一个卷积
Conv2d Output Shape: torch.Size([1, 512, 7, 7])
BatchNorm2d Output Shape: torch.Size([1, 512, 7, 7])
ReLU Output Shape: torch.Size([1, 512, 7, 7])
# 第二个实线BasicBlock:第一个卷积
Conv2d Output Shape: torch.Size([1, 512, 7, 7])
BatchNorm2d Output Shape: torch.Size([1, 512, 7, 7])
# Shortcut连接,直接相加
BasicBlock Output Shape: torch.Size([1, 512, 7, 7])
-------------------------------------------------------------------------------------------------
AdaptiveAvgPool2d Output Shape: torch.Size([1, 512, 1, 1])
-------------------------------------------------------------------------------------------------
Linear Output Shape: torch.Size([1, 10])
-------------------------------------------------------------------------------------------------