当前位置: 首页 > article >正文

Residual_残差模块

残差的模块的实现有两种方式,

一种使用两层相同 conv3x3 实现的,  即此时卷积核的大小是相同的;

另外一种方式, 两边使用conv1x1 实现, 中间使用 conv3x3, 这种也成为bottleNeck,

在这里插入图片描述

原文中提出了两种block,如上图,
左边的称作basic block,
右边的称为bottle neck。

结构都是在卷积层后面,添加一跳short cut,将输入与卷积层组的输出相加。

1. Basic block

注意观察结构,basic block中包含两个卷积层,卷积核数量相同,卷积核均为3x3;

1.1  代码实现:

def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=dilation, groups=groups, bias=False, dilation=dilation)


def conv1x1(in_planes, out_planes, stride=1):
    """1x1 convolution"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
                 base_width=64, dilation=1, norm_layer=None):
        super(BasicBlock, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        if groups != 1 or base_width != 64:
            raise ValueError('BasicBlock only supports groups=1 and base_width=64')
        if dilation > 1:
            raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
        # Both self.conv1 and self.downsample layers downsample the input when stride != 1
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = norm_layer(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = norm_layer(planes)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

        return out


2. BottleNeck

bottle neck的结构是前两组滤波核数量相同,第三层滤波核数量是前两组的4倍,第二层尺寸3x3,其余两层尺寸是1x1。

2.1 代码实现:

class Bottleneck(nn.Module):
    # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
    # while original implementation places the stride at the first 1x1 convolution(self.conv1)
    # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
    # This variant is also known as ResNet V1.5 and improves accuracy according to
    # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.

    expansion = 4

    def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
                 base_width=64, dilation=1, norm_layer=None):
        super(Bottleneck, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        width = int(planes * (base_width / 64.)) * groups
        # Both self.conv2 and self.downsample layers downsample the input when stride != 1
        self.conv1 = conv1x1(inplanes, width)
        self.bn1 = norm_layer(width)
        self.conv2 = conv3x3(width, width, stride, groups, dilation)
        self.bn2 = norm_layer(width)
        self.conv3 = conv1x1(width, planes * self.expansion)
        self.bn3 = norm_layer(planes * self.expansion)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

        return out

http://www.kler.cn/a/5043.html

相关文章:

  • gesp(C++四级)(11)洛谷:B4005:[GESP202406 四级] 黑白方块
  • 完美解决VMware 17.0 Pro安装ubuntu、Deepin等虚拟机后卡顿、卡死问题
  • 手机的ip地址是根据电话卡归属地定吗
  • 代码随想录 哈希 test 8
  • 代码随想录算法训练营day27
  • Effective C++读书笔记——item13(使用对象管理资源)
  • 蓝桥杯嵌入式STM32 LED模块化封装
  • 数据结构合集
  • 开放平台之敏感数据加密处理
  • JAVA反射机制知多少
  • 红黑树、B树以及B+树及应用
  • 劝退还是坚守?计算机视觉行业综述
  • QT开发笔记(AP3216C )
  • 【新2023Q2模拟题JAVA】华为OD机试 - 最少停车数
  • 【Spring6】资源操作:Resources
  • Hive 流量分析(含维度和不含维度计算)
  • 37.Flexbox简介
  • vue3与vue2的区别
  • 代码随想录算法训练营第四十二天 | 01背包问题,你该了解这些、01背包问题,你该了解这些 滚动数组、 416. 分割等和子集
  • Cron表达式
  • 搭建Vue3工程化
  • Dynamics 365 Customer Service入门
  • 基于springboot实现私人健身与教练预约管理系统【源码+论文】
  • Python量化交易08——利用Tushare获取日K数据
  • 异步流程控制 遍历篇filter
  • 【线程池的工作参数、什么情况下会触发最大线程数?什么情况下会回收线程?】