当前位置: 首页 > article >正文

深度学习camp-ResNeXt-50实战解析

  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍖 原作者:K同学啊

一、前言

ResNeXt是由何凯明团队提出的基于ResNet的升级版分类网络。在残差网络的基础上,引入了cardinality的概念。
我们简单介绍ResNeXt中的核心概念:

核心概念

  1. Cardinality

    • Cardinality 是指在 ResNeXt 中引入的分组卷积的数目。它是一个描述变换集合大小的参数,这种方式被证实比增加宽度(每一层的过滤器数量)或增加深度(层的数量)更有效。
    • 在实践中,增加 cardinality 能够在保持计算复杂度相对不变的同时提高准确率。
  2. 分组卷积

    • ResNeXt 中的核心是使用分组卷积,这是受到 AlexNet 中的成功应用启发。分组卷积将输入的特征通道分成多个组,每个组使用独立的过滤器,这样可以在不增加计算负担的情况下增加网络的“路径”。
  3. 模块化设计

    • ResNeXt 的一个关键特点是它的模块化设计,每个块几乎是相同的布局。这使得 ResNeXt 更容易扩展和修改。每个残差块内部都采用分组卷积,并且这些块可以堆叠来构建更深的网络。

二、ResNeXt-50网络的复现

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchsummary

class ResNeXtBottleneck(nn.Module):
    expansion = 4 #用于扩展残差块最后一个卷积层的输出通道数
    
    def __init__(self, in_channels, out_channels, stride=1, cardinality=32, base_width=4):
        super(ResNeXtBottleneck, self).__init__()
        width_ratio = out_channels / (cardinality * base_width) #计算每个分组的宽度与基本宽度的比例
        D = cardinality * int(base_width * width_ratio) #每个分组的卷积层通道数
        
        self.conv_reduce = nn.Conv2d(in_channels, D, kernel_size=1, stride=1, padding=0, bias=False)
        self.bn_reduce = nn.BatchNorm2d(D)
        
        self.conv_conv = nn.Conv2d(D, D, kernel_size=3, stride=stride, padding=1, groups=cardinality, bias=False)
        self.bn = nn.BatchNorm2d(D)
        
        self.conv_expand = nn.Conv2d(D, out_channels * self.expansion, kernel_size=1, stride=1, padding=0, bias=False)
        self.bn_expand = nn.BatchNorm2d(out_channels * self.expansion)
        
        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels * self.expansion:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels * self.expansion, kernel_size=1, stride=stride, padding=0, bias=False),
                nn.BatchNorm2d(out_channels * self.expansion)
            )
    
    def forward(self, x):
        out = F.relu(self.bn_reduce(self.conv_reduce(x)), inplace=True)
        out = F.relu(self.bn(self.conv_conv(out)), inplace=True)
        out = self.bn_expand(self.conv_expand(out))
        out += self.shortcut(x)
        out = F.relu(out, inplace=True)
        return out

class ResNeXt(nn.Module):
    def __init__(self, block, layers, cardinality, num_classes=1000, base_width=4):
        super(ResNeXt, self).__init__()
        self.in_channels = 64

        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        self.layer1 = self._make_layer(block, 64, layers[0], 1, cardinality, base_width)
        self.layer2 = self._make_layer(block, 128, layers[1], 2, cardinality, base_width)
        self.layer3 = self._make_layer(block, 256, layers[2], 2, cardinality, base_width)
        self.layer4 = self._make_layer(block, 512, layers[3], 2, cardinality, base_width)

        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512 * block.expansion, num_classes)

    def _make_layer(self, block, out_channels, num_blocks, stride, cardinality, base_width):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_channels, out_channels, stride, cardinality, base_width))
            self.in_channels = out_channels * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.relu(self.bn1(self.conv1(x)))
        x = self.maxpool(x)
        
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        
        return x

def resnext50():
    model = ResNeXt(ResNeXtBottleneck, [3, 4, 6, 3], cardinality=32, num_classes=1000, base_width=4)
    return model

# 创建模型实例
model = resnext50().cuda()
model

主要的设计思路在于ResNeXtBottleneck的设计,ResNeXtBottleneck 类在 ResNeXt 网络中起着核心作用,主要功能和特点可以总结如下:

特征转换与增强

  1. 高效的特征转换
    • 降维和升维ResNeXtBottleneck 使用 1x1 卷积进行特征的降维和升维,这有助于减少计算量和参数数量,同时增强网络的学习能力。
    • 分组卷积:通过在 3x3 卷积层中使用分组卷积,这一结构有效地增加了网络的容量和复杂度,同时保持了参数数量和计算复杂度相对较低。分组卷积可以让每个卷积组专注于输入的一部分,增强了模型对特征的多样性捕捉能力。

增强网络深度和效率

  1. 提升网络性能
    • 残差连接:通过引入快捷连接,即残差连接,ResNeXtBottleneck 允许梯度直接流过这些连接,解决了深层网络训练中可能出现的梯度消失或爆炸问题。这种设计使得网络能够在增加深度的同时保持稳定的训练效果。
    • 批量归一化:每个卷积操作后都伴随一层批量归一化,有助于加快训练速度,提高模型的收敛速率,并且使网络对初始权重不那么敏感。

网络可扩展性

  1. 模块化设计
    • 标准化的构建块ResNeXtBottleneck 作为一个标准化的构建块,可以根据需要堆叠多个这样的块来构建更深的网络。这种模块化设计简化了网络的扩展和实验过程,使得研究者和开发者可以灵活地调整网络结构以适应不同的任务和数据集。

应用灵活性

  1. 适应多种规模的数据集和任务
    • 可配置的参数:通过调整 cardinality(分组的数量)和 base_width(基础宽度),ResNeXtBottleneck 可以根据具体任务的需求调整其容量和复杂度,适用于从小规模到大规模的各种数据集和任务。

总的来说,ResNeXtBottleneck 通过其高效的特征转换、强大的网络连接以及模块化的设计,在提高模型性能的同时,也保持了良好的可扩展性和适应性。这使得 ResNeXt 网络能够在多个重要的视觉识别任务中取得优异的表现。

代码输出:

ResNeXt(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): ResNeXtBottleneck(
      (conv_reduce): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn_reduce): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv_conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
      (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv_expand): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn_expand): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (shortcut): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): ResNeXtBottleneck(
      (conv_reduce): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn_reduce): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv_conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
      (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv_expand): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn_expand): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (shortcut): Sequential()
    )
    (2): ResNeXtBottleneck(
      (conv_reduce): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn_reduce): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv_conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
      (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv_expand): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn_expand): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (shortcut): Sequential()
    )
  )
  (layer2): Sequential(
    (0): ResNeXtBottleneck(
      (conv_reduce): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn_reduce): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv_conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=32, bias=False)
      (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv_expand): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn_expand): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (shortcut): Sequential(
        (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): ResNeXtBottleneck(
      (conv_reduce): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn_reduce): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv_conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
      (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv_expand): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn_expand): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (shortcut): Sequential()
    )
    (2): ResNeXtBottleneck(
      (conv_reduce): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn_reduce): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv_conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
      (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv_expand): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn_expand): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (shortcut): Sequential()
    )
    (3): ResNeXtBottleneck(
      (conv_reduce): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn_reduce): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv_conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
      (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv_expand): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn_expand): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (shortcut): Sequential()
    )
  )
  (layer3): Sequential(
    (0): ResNeXtBottleneck(
      (conv_reduce): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn_reduce): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv_conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=32, bias=False)
      (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv_expand): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn_expand): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (shortcut): Sequential(
        (0): Conv2d(512, 1024, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): ResNeXtBottleneck(
      (conv_reduce): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn_reduce): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv_conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
      (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv_expand): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn_expand): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (shortcut): Sequential()
    )
    (2): ResNeXtBottleneck(
      (conv_reduce): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn_reduce): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv_conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
      (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv_expand): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn_expand): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (shortcut): Sequential()
    )
    (3): ResNeXtBottleneck(
      (conv_reduce): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn_reduce): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv_conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
      (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv_expand): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn_expand): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (shortcut): Sequential()
    )
    (4): ResNeXtBottleneck(
      (conv_reduce): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn_reduce): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv_conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
      (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv_expand): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn_expand): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (shortcut): Sequential()
    )
    (5): ResNeXtBottleneck(
      (conv_reduce): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn_reduce): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv_conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
      (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv_expand): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn_expand): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (shortcut): Sequential()
    )
  )
  (layer4): Sequential(
    (0): ResNeXtBottleneck(
      (conv_reduce): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn_reduce): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv_conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=32, bias=False)
      (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv_expand): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn_expand): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (shortcut): Sequential(
        (0): Conv2d(1024, 2048, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): ResNeXtBottleneck(
      (conv_reduce): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn_reduce): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv_conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
      (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv_expand): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn_expand): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (shortcut): Sequential()
    )
    (2): ResNeXtBottleneck(
      (conv_reduce): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn_reduce): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv_conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
      (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv_expand): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn_expand): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (shortcut): Sequential()
    )
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
  (fc): Linear(in_features=2048, out_features=1000, bias=True)
)

我们可以看到网络的总体结构:

torchsummary.summary(model, input_size=(3, 224, 224))

代码输出:

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1         [-1, 64, 112, 112]           9,408
       BatchNorm2d-2         [-1, 64, 112, 112]             128
              ReLU-3         [-1, 64, 112, 112]               0
         MaxPool2d-4           [-1, 64, 56, 56]               0
            Conv2d-5           [-1, 64, 56, 56]           4,096
       BatchNorm2d-6           [-1, 64, 56, 56]             128
            Conv2d-7           [-1, 64, 56, 56]           1,152
       BatchNorm2d-8           [-1, 64, 56, 56]             128
            Conv2d-9          [-1, 256, 56, 56]          16,384
      BatchNorm2d-10          [-1, 256, 56, 56]             512
           Conv2d-11          [-1, 256, 56, 56]          16,384
      BatchNorm2d-12          [-1, 256, 56, 56]             512
ResNeXtBottleneck-13          [-1, 256, 56, 56]               0
           Conv2d-14           [-1, 64, 56, 56]          16,384
      BatchNorm2d-15           [-1, 64, 56, 56]             128
           Conv2d-16           [-1, 64, 56, 56]           1,152
      BatchNorm2d-17           [-1, 64, 56, 56]             128
           Conv2d-18          [-1, 256, 56, 56]          16,384
      BatchNorm2d-19          [-1, 256, 56, 56]             512
ResNeXtBottleneck-20          [-1, 256, 56, 56]               0
           Conv2d-21           [-1, 64, 56, 56]          16,384
      BatchNorm2d-22           [-1, 64, 56, 56]             128
           Conv2d-23           [-1, 64, 56, 56]           1,152
      BatchNorm2d-24           [-1, 64, 56, 56]             128
           Conv2d-25          [-1, 256, 56, 56]          16,384
      BatchNorm2d-26          [-1, 256, 56, 56]             512
ResNeXtBottleneck-27          [-1, 256, 56, 56]               0
           Conv2d-28          [-1, 128, 56, 56]          32,768
      BatchNorm2d-29          [-1, 128, 56, 56]             256
           Conv2d-30          [-1, 128, 28, 28]           4,608
      BatchNorm2d-31          [-1, 128, 28, 28]             256
           Conv2d-32          [-1, 512, 28, 28]          65,536
      BatchNorm2d-33          [-1, 512, 28, 28]           1,024
           Conv2d-34          [-1, 512, 28, 28]         131,072
      BatchNorm2d-35          [-1, 512, 28, 28]           1,024
ResNeXtBottleneck-36          [-1, 512, 28, 28]               0
           Conv2d-37          [-1, 128, 28, 28]          65,536
      BatchNorm2d-38          [-1, 128, 28, 28]             256
           Conv2d-39          [-1, 128, 28, 28]           4,608
      BatchNorm2d-40          [-1, 128, 28, 28]             256
           Conv2d-41          [-1, 512, 28, 28]          65,536
      BatchNorm2d-42          [-1, 512, 28, 28]           1,024
ResNeXtBottleneck-43          [-1, 512, 28, 28]               0
           Conv2d-44          [-1, 128, 28, 28]          65,536
      BatchNorm2d-45          [-1, 128, 28, 28]             256
           Conv2d-46          [-1, 128, 28, 28]           4,608
      BatchNorm2d-47          [-1, 128, 28, 28]             256
           Conv2d-48          [-1, 512, 28, 28]          65,536
      BatchNorm2d-49          [-1, 512, 28, 28]           1,024
ResNeXtBottleneck-50          [-1, 512, 28, 28]               0
           Conv2d-51          [-1, 128, 28, 28]          65,536
      BatchNorm2d-52          [-1, 128, 28, 28]             256
           Conv2d-53          [-1, 128, 28, 28]           4,608
      BatchNorm2d-54          [-1, 128, 28, 28]             256
           Conv2d-55          [-1, 512, 28, 28]          65,536
      BatchNorm2d-56          [-1, 512, 28, 28]           1,024
ResNeXtBottleneck-57          [-1, 512, 28, 28]               0
           Conv2d-58          [-1, 256, 28, 28]         131,072
      BatchNorm2d-59          [-1, 256, 28, 28]             512
           Conv2d-60          [-1, 256, 14, 14]          18,432
      BatchNorm2d-61          [-1, 256, 14, 14]             512
           Conv2d-62         [-1, 1024, 14, 14]         262,144
      BatchNorm2d-63         [-1, 1024, 14, 14]           2,048
           Conv2d-64         [-1, 1024, 14, 14]         524,288
      BatchNorm2d-65         [-1, 1024, 14, 14]           2,048
ResNeXtBottleneck-66         [-1, 1024, 14, 14]               0
           Conv2d-67          [-1, 256, 14, 14]         262,144
      BatchNorm2d-68          [-1, 256, 14, 14]             512
           Conv2d-69          [-1, 256, 14, 14]          18,432
      BatchNorm2d-70          [-1, 256, 14, 14]             512
           Conv2d-71         [-1, 1024, 14, 14]         262,144
      BatchNorm2d-72         [-1, 1024, 14, 14]           2,048
ResNeXtBottleneck-73         [-1, 1024, 14, 14]               0
           Conv2d-74          [-1, 256, 14, 14]         262,144
      BatchNorm2d-75          [-1, 256, 14, 14]             512
           Conv2d-76          [-1, 256, 14, 14]          18,432
      BatchNorm2d-77          [-1, 256, 14, 14]             512
           Conv2d-78         [-1, 1024, 14, 14]         262,144
      BatchNorm2d-79         [-1, 1024, 14, 14]           2,048
ResNeXtBottleneck-80         [-1, 1024, 14, 14]               0
           Conv2d-81          [-1, 256, 14, 14]         262,144
      BatchNorm2d-82          [-1, 256, 14, 14]             512
           Conv2d-83          [-1, 256, 14, 14]          18,432
      BatchNorm2d-84          [-1, 256, 14, 14]             512
           Conv2d-85         [-1, 1024, 14, 14]         262,144
      BatchNorm2d-86         [-1, 1024, 14, 14]           2,048
ResNeXtBottleneck-87         [-1, 1024, 14, 14]               0
           Conv2d-88          [-1, 256, 14, 14]         262,144
      BatchNorm2d-89          [-1, 256, 14, 14]             512
           Conv2d-90          [-1, 256, 14, 14]          18,432
      BatchNorm2d-91          [-1, 256, 14, 14]             512
           Conv2d-92         [-1, 1024, 14, 14]         262,144
      BatchNorm2d-93         [-1, 1024, 14, 14]           2,048
ResNeXtBottleneck-94         [-1, 1024, 14, 14]               0
           Conv2d-95          [-1, 256, 14, 14]         262,144
      BatchNorm2d-96          [-1, 256, 14, 14]             512
           Conv2d-97          [-1, 256, 14, 14]          18,432
      BatchNorm2d-98          [-1, 256, 14, 14]             512
           Conv2d-99         [-1, 1024, 14, 14]         262,144
     BatchNorm2d-100         [-1, 1024, 14, 14]           2,048
ResNeXtBottleneck-101         [-1, 1024, 14, 14]               0
          Conv2d-102          [-1, 512, 14, 14]         524,288
     BatchNorm2d-103          [-1, 512, 14, 14]           1,024
          Conv2d-104            [-1, 512, 7, 7]          73,728
     BatchNorm2d-105            [-1, 512, 7, 7]           1,024
          Conv2d-106           [-1, 2048, 7, 7]       1,048,576
     BatchNorm2d-107           [-1, 2048, 7, 7]           4,096
          Conv2d-108           [-1, 2048, 7, 7]       2,097,152
     BatchNorm2d-109           [-1, 2048, 7, 7]           4,096
ResNeXtBottleneck-110           [-1, 2048, 7, 7]               0
          Conv2d-111            [-1, 512, 7, 7]       1,048,576
     BatchNorm2d-112            [-1, 512, 7, 7]           1,024
          Conv2d-113            [-1, 512, 7, 7]          73,728
     BatchNorm2d-114            [-1, 512, 7, 7]           1,024
          Conv2d-115           [-1, 2048, 7, 7]       1,048,576
     BatchNorm2d-116           [-1, 2048, 7, 7]           4,096
ResNeXtBottleneck-117           [-1, 2048, 7, 7]               0
          Conv2d-118            [-1, 512, 7, 7]       1,048,576
     BatchNorm2d-119            [-1, 512, 7, 7]           1,024
          Conv2d-120            [-1, 512, 7, 7]          73,728
     BatchNorm2d-121            [-1, 512, 7, 7]           1,024
          Conv2d-122           [-1, 2048, 7, 7]       1,048,576
     BatchNorm2d-123           [-1, 2048, 7, 7]           4,096
ResNeXtBottleneck-124           [-1, 2048, 7, 7]               0
AdaptiveAvgPool2d-125           [-1, 2048, 1, 1]               0
          Linear-126                 [-1, 1000]       2,049,000
================================================================
Total params: 14,593,448
Trainable params: 14,593,448
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.57
Forward/backward pass size (MB): 219.37
Params size (MB): 55.67
Estimated Total Size (MB): 275.62
----------------------------------------------------------------

三、用于乳腺癌的识别

数据的预处理与之前一样,我们直接看结果:

D:\App\anaconda\Lib\site-packages\torch\optim\lr_scheduler.py:60: UserWarning: The verbose parameter is deprecated. Please use get_last_lr() to access the learning rate.
  warnings.warn(

Epoch: 1, Train_acc:80.9%, Train_loss:0.573, Test_acc:83.0%, Test_loss:0.433, Lr:1.00E-04
Epoch: 2, Train_acc:84.9%, Train_loss:0.350, Test_acc:85.9%, Test_loss:0.321, Lr:1.00E-04
Epoch: 3, Train_acc:86.5%, Train_loss:0.323, Test_acc:85.8%, Test_loss:0.320, Lr:1.00E-04
Epoch: 4, Train_acc:87.4%, Train_loss:0.302, Test_acc:87.6%, Test_loss:0.294, Lr:1.00E-04
Epoch: 5, Train_acc:88.0%, Train_loss:0.292, Test_acc:87.6%, Test_loss:0.314, Lr:1.00E-04
Epoch: 6, Train_acc:88.5%, Train_loss:0.276, Test_acc:87.7%, Test_loss:0.290, Lr:1.00E-04
Epoch: 7, Train_acc:89.9%, Train_loss:0.248, Test_acc:88.7%, Test_loss:0.281, Lr:1.00E-04
Epoch: 8, Train_acc:89.7%, Train_loss:0.246, Test_acc:88.6%, Test_loss:0.279, Lr:1.00E-04
Epoch: 9, Train_acc:90.3%, Train_loss:0.233, Test_acc:89.5%, Test_loss:0.244, Lr:1.00E-04
Epoch:10, Train_acc:90.9%, Train_loss:0.227, Test_acc:88.7%, Test_loss:0.279, Lr:1.00E-04
Epoch:11, Train_acc:91.6%, Train_loss:0.205, Test_acc:85.4%, Test_loss:0.340, Lr:1.00E-04
Epoch:12, Train_acc:92.4%, Train_loss:0.185, Test_acc:86.2%, Test_loss:0.379, Lr:1.00E-04
Epoch:13, Train_acc:93.1%, Train_loss:0.177, Test_acc:87.1%, Test_loss:0.334, Lr:1.00E-04
Epoch:14, Train_acc:93.9%, Train_loss:0.152, Test_acc:88.0%, Test_loss:0.324, Lr:1.00E-04
Epoch:15, Train_acc:95.2%, Train_loss:0.120, Test_acc:88.0%, Test_loss:0.367, Lr:1.00E-05
Epoch:16, Train_acc:97.5%, Train_loss:0.069, Test_acc:89.8%, Test_loss:0.337, Lr:1.00E-05
Epoch:17, Train_acc:98.6%, Train_loss:0.045, Test_acc:89.7%, Test_loss:0.359, Lr:1.00E-05
Epoch:18, Train_acc:98.9%, Train_loss:0.033, Test_acc:89.1%, Test_loss:0.373, Lr:1.00E-05
Epoch:19, Train_acc:99.2%, Train_loss:0.024, Test_acc:89.2%, Test_loss:0.391, Lr:1.00E-05
Epoch:20, Train_acc:99.4%, Train_loss:0.022, Test_acc:89.8%, Test_loss:0.405, Lr:1.00E-05
Epoch:21, Train_acc:99.3%, Train_loss:0.022, Test_acc:89.0%, Test_loss:0.426, Lr:1.00E-06
Epoch:22, Train_acc:99.5%, Train_loss:0.016, Test_acc:89.1%, Test_loss:0.432, Lr:1.00E-06
Epoch:23, Train_acc:99.6%, Train_loss:0.016, Test_acc:89.1%, Test_loss:0.435, Lr:1.00E-06
Epoch:24, Train_acc:99.5%, Train_loss:0.017, Test_acc:89.1%, Test_loss:0.444, Lr:1.00E-06
Epoch:25, Train_acc:99.6%, Train_loss:0.014, Test_acc:89.1%, Test_loss:0.437, Lr:1.00E-06
Epoch:26, Train_acc:99.6%, Train_loss:0.014, Test_acc:89.0%, Test_loss:0.457, Lr:1.00E-06
Epoch:27, Train_acc:99.4%, Train_loss:0.024, Test_acc:89.3%, Test_loss:0.458, Lr:1.00E-07
Epoch:28, Train_acc:99.6%, Train_loss:0.016, Test_acc:89.1%, Test_loss:0.443, Lr:1.00E-07
Epoch:29, Train_acc:99.7%, Train_loss:0.013, Test_acc:89.2%, Test_loss:0.432, Lr:1.00E-07
Epoch:30, Train_acc:99.7%, Train_loss:0.012, Test_acc:89.3%, Test_loss:0.442, Lr:1.00E-07
Epoch:31, Train_acc:99.6%, Train_loss:0.019, Test_acc:89.3%, Test_loss:0.427, Lr:1.00E-07
Epoch:32, Train_acc:99.6%, Train_loss:0.016, Test_acc:89.6%, Test_loss:0.439, Lr:1.00E-07
Done

预测的准确性还不错,我们将数据可视化:
在这里插入图片描述
可以看到在训练集中的准确度一直上升,并且损失函数不断降低,但是对于测试集的准确度很难后续上升,我们再看在验证集中的准确度:

def validate(dataloader, model):
    model.eval()
    size = len(dataloader.dataset)
    num_batches = len(dataloader)

    validate_acc = 0

    for x, y in dataloader:
        x, y = x.to(device), y.to(device)

        pred = model(x)

        validate_acc += (pred.argmax(1) == y).type(torch.float).sum().item()

    validate_acc /= size

    return validate_acc


# 计算验证集准确率
validate_acc = validate(validate_dl, best_model)
print(f"Validation Accuracy: {validate_acc:.2%}")

代码输出:

Validation Accuracy: 89.81%

验证集的准确度和测试集相近。


http://www.kler.cn/a/461869.html

相关文章:

  • 【文献精读笔记】Explainability for Large Language Models: A Survey (大语言模型的可解释性综述)(三)
  • 聆听音乐 1.5.9 | 畅听全网音乐,支持无损音质下载
  • pygame飞机大战
  • 数字孪生:物联+数据打造洞察世界新视角
  • sniff2sipp: 把 pcap 处理成 sipp.xml
  • 机器学习之正则化惩罚和K折交叉验证调整逻辑回归模型
  • ffmpeg指令
  • Axure PR 9 Banner 轮播图 设计交互
  • Ps:创建数据驱动的图像
  • 瑞芯微(RK)平台调试MIPI屏幕
  • 力扣2110股票平滑下跌阶段的数目
  • excel操作
  • linux-centos8-安装make
  • Ubuntu20.04安装Foxit Reader 福昕阅读器
  • 展望2025:在创新与协作中创造价值、奉献佳作
  • An object could not be cloned 错误
  • hpcrunner
  • 计算机基础知识复习1.1
  • 【机器学习 | 数据挖掘】时间序列算法
  • 小程序组件 —— 23 组件案例 - 轮播图图片添加
  • Excel 面试 03 多个条件函数 SUMIFS
  • Django-Easy-Audit 实战:轻松实现数据审计
  • 【2024最新】基于Python+Mysql+PyQT5的数学函数绘图软件Lw+PPT
  • Unity3D仿星露谷物语开发12之创建道具列表
  • iOS 中的 nil、Nil、NULL、NSNull 僵尸对象和野指针
  • Disruptor 有哪些典型的使用场景?