当前位置: 首页 > article >正文

深度学习camp-第J5周:DenseNet+SE-Net实战

  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍖 原作者:K同学啊

📌 本周任务(自己改进一下):
●1. 在DenseNet系列算法中插入SE-Net通道注意力机制,并完成乳腺癌数据集识别
●2. 改进思路是否可以迁移到其他地方呢
●3. 验证集accuracy比较

一、前言

SE-Net(Squeeze-and-Excitation Network),它是 ImageNet 2017 竞赛的冠军模型,由 WMW 团队提出。具有复杂度低,参数少和计算量小的优点。SE-Net 的设计思路简单,容易与现有的网络结构(如 Inception 和 ResNet)结合,增强这些网络的性能。传统的网络结构(如 Inception)主要关注空间维度的特征提升,而 SE-Net 则将重点放在特征通道之间的关系上。这意味着它关注的是不同特征通道的重要性,而不仅仅是空间位置。SE-Net 通过学习每个特征通道的重要性,来自动调整特征的权重。具体来说,它会提升那些对当前任务有用的特征,同时抑制那些不重要的特征。这一过程被称为“特征重标定”。SE 模块是 SE-Net 的核心组成部分,负责实现特征重标定的功能。虽然具体的 SE 模块图未显示,但通常它包括两个主要步骤:首先通过全局平均池化获取特征通道的全局信息,然后通过全连接层学习每个通道的重要性,最后根据这些重要性调整特征通道的输出。如下图所示:
在这里插入图片描述
对于给定的一个输入x,其特征通道数为C’,经过一系列操作变换后通道数变为C,SE-Net会进行如下操作

  1. Squeeze 操作: 通过全局平均池化将每个通道的特征压缩成一个单一的数值,从而得到一个全局空间信息的通道描述符。这一步可以视为对每个通道的特征进行“压缩”,从而总结出通道的全局信息。
  2. Excitation 操作: 采用一个全连接的神经网络,通常包含两层,第一层用来降维(减少模型复杂度和参数量),第二层用来恢复维度。这个过程通过 Sigmoid 函数输出每个通道的权重系数,从而实现对每个通道的“激励”。
  3. Scale 操作: 最后,通过将 Excitation 操作的输出(即通道的权重系数)与原始输入按元素相乘,实现了对特征的重新缩放。这种按权重调整通道输出的方法,增强了模型对有用特征的捕捉能力,同时抑制了不重要的特征。

二、SE模块的通用性

SE模块的一个优点在于他可以直接应用于现有的网络结构中,以Inception和ResNet为例,我们只需要在Inception模块和Residual模块后面加上SE模块即可。
在这里插入图片描述

三、SE模块的代码实现

class SELayer(nn.Module):
    def __init__(self, channels, reduction= 16):
        super(SELayer, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1) #Squeeze操作
        self.fc = nn.Sequential(
            nn.Linear(channels, channels // reduction, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(channels // reduction, channels, bias=False),
            nn.Sigmoid() #产生通道权重
        )
        
    def forward(self, x):
        b, c, h, w = x.size()
        y = self.avg_pool(x).view(b,c)
        y = self.fc(y).view(b,c,1,1)
        return x * y.expand_as(x)

四、SE模块插入到DenseNet网络中

我直接把网络结构放出来:

import torch
import torch.nn as nn
import torch.nn.functional as F

class SELayer(nn.Module):
    def __init__(self, channels, reduction= 16):
        super(SELayer, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1) #Squeeze操作
        self.fc = nn.Sequential(
            nn.Linear(channels, channels // reduction, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(channels // reduction, channels, bias=False),
            nn.Sigmoid() #产生通道权重
        )
        
    def forward(self, x):
        b, c, h, w = x.size()
        y = self.avg_pool(x).view(b,c)
        y = self.fc(y).view(b,c,1,1)
        return x * y.expand_as(x)

#定义卷积块
class ConvBlock(nn.Module):
    def __init__(self, in_channels, growth_rate):
        super(ConvBlock, self).__init__()
        self.conv1 = nn.Sequential(
            nn.BatchNorm2d(in_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels, 4 * growth_rate, kernel_size=1, bias=False),
            nn.BatchNorm2d(4 * growth_rate),
            nn.ReLU(inplace=True),
            nn.Conv2d(4 * growth_rate, growth_rate, kernel_size=3, padding=1, bias=False)
        )
        
    def forward(self, x):
        x = self.conv1(x)
        return x

class DenseBlock(nn.Module):
    def __init__(self, num_layers, in_channels, growth_rate):
        super(DenseBlock, self).__init__()
        self.layers = nn.ModuleList()
        for i in range(num_layers):
            self.layers.append(ConvBlock(in_channels + i * growth_rate, growth_rate))
    
    def forward(self, x):
        for layer in self.layers:
            new_features = layer(x)
            x = torch.cat([x, new_features], dim=1)
        return x

class TransitionBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(TransitionBlock, self).__init__()
        self.bn = nn.BatchNorm2d(in_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
        self.avg_pool = nn.AvgPool2d(kernel_size=2, stride=2)
        
    def forward(self, x):
        x = self.bn(x)
        x = self.relu(x)
        x = self.conv(x)
        x = self.avg_pool(x)
        return x

# 构建DenseNet
class DenseNet(nn.Module):
    def __init__(self, block_config, num_classes=1000, growth_rate=32):
        super(DenseNet, self).__init__()
        num_init_features = 64
        self.features = nn.Sequential(
            nn.Conv2d(3, num_init_features, kernel_size=7, stride=2, padding=3, bias=False),
            nn.BatchNorm2d(num_init_features),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        )
        num_features = num_init_features
        for i, num_blocks in enumerate(block_config):
            block = DenseBlock(num_blocks, num_features, growth_rate)
            self.features.add_module('denseblock{}'.format(i + 1), block)
            num_features += num_blocks * growth_rate
            if i != len(block_config) - 1:
                trans = TransitionBlock(num_features, num_features // 2)
                self.features.add_module('transition{}'.format(i + 1), trans)
                num_features = num_features // 2
        self.features.add_module('norm5', nn.BatchNorm2d(num_features))
        self.classifier = nn.Linear(num_features, num_classes)
        self.SE_layer = SELayer(num_features)  # SE Layer should be initialized with the correct number of channels

    def forward(self, x):
        x = self.features(x)
        x = self.SE_layer(x)
        x = F.relu(x, inplace=True)
        x = F.adaptive_avg_pool2d(x, (1, 1))
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x

    
# 定义不同的 DenseNet 配置
def DenseNet121(num_classes=3):
    return DenseNet([6, 12, 24, 16], num_classes=num_classes)

def DenseNet169(num_classes=3):
    return DenseNet([6, 12, 32, 32], num_classes=num_classes)

def DenseNet201(num_classes=3):
    return DenseNet([6, 12, 48, 32], num_classes=num_classes)

import torchsummary
model1 = DenseNet121().cuda()
x = (3, 224, 224)

torchsummary.summary(model1,x)
model2 = DenseNet169().cuda()
model3 = DenseNet201().cuda()

代码输出:

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1         [-1, 64, 112, 112]           9,408
       BatchNorm2d-2         [-1, 64, 112, 112]             128
              ReLU-3         [-1, 64, 112, 112]               0
         MaxPool2d-4           [-1, 64, 56, 56]               0
       BatchNorm2d-5           [-1, 64, 56, 56]             128
              ReLU-6           [-1, 64, 56, 56]               0
            Conv2d-7          [-1, 128, 56, 56]           8,192
       BatchNorm2d-8          [-1, 128, 56, 56]             256
              ReLU-9          [-1, 128, 56, 56]               0
           Conv2d-10           [-1, 32, 56, 56]          36,864
        ConvBlock-11           [-1, 32, 56, 56]               0
      BatchNorm2d-12           [-1, 96, 56, 56]             192
             ReLU-13           [-1, 96, 56, 56]               0
           Conv2d-14          [-1, 128, 56, 56]          12,288
      BatchNorm2d-15          [-1, 128, 56, 56]             256
             ReLU-16          [-1, 128, 56, 56]               0
           Conv2d-17           [-1, 32, 56, 56]          36,864
        ConvBlock-18           [-1, 32, 56, 56]               0
      BatchNorm2d-19          [-1, 128, 56, 56]             256
             ReLU-20          [-1, 128, 56, 56]               0
           Conv2d-21          [-1, 128, 56, 56]          16,384
      BatchNorm2d-22          [-1, 128, 56, 56]             256
             ReLU-23          [-1, 128, 56, 56]               0
           Conv2d-24           [-1, 32, 56, 56]          36,864
        ConvBlock-25           [-1, 32, 56, 56]               0
      BatchNorm2d-26          [-1, 160, 56, 56]             320
             ReLU-27          [-1, 160, 56, 56]               0
           Conv2d-28          [-1, 128, 56, 56]          20,480
      BatchNorm2d-29          [-1, 128, 56, 56]             256
             ReLU-30          [-1, 128, 56, 56]               0
           Conv2d-31           [-1, 32, 56, 56]          36,864
        ConvBlock-32           [-1, 32, 56, 56]               0
      BatchNorm2d-33          [-1, 192, 56, 56]             384
             ReLU-34          [-1, 192, 56, 56]               0
           Conv2d-35          [-1, 128, 56, 56]          24,576
      BatchNorm2d-36          [-1, 128, 56, 56]             256
             ReLU-37          [-1, 128, 56, 56]               0
           Conv2d-38           [-1, 32, 56, 56]          36,864
        ConvBlock-39           [-1, 32, 56, 56]               0
      BatchNorm2d-40          [-1, 224, 56, 56]             448
             ReLU-41          [-1, 224, 56, 56]               0
           Conv2d-42          [-1, 128, 56, 56]          28,672
      BatchNorm2d-43          [-1, 128, 56, 56]             256
             ReLU-44          [-1, 128, 56, 56]               0
           Conv2d-45           [-1, 32, 56, 56]          36,864
        ConvBlock-46           [-1, 32, 56, 56]               0
       DenseBlock-47          [-1, 256, 56, 56]               0
      BatchNorm2d-48          [-1, 256, 56, 56]             512
             ReLU-49          [-1, 256, 56, 56]               0
           Conv2d-50          [-1, 128, 56, 56]          32,768
        AvgPool2d-51          [-1, 128, 28, 28]               0
  TransitionBlock-52          [-1, 128, 28, 28]               0
      BatchNorm2d-53          [-1, 128, 28, 28]             256
             ReLU-54          [-1, 128, 28, 28]               0
           Conv2d-55          [-1, 128, 28, 28]          16,384
      BatchNorm2d-56          [-1, 128, 28, 28]             256
             ReLU-57          [-1, 128, 28, 28]               0
           Conv2d-58           [-1, 32, 28, 28]          36,864
        ConvBlock-59           [-1, 32, 28, 28]               0
      BatchNorm2d-60          [-1, 160, 28, 28]             320
             ReLU-61          [-1, 160, 28, 28]               0
           Conv2d-62          [-1, 128, 28, 28]          20,480
      BatchNorm2d-63          [-1, 128, 28, 28]             256
             ReLU-64          [-1, 128, 28, 28]               0
           Conv2d-65           [-1, 32, 28, 28]          36,864
        ConvBlock-66           [-1, 32, 28, 28]               0
      BatchNorm2d-67          [-1, 192, 28, 28]             384
             ReLU-68          [-1, 192, 28, 28]               0
           Conv2d-69          [-1, 128, 28, 28]          24,576
      BatchNorm2d-70          [-1, 128, 28, 28]             256
             ReLU-71          [-1, 128, 28, 28]               0
           Conv2d-72           [-1, 32, 28, 28]          36,864
        ConvBlock-73           [-1, 32, 28, 28]               0
      BatchNorm2d-74          [-1, 224, 28, 28]             448
             ReLU-75          [-1, 224, 28, 28]               0
           Conv2d-76          [-1, 128, 28, 28]          28,672
      BatchNorm2d-77          [-1, 128, 28, 28]             256
             ReLU-78          [-1, 128, 28, 28]               0
           Conv2d-79           [-1, 32, 28, 28]          36,864
        ConvBlock-80           [-1, 32, 28, 28]               0
      BatchNorm2d-81          [-1, 256, 28, 28]             512
             ReLU-82          [-1, 256, 28, 28]               0
           Conv2d-83          [-1, 128, 28, 28]          32,768
      BatchNorm2d-84          [-1, 128, 28, 28]             256
             ReLU-85          [-1, 128, 28, 28]               0
           Conv2d-86           [-1, 32, 28, 28]          36,864
        ConvBlock-87           [-1, 32, 28, 28]               0
      BatchNorm2d-88          [-1, 288, 28, 28]             576
             ReLU-89          [-1, 288, 28, 28]               0
           Conv2d-90          [-1, 128, 28, 28]          36,864
      BatchNorm2d-91          [-1, 128, 28, 28]             256
             ReLU-92          [-1, 128, 28, 28]               0
           Conv2d-93           [-1, 32, 28, 28]          36,864
        ConvBlock-94           [-1, 32, 28, 28]               0
      BatchNorm2d-95          [-1, 320, 28, 28]             640
             ReLU-96          [-1, 320, 28, 28]               0
           Conv2d-97          [-1, 128, 28, 28]          40,960
      BatchNorm2d-98          [-1, 128, 28, 28]             256
             ReLU-99          [-1, 128, 28, 28]               0
          Conv2d-100           [-1, 32, 28, 28]          36,864
       ConvBlock-101           [-1, 32, 28, 28]               0
     BatchNorm2d-102          [-1, 352, 28, 28]             704
            ReLU-103          [-1, 352, 28, 28]               0
          Conv2d-104          [-1, 128, 28, 28]          45,056
     BatchNorm2d-105          [-1, 128, 28, 28]             256
            ReLU-106          [-1, 128, 28, 28]               0
          Conv2d-107           [-1, 32, 28, 28]          36,864
       ConvBlock-108           [-1, 32, 28, 28]               0
     BatchNorm2d-109          [-1, 384, 28, 28]             768
            ReLU-110          [-1, 384, 28, 28]               0
          Conv2d-111          [-1, 128, 28, 28]          49,152
     BatchNorm2d-112          [-1, 128, 28, 28]             256
            ReLU-113          [-1, 128, 28, 28]               0
          Conv2d-114           [-1, 32, 28, 28]          36,864
       ConvBlock-115           [-1, 32, 28, 28]               0
     BatchNorm2d-116          [-1, 416, 28, 28]             832
            ReLU-117          [-1, 416, 28, 28]               0
          Conv2d-118          [-1, 128, 28, 28]          53,248
     BatchNorm2d-119          [-1, 128, 28, 28]             256
            ReLU-120          [-1, 128, 28, 28]               0
          Conv2d-121           [-1, 32, 28, 28]          36,864
       ConvBlock-122           [-1, 32, 28, 28]               0
     BatchNorm2d-123          [-1, 448, 28, 28]             896
            ReLU-124          [-1, 448, 28, 28]               0
          Conv2d-125          [-1, 128, 28, 28]          57,344
     BatchNorm2d-126          [-1, 128, 28, 28]             256
            ReLU-127          [-1, 128, 28, 28]               0
          Conv2d-128           [-1, 32, 28, 28]          36,864
       ConvBlock-129           [-1, 32, 28, 28]               0
     BatchNorm2d-130          [-1, 480, 28, 28]             960
            ReLU-131          [-1, 480, 28, 28]               0
          Conv2d-132          [-1, 128, 28, 28]          61,440
     BatchNorm2d-133          [-1, 128, 28, 28]             256
            ReLU-134          [-1, 128, 28, 28]               0
          Conv2d-135           [-1, 32, 28, 28]          36,864
       ConvBlock-136           [-1, 32, 28, 28]               0
      DenseBlock-137          [-1, 512, 28, 28]               0
     BatchNorm2d-138          [-1, 512, 28, 28]           1,024
            ReLU-139          [-1, 512, 28, 28]               0
          Conv2d-140          [-1, 256, 28, 28]         131,072
       AvgPool2d-141          [-1, 256, 14, 14]               0
 TransitionBlock-142          [-1, 256, 14, 14]               0
     BatchNorm2d-143          [-1, 256, 14, 14]             512
            ReLU-144          [-1, 256, 14, 14]               0
          Conv2d-145          [-1, 128, 14, 14]          32,768
     BatchNorm2d-146          [-1, 128, 14, 14]             256
            ReLU-147          [-1, 128, 14, 14]               0
          Conv2d-148           [-1, 32, 14, 14]          36,864
       ConvBlock-149           [-1, 32, 14, 14]               0
     BatchNorm2d-150          [-1, 288, 14, 14]             576
            ReLU-151          [-1, 288, 14, 14]               0
          Conv2d-152          [-1, 128, 14, 14]          36,864
     BatchNorm2d-153          [-1, 128, 14, 14]             256
            ReLU-154          [-1, 128, 14, 14]               0
          Conv2d-155           [-1, 32, 14, 14]          36,864
       ConvBlock-156           [-1, 32, 14, 14]               0
     BatchNorm2d-157          [-1, 320, 14, 14]             640
            ReLU-158          [-1, 320, 14, 14]               0
          Conv2d-159          [-1, 128, 14, 14]          40,960
     BatchNorm2d-160          [-1, 128, 14, 14]             256
            ReLU-161          [-1, 128, 14, 14]               0
          Conv2d-162           [-1, 32, 14, 14]          36,864
       ConvBlock-163           [-1, 32, 14, 14]               0
     BatchNorm2d-164          [-1, 352, 14, 14]             704
            ReLU-165          [-1, 352, 14, 14]               0
          Conv2d-166          [-1, 128, 14, 14]          45,056
     BatchNorm2d-167          [-1, 128, 14, 14]             256
            ReLU-168          [-1, 128, 14, 14]               0
          Conv2d-169           [-1, 32, 14, 14]          36,864
       ConvBlock-170           [-1, 32, 14, 14]               0
     BatchNorm2d-171          [-1, 384, 14, 14]             768
            ReLU-172          [-1, 384, 14, 14]               0
          Conv2d-173          [-1, 128, 14, 14]          49,152
     BatchNorm2d-174          [-1, 128, 14, 14]             256
            ReLU-175          [-1, 128, 14, 14]               0
          Conv2d-176           [-1, 32, 14, 14]          36,864
       ConvBlock-177           [-1, 32, 14, 14]               0
     BatchNorm2d-178          [-1, 416, 14, 14]             832
            ReLU-179          [-1, 416, 14, 14]               0
          Conv2d-180          [-1, 128, 14, 14]          53,248
     BatchNorm2d-181          [-1, 128, 14, 14]             256
            ReLU-182          [-1, 128, 14, 14]               0
          Conv2d-183           [-1, 32, 14, 14]          36,864
       ConvBlock-184           [-1, 32, 14, 14]               0
     BatchNorm2d-185          [-1, 448, 14, 14]             896
            ReLU-186          [-1, 448, 14, 14]               0
          Conv2d-187          [-1, 128, 14, 14]          57,344
     BatchNorm2d-188          [-1, 128, 14, 14]             256
            ReLU-189          [-1, 128, 14, 14]               0
          Conv2d-190           [-1, 32, 14, 14]          36,864
       ConvBlock-191           [-1, 32, 14, 14]               0
     BatchNorm2d-192          [-1, 480, 14, 14]             960
            ReLU-193          [-1, 480, 14, 14]               0
          Conv2d-194          [-1, 128, 14, 14]          61,440
     BatchNorm2d-195          [-1, 128, 14, 14]             256
            ReLU-196          [-1, 128, 14, 14]               0
          Conv2d-197           [-1, 32, 14, 14]          36,864
       ConvBlock-198           [-1, 32, 14, 14]               0
     BatchNorm2d-199          [-1, 512, 14, 14]           1,024
            ReLU-200          [-1, 512, 14, 14]               0
          Conv2d-201          [-1, 128, 14, 14]          65,536
     BatchNorm2d-202          [-1, 128, 14, 14]             256
            ReLU-203          [-1, 128, 14, 14]               0
          Conv2d-204           [-1, 32, 14, 14]          36,864
       ConvBlock-205           [-1, 32, 14, 14]               0
     BatchNorm2d-206          [-1, 544, 14, 14]           1,088
            ReLU-207          [-1, 544, 14, 14]               0
          Conv2d-208          [-1, 128, 14, 14]          69,632
     BatchNorm2d-209          [-1, 128, 14, 14]             256
            ReLU-210          [-1, 128, 14, 14]               0
          Conv2d-211           [-1, 32, 14, 14]          36,864
       ConvBlock-212           [-1, 32, 14, 14]               0
     BatchNorm2d-213          [-1, 576, 14, 14]           1,152
            ReLU-214          [-1, 576, 14, 14]               0
          Conv2d-215          [-1, 128, 14, 14]          73,728
     BatchNorm2d-216          [-1, 128, 14, 14]             256
            ReLU-217          [-1, 128, 14, 14]               0
          Conv2d-218           [-1, 32, 14, 14]          36,864
       ConvBlock-219           [-1, 32, 14, 14]               0
     BatchNorm2d-220          [-1, 608, 14, 14]           1,216
            ReLU-221          [-1, 608, 14, 14]               0
          Conv2d-222          [-1, 128, 14, 14]          77,824
     BatchNorm2d-223          [-1, 128, 14, 14]             256
            ReLU-224          [-1, 128, 14, 14]               0
          Conv2d-225           [-1, 32, 14, 14]          36,864
       ConvBlock-226           [-1, 32, 14, 14]               0
     BatchNorm2d-227          [-1, 640, 14, 14]           1,280
            ReLU-228          [-1, 640, 14, 14]               0
          Conv2d-229          [-1, 128, 14, 14]          81,920
     BatchNorm2d-230          [-1, 128, 14, 14]             256
            ReLU-231          [-1, 128, 14, 14]               0
          Conv2d-232           [-1, 32, 14, 14]          36,864
       ConvBlock-233           [-1, 32, 14, 14]               0
     BatchNorm2d-234          [-1, 672, 14, 14]           1,344
            ReLU-235          [-1, 672, 14, 14]               0
          Conv2d-236          [-1, 128, 14, 14]          86,016
     BatchNorm2d-237          [-1, 128, 14, 14]             256
            ReLU-238          [-1, 128, 14, 14]               0
          Conv2d-239           [-1, 32, 14, 14]          36,864
       ConvBlock-240           [-1, 32, 14, 14]               0
     BatchNorm2d-241          [-1, 704, 14, 14]           1,408
            ReLU-242          [-1, 704, 14, 14]               0
          Conv2d-243          [-1, 128, 14, 14]          90,112
     BatchNorm2d-244          [-1, 128, 14, 14]             256
            ReLU-245          [-1, 128, 14, 14]               0
          Conv2d-246           [-1, 32, 14, 14]          36,864
       ConvBlock-247           [-1, 32, 14, 14]               0
     BatchNorm2d-248          [-1, 736, 14, 14]           1,472
            ReLU-249          [-1, 736, 14, 14]               0
          Conv2d-250          [-1, 128, 14, 14]          94,208
     BatchNorm2d-251          [-1, 128, 14, 14]             256
            ReLU-252          [-1, 128, 14, 14]               0
          Conv2d-253           [-1, 32, 14, 14]          36,864
       ConvBlock-254           [-1, 32, 14, 14]               0
     BatchNorm2d-255          [-1, 768, 14, 14]           1,536
            ReLU-256          [-1, 768, 14, 14]               0
          Conv2d-257          [-1, 128, 14, 14]          98,304
     BatchNorm2d-258          [-1, 128, 14, 14]             256
            ReLU-259          [-1, 128, 14, 14]               0
          Conv2d-260           [-1, 32, 14, 14]          36,864
       ConvBlock-261           [-1, 32, 14, 14]               0
     BatchNorm2d-262          [-1, 800, 14, 14]           1,600
            ReLU-263          [-1, 800, 14, 14]               0
          Conv2d-264          [-1, 128, 14, 14]         102,400
     BatchNorm2d-265          [-1, 128, 14, 14]             256
            ReLU-266          [-1, 128, 14, 14]               0
          Conv2d-267           [-1, 32, 14, 14]          36,864
       ConvBlock-268           [-1, 32, 14, 14]               0
     BatchNorm2d-269          [-1, 832, 14, 14]           1,664
            ReLU-270          [-1, 832, 14, 14]               0
          Conv2d-271          [-1, 128, 14, 14]         106,496
     BatchNorm2d-272          [-1, 128, 14, 14]             256
            ReLU-273          [-1, 128, 14, 14]               0
          Conv2d-274           [-1, 32, 14, 14]          36,864
       ConvBlock-275           [-1, 32, 14, 14]               0
     BatchNorm2d-276          [-1, 864, 14, 14]           1,728
            ReLU-277          [-1, 864, 14, 14]               0
          Conv2d-278          [-1, 128, 14, 14]         110,592
     BatchNorm2d-279          [-1, 128, 14, 14]             256
            ReLU-280          [-1, 128, 14, 14]               0
          Conv2d-281           [-1, 32, 14, 14]          36,864
       ConvBlock-282           [-1, 32, 14, 14]               0
     BatchNorm2d-283          [-1, 896, 14, 14]           1,792
            ReLU-284          [-1, 896, 14, 14]               0
          Conv2d-285          [-1, 128, 14, 14]         114,688
     BatchNorm2d-286          [-1, 128, 14, 14]             256
            ReLU-287          [-1, 128, 14, 14]               0
          Conv2d-288           [-1, 32, 14, 14]          36,864
       ConvBlock-289           [-1, 32, 14, 14]               0
     BatchNorm2d-290          [-1, 928, 14, 14]           1,856
            ReLU-291          [-1, 928, 14, 14]               0
          Conv2d-292          [-1, 128, 14, 14]         118,784
     BatchNorm2d-293          [-1, 128, 14, 14]             256
            ReLU-294          [-1, 128, 14, 14]               0
          Conv2d-295           [-1, 32, 14, 14]          36,864
       ConvBlock-296           [-1, 32, 14, 14]               0
     BatchNorm2d-297          [-1, 960, 14, 14]           1,920
            ReLU-298          [-1, 960, 14, 14]               0
          Conv2d-299          [-1, 128, 14, 14]         122,880
     BatchNorm2d-300          [-1, 128, 14, 14]             256
            ReLU-301          [-1, 128, 14, 14]               0
          Conv2d-302           [-1, 32, 14, 14]          36,864
       ConvBlock-303           [-1, 32, 14, 14]               0
     BatchNorm2d-304          [-1, 992, 14, 14]           1,984
            ReLU-305          [-1, 992, 14, 14]               0
          Conv2d-306          [-1, 128, 14, 14]         126,976
     BatchNorm2d-307          [-1, 128, 14, 14]             256
            ReLU-308          [-1, 128, 14, 14]               0
          Conv2d-309           [-1, 32, 14, 14]          36,864
       ConvBlock-310           [-1, 32, 14, 14]               0
      DenseBlock-311         [-1, 1024, 14, 14]               0
     BatchNorm2d-312         [-1, 1024, 14, 14]           2,048
            ReLU-313         [-1, 1024, 14, 14]               0
          Conv2d-314          [-1, 512, 14, 14]         524,288
       AvgPool2d-315            [-1, 512, 7, 7]               0
 TransitionBlock-316            [-1, 512, 7, 7]               0
     BatchNorm2d-317            [-1, 512, 7, 7]           1,024
            ReLU-318            [-1, 512, 7, 7]               0
          Conv2d-319            [-1, 128, 7, 7]          65,536
     BatchNorm2d-320            [-1, 128, 7, 7]             256
            ReLU-321            [-1, 128, 7, 7]               0
          Conv2d-322             [-1, 32, 7, 7]          36,864
       ConvBlock-323             [-1, 32, 7, 7]               0
     BatchNorm2d-324            [-1, 544, 7, 7]           1,088
            ReLU-325            [-1, 544, 7, 7]               0
          Conv2d-326            [-1, 128, 7, 7]          69,632
     BatchNorm2d-327            [-1, 128, 7, 7]             256
            ReLU-328            [-1, 128, 7, 7]               0
          Conv2d-329             [-1, 32, 7, 7]          36,864
       ConvBlock-330             [-1, 32, 7, 7]               0
     BatchNorm2d-331            [-1, 576, 7, 7]           1,152
            ReLU-332            [-1, 576, 7, 7]               0
          Conv2d-333            [-1, 128, 7, 7]          73,728
     BatchNorm2d-334            [-1, 128, 7, 7]             256
            ReLU-335            [-1, 128, 7, 7]               0
          Conv2d-336             [-1, 32, 7, 7]          36,864
       ConvBlock-337             [-1, 32, 7, 7]               0
     BatchNorm2d-338            [-1, 608, 7, 7]           1,216
            ReLU-339            [-1, 608, 7, 7]               0
          Conv2d-340            [-1, 128, 7, 7]          77,824
     BatchNorm2d-341            [-1, 128, 7, 7]             256
            ReLU-342            [-1, 128, 7, 7]               0
          Conv2d-343             [-1, 32, 7, 7]          36,864
       ConvBlock-344             [-1, 32, 7, 7]               0
     BatchNorm2d-345            [-1, 640, 7, 7]           1,280
            ReLU-346            [-1, 640, 7, 7]               0
          Conv2d-347            [-1, 128, 7, 7]          81,920
     BatchNorm2d-348            [-1, 128, 7, 7]             256
            ReLU-349            [-1, 128, 7, 7]               0
          Conv2d-350             [-1, 32, 7, 7]          36,864
       ConvBlock-351             [-1, 32, 7, 7]               0
     BatchNorm2d-352            [-1, 672, 7, 7]           1,344
            ReLU-353            [-1, 672, 7, 7]               0
          Conv2d-354            [-1, 128, 7, 7]          86,016
     BatchNorm2d-355            [-1, 128, 7, 7]             256
            ReLU-356            [-1, 128, 7, 7]               0
          Conv2d-357             [-1, 32, 7, 7]          36,864
       ConvBlock-358             [-1, 32, 7, 7]               0
     BatchNorm2d-359            [-1, 704, 7, 7]           1,408
            ReLU-360            [-1, 704, 7, 7]               0
          Conv2d-361            [-1, 128, 7, 7]          90,112
     BatchNorm2d-362            [-1, 128, 7, 7]             256
            ReLU-363            [-1, 128, 7, 7]               0
          Conv2d-364             [-1, 32, 7, 7]          36,864
       ConvBlock-365             [-1, 32, 7, 7]               0
     BatchNorm2d-366            [-1, 736, 7, 7]           1,472
            ReLU-367            [-1, 736, 7, 7]               0
          Conv2d-368            [-1, 128, 7, 7]          94,208
     BatchNorm2d-369            [-1, 128, 7, 7]             256
            ReLU-370            [-1, 128, 7, 7]               0
          Conv2d-371             [-1, 32, 7, 7]          36,864
       ConvBlock-372             [-1, 32, 7, 7]               0
     BatchNorm2d-373            [-1, 768, 7, 7]           1,536
            ReLU-374            [-1, 768, 7, 7]               0
          Conv2d-375            [-1, 128, 7, 7]          98,304
     BatchNorm2d-376            [-1, 128, 7, 7]             256
            ReLU-377            [-1, 128, 7, 7]               0
          Conv2d-378             [-1, 32, 7, 7]          36,864
       ConvBlock-379             [-1, 32, 7, 7]               0
     BatchNorm2d-380            [-1, 800, 7, 7]           1,600
            ReLU-381            [-1, 800, 7, 7]               0
          Conv2d-382            [-1, 128, 7, 7]         102,400
     BatchNorm2d-383            [-1, 128, 7, 7]             256
            ReLU-384            [-1, 128, 7, 7]               0
          Conv2d-385             [-1, 32, 7, 7]          36,864
       ConvBlock-386             [-1, 32, 7, 7]               0
     BatchNorm2d-387            [-1, 832, 7, 7]           1,664
            ReLU-388            [-1, 832, 7, 7]               0
          Conv2d-389            [-1, 128, 7, 7]         106,496
     BatchNorm2d-390            [-1, 128, 7, 7]             256
            ReLU-391            [-1, 128, 7, 7]               0
          Conv2d-392             [-1, 32, 7, 7]          36,864
       ConvBlock-393             [-1, 32, 7, 7]               0
     BatchNorm2d-394            [-1, 864, 7, 7]           1,728
            ReLU-395            [-1, 864, 7, 7]               0
          Conv2d-396            [-1, 128, 7, 7]         110,592
     BatchNorm2d-397            [-1, 128, 7, 7]             256
            ReLU-398            [-1, 128, 7, 7]               0
          Conv2d-399             [-1, 32, 7, 7]          36,864
       ConvBlock-400             [-1, 32, 7, 7]               0
     BatchNorm2d-401            [-1, 896, 7, 7]           1,792
            ReLU-402            [-1, 896, 7, 7]               0
          Conv2d-403            [-1, 128, 7, 7]         114,688
     BatchNorm2d-404            [-1, 128, 7, 7]             256
            ReLU-405            [-1, 128, 7, 7]               0
          Conv2d-406             [-1, 32, 7, 7]          36,864
       ConvBlock-407             [-1, 32, 7, 7]               0
     BatchNorm2d-408            [-1, 928, 7, 7]           1,856
            ReLU-409            [-1, 928, 7, 7]               0
          Conv2d-410            [-1, 128, 7, 7]         118,784
     BatchNorm2d-411            [-1, 128, 7, 7]             256
            ReLU-412            [-1, 128, 7, 7]               0
          Conv2d-413             [-1, 32, 7, 7]          36,864
       ConvBlock-414             [-1, 32, 7, 7]               0
     BatchNorm2d-415            [-1, 960, 7, 7]           1,920
            ReLU-416            [-1, 960, 7, 7]               0
          Conv2d-417            [-1, 128, 7, 7]         122,880
     BatchNorm2d-418            [-1, 128, 7, 7]             256
            ReLU-419            [-1, 128, 7, 7]               0
          Conv2d-420             [-1, 32, 7, 7]          36,864
       ConvBlock-421             [-1, 32, 7, 7]               0
     BatchNorm2d-422            [-1, 992, 7, 7]           1,984
            ReLU-423            [-1, 992, 7, 7]               0
          Conv2d-424            [-1, 128, 7, 7]         126,976
     BatchNorm2d-425            [-1, 128, 7, 7]             256
            ReLU-426            [-1, 128, 7, 7]               0
          Conv2d-427             [-1, 32, 7, 7]          36,864
       ConvBlock-428             [-1, 32, 7, 7]               0
      DenseBlock-429           [-1, 1024, 7, 7]               0
     BatchNorm2d-430           [-1, 1024, 7, 7]           2,048
AdaptiveAvgPool2d-431           [-1, 1024, 1, 1]               0
          Linear-432                   [-1, 64]          65,536
            ReLU-433                   [-1, 64]               0
          Linear-434                 [-1, 1024]          65,536
         Sigmoid-435                 [-1, 1024]               0
         SELayer-436           [-1, 1024, 7, 7]               0
          Linear-437                    [-1, 3]           3,075
================================================================
Total params: 7,088,003
Trainable params: 7,088,003
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.57
Forward/backward pass size (MB): 315.27
Params size (MB): 27.04
Estimated Total Size (MB): 342.88
----------------------------------------------------------------

五、用于乳腺癌识别以及比较网络之间的差异

这次数据的预处理部分与前几周完全相同,我们不在给出代码,直接看训练部分的代码:

import copy
from torch.optim.lr_scheduler import ReduceLROnPlateau

opt1 = torch.optim.Adam(model1.parameters(), lr= 1e-4)
scheduler1 = ReduceLROnPlateau(opt1, mode='min', factor=0.1, patience=5, verbose=True) # 当指标(如损失)连续 5 次没有改善时,将学习率乘以 0.1

opt2 = torch.optim.Adam(model2.parameters(), lr= 1e-4)
scheduler2 = ReduceLROnPlateau(opt2, mode='min', factor=0.1, patience=5, verbose=True)

opt3 = torch.optim.Adam(model3.parameters(), lr= 1e-4)
scheduler3 = ReduceLROnPlateau(opt3, mode='min', factor=0.1, patience=5, verbose=True)

loss_fn = nn.CrossEntropyLoss() # 交叉熵

epochs = 32

train_loss = []
train_acc  = []
test_loss  = []
test_acc   = []

best_acc1 = 0    # 设置一个最佳准确率,作为最佳模型的判别指标
best_acc2 = 0
best_acc3 = 0

for epoch in range(epochs):
    model1.train()
    epoch_train_acc, epoch_train_loss = train(train_dl, model1, loss_fn, opt1)
    
    model1.eval()
    epoch_test_acc, epoch_test_loss = test(test_dl, model1, loss_fn)
    scheduler1.step(epoch_test_loss)
    
    if epoch_test_acc > best_acc1:
        best_acc1 = epoch_test_acc
        best_model1 = copy.deepcopy(model1)

    train_acc.append(epoch_train_acc)
    train_loss.append(epoch_train_loss)
    test_acc.append(epoch_test_acc)
    test_loss.append(epoch_test_loss)

     # 获取当前的学习率
    lr = opt1.state_dict()['param_groups'][0]['lr']
    
    template = ('Epoch:{:2d}, Train_acc:{:.1f}%, Train_loss:{:.3f}, Test_acc:{:.1f}%, Test_loss:{:.3f}, Lr:{:.2E}')
    print(template.format(epoch+1, epoch_train_acc*100, epoch_train_loss, 
                          epoch_test_acc*100, epoch_test_loss, lr))
    
# 保存最佳模型到文件中
PATH = './best_model1.pth'  # 保存的参数文件名
torch.save(best_model1.state_dict(), PATH)

print('Done1')

for epoch in range(epochs):
    model2.train()
    epoch_train_acc, epoch_train_loss = train(train_dl, model2, loss_fn, opt2)
    
    model2.eval()
    epoch_test_acc, epoch_test_loss = test(test_dl, model2, loss_fn)
    scheduler2.step(epoch_test_loss)
    
    if epoch_test_acc > best_acc2:
        best_acc2 = epoch_test_acc
        best_model2 = copy.deepcopy(model2)

    train_acc.append(epoch_train_acc)
    train_loss.append(epoch_train_loss)
    test_acc.append(epoch_test_acc)
    test_loss.append(epoch_test_loss)

     # 获取当前的学习率
    lr = opt2.state_dict()['param_groups'][0]['lr']
    
    template = ('Epoch:{:2d}, Train_acc:{:.1f}%, Train_loss:{:.3f}, Test_acc:{:.1f}%, Test_loss:{:.3f}, Lr:{:.2E}')
    print(template.format(epoch+1, epoch_train_acc*100, epoch_train_loss, 
                          epoch_test_acc*100, epoch_test_loss, lr))
    
# 保存最佳模型到文件中
PATH = './best_model2.pth'  # 保存的参数文件名
torch.save(best_model2.state_dict(), PATH)

print('Done2')

for epoch in range(epochs):
    model3.train()
    epoch_train_acc, epoch_train_loss = train(train_dl, model3, loss_fn, opt3)
    
    model3.eval()
    epoch_test_acc, epoch_test_loss = test(test_dl, model3, loss_fn)
    scheduler3.step(epoch_test_loss)
    
    if epoch_test_acc > best_acc3:
        best_acc3 = epoch_test_acc
        best_model3 = copy.deepcopy(model3)

    train_acc.append(epoch_train_acc)
    train_loss.append(epoch_train_loss)
    test_acc.append(epoch_test_acc)
    test_loss.append(epoch_test_loss)

     # 获取当前的学习率
    lr = opt3.state_dict()['param_groups'][0]['lr']
    
    template = ('Epoch:{:2d}, Train_acc:{:.1f}%, Train_loss:{:.3f}, Test_acc:{:.1f}%, Test_loss:{:.3f}, Lr:{:.2E}')
    print(template.format(epoch+1, epoch_train_acc*100, epoch_train_loss, 
                          epoch_test_acc*100, epoch_test_loss, lr))
    
# 保存最佳模型到文件中
PATH = './best_model3.pth'  # 保存的参数文件名
torch.save(best_model3.state_dict(), PATH)

print('Done3')

代码输出:

Epoch: 1, Train_acc:94.9%, Train_loss:0.132, Test_acc:94.2%, Test_loss:0.153, Lr:1.00E-04
Epoch: 2, Train_acc:95.5%, Train_loss:0.113, Test_acc:95.1%, Test_loss:0.137, Lr:1.00E-04
Epoch: 3, Train_acc:96.4%, Train_loss:0.098, Test_acc:94.4%, Test_loss:0.145, Lr:1.00E-04
Epoch: 4, Train_acc:96.4%, Train_loss:0.101, Test_acc:91.0%, Test_loss:0.317, Lr:1.00E-04
Epoch: 5, Train_acc:96.7%, Train_loss:0.092, Test_acc:93.6%, Test_loss:0.193, Lr:1.00E-04
Epoch: 6, Train_acc:96.9%, Train_loss:0.079, Test_acc:91.9%, Test_loss:0.217, Lr:1.00E-04
Epoch: 7, Train_acc:97.5%, Train_loss:0.065, Test_acc:95.1%, Test_loss:0.153, Lr:1.00E-04
Epoch: 8, Train_acc:97.7%, Train_loss:0.062, Test_acc:93.9%, Test_loss:0.154, Lr:1.00E-05
Epoch: 9, Train_acc:99.1%, Train_loss:0.031, Test_acc:96.2%, Test_loss:0.115, Lr:1.00E-05
Epoch:10, Train_acc:99.6%, Train_loss:0.016, Test_acc:96.4%, Test_loss:0.112, Lr:1.00E-05
Epoch:11, Train_acc:99.7%, Train_loss:0.015, Test_acc:96.3%, Test_loss:0.116, Lr:1.00E-05
Epoch:12, Train_acc:99.7%, Train_loss:0.012, Test_acc:96.5%, Test_loss:0.115, Lr:1.00E-05
Epoch:13, Train_acc:99.9%, Train_loss:0.008, Test_acc:96.5%, Test_loss:0.123, Lr:1.00E-05
Epoch:14, Train_acc:99.8%, Train_loss:0.013, Test_acc:96.3%, Test_loss:0.126, Lr:1.00E-05
Epoch:15, Train_acc:99.8%, Train_loss:0.008, Test_acc:96.2%, Test_loss:0.125, Lr:1.00E-05
Epoch:16, Train_acc:99.9%, Train_loss:0.006, Test_acc:96.3%, Test_loss:0.134, Lr:1.00E-06
Epoch:17, Train_acc:99.9%, Train_loss:0.006, Test_acc:96.4%, Test_loss:0.126, Lr:1.00E-06
Epoch:18, Train_acc:99.9%, Train_loss:0.006, Test_acc:96.5%, Test_loss:0.126, Lr:1.00E-06
Epoch:19, Train_acc:99.9%, Train_loss:0.005, Test_acc:96.5%, Test_loss:0.125, Lr:1.00E-06
Epoch:20, Train_acc:99.9%, Train_loss:0.005, Test_acc:96.4%, Test_loss:0.125, Lr:1.00E-06
Epoch:21, Train_acc:99.9%, Train_loss:0.007, Test_acc:96.6%, Test_loss:0.121, Lr:1.00E-06
Epoch:22, Train_acc:100.0%, Train_loss:0.005, Test_acc:96.4%, Test_loss:0.126, Lr:1.00E-07
Epoch:23, Train_acc:99.9%, Train_loss:0.004, Test_acc:96.5%, Test_loss:0.122, Lr:1.00E-07
Epoch:24, Train_acc:99.9%, Train_loss:0.011, Test_acc:96.2%, Test_loss:0.130, Lr:1.00E-07
Epoch:25, Train_acc:100.0%, Train_loss:0.004, Test_acc:96.6%, Test_loss:0.127, Lr:1.00E-07
Epoch:26, Train_acc:99.9%, Train_loss:0.005, Test_acc:96.6%, Test_loss:0.126, Lr:1.00E-07
Epoch:27, Train_acc:99.9%, Train_loss:0.004, Test_acc:96.6%, Test_loss:0.123, Lr:1.00E-07
Epoch:28, Train_acc:99.9%, Train_loss:0.005, Test_acc:96.5%, Test_loss:0.120, Lr:1.00E-08
Epoch:29, Train_acc:99.9%, Train_loss:0.005, Test_acc:96.6%, Test_loss:0.122, Lr:1.00E-08
Epoch:30, Train_acc:99.9%, Train_loss:0.005, Test_acc:96.5%, Test_loss:0.125, Lr:1.00E-08
Epoch:31, Train_acc:99.9%, Train_loss:0.005, Test_acc:96.2%, Test_loss:0.127, Lr:1.00E-08
Epoch:32, Train_acc:100.0%, Train_loss:0.004, Test_acc:96.6%, Test_loss:0.121, Lr:1.00E-08
Done1
Epoch: 1, Train_acc:93.5%, Train_loss:0.161, Test_acc:90.6%, Test_loss:0.243, Lr:1.00E-04
Epoch: 2, Train_acc:94.3%, Train_loss:0.145, Test_acc:93.8%, Test_loss:0.183, Lr:1.00E-04
Epoch: 3, Train_acc:94.6%, Train_loss:0.132, Test_acc:92.2%, Test_loss:0.213, Lr:1.00E-04
Epoch: 4, Train_acc:95.0%, Train_loss:0.126, Test_acc:89.1%, Test_loss:0.299, Lr:1.00E-04
Epoch: 5, Train_acc:95.5%, Train_loss:0.115, Test_acc:89.7%, Test_loss:0.248, Lr:1.00E-04
Epoch: 6, Train_acc:96.2%, Train_loss:0.102, Test_acc:94.1%, Test_loss:0.152, Lr:1.00E-04
Epoch: 7, Train_acc:96.4%, Train_loss:0.093, Test_acc:93.9%, Test_loss:0.172, Lr:1.00E-04
Epoch: 8, Train_acc:96.8%, Train_loss:0.077, Test_acc:94.1%, Test_loss:0.165, Lr:1.00E-04
Epoch: 9, Train_acc:97.3%, Train_loss:0.072, Test_acc:92.5%, Test_loss:0.268, Lr:1.00E-04
Epoch:10, Train_acc:97.8%, Train_loss:0.059, Test_acc:90.1%, Test_loss:0.384, Lr:1.00E-04
Epoch:11, Train_acc:97.6%, Train_loss:0.064, Test_acc:94.0%, Test_loss:0.204, Lr:1.00E-04
Epoch:12, Train_acc:97.5%, Train_loss:0.063, Test_acc:92.9%, Test_loss:0.234, Lr:1.00E-05
Epoch:13, Train_acc:99.4%, Train_loss:0.023, Test_acc:94.8%, Test_loss:0.162, Lr:1.00E-05
Epoch:14, Train_acc:99.8%, Train_loss:0.013, Test_acc:95.1%, Test_loss:0.161, Lr:1.00E-05
Epoch:15, Train_acc:99.7%, Train_loss:0.012, Test_acc:95.1%, Test_loss:0.164, Lr:1.00E-05
Epoch:16, Train_acc:99.7%, Train_loss:0.010, Test_acc:94.8%, Test_loss:0.179, Lr:1.00E-05
Epoch:17, Train_acc:99.9%, Train_loss:0.009, Test_acc:94.9%, Test_loss:0.163, Lr:1.00E-05
Epoch:18, Train_acc:99.8%, Train_loss:0.007, Test_acc:95.1%, Test_loss:0.168, Lr:1.00E-06
Epoch:19, Train_acc:99.9%, Train_loss:0.006, Test_acc:95.1%, Test_loss:0.175, Lr:1.00E-06
Epoch:20, Train_acc:99.9%, Train_loss:0.006, Test_acc:95.1%, Test_loss:0.174, Lr:1.00E-06
Epoch:21, Train_acc:99.9%, Train_loss:0.005, Test_acc:94.7%, Test_loss:0.179, Lr:1.00E-06
Epoch:22, Train_acc:100.0%, Train_loss:0.006, Test_acc:94.8%, Test_loss:0.188, Lr:1.00E-06
Epoch:23, Train_acc:99.9%, Train_loss:0.005, Test_acc:95.1%, Test_loss:0.179, Lr:1.00E-06
Epoch:24, Train_acc:99.9%, Train_loss:0.005, Test_acc:94.8%, Test_loss:0.178, Lr:1.00E-07
Epoch:25, Train_acc:99.9%, Train_loss:0.003, Test_acc:95.1%, Test_loss:0.181, Lr:1.00E-07
Epoch:26, Train_acc:99.9%, Train_loss:0.005, Test_acc:95.2%, Test_loss:0.179, Lr:1.00E-07
Epoch:27, Train_acc:99.9%, Train_loss:0.009, Test_acc:95.1%, Test_loss:0.191, Lr:1.00E-07
Epoch:28, Train_acc:99.9%, Train_loss:0.004, Test_acc:94.9%, Test_loss:0.185, Lr:1.00E-07
Epoch:29, Train_acc:99.9%, Train_loss:0.004, Test_acc:95.1%, Test_loss:0.191, Lr:1.00E-07
Epoch:30, Train_acc:100.0%, Train_loss:0.004, Test_acc:95.3%, Test_loss:0.175, Lr:1.00E-08
Epoch:31, Train_acc:100.0%, Train_loss:0.004, Test_acc:95.2%, Test_loss:0.175, Lr:1.00E-08
Epoch:32, Train_acc:99.9%, Train_loss:0.004, Test_acc:95.1%, Test_loss:0.179, Lr:1.00E-08
Done2
Epoch: 1, Train_acc:89.5%, Train_loss:0.254, Test_acc:89.7%, Test_loss:0.263, Lr:1.00E-04
Epoch: 2, Train_acc:89.7%, Train_loss:0.248, Test_acc:88.5%, Test_loss:0.290, Lr:1.00E-04
Epoch: 3, Train_acc:91.2%, Train_loss:0.222, Test_acc:91.3%, Test_loss:0.222, Lr:1.00E-04
Epoch: 4, Train_acc:91.2%, Train_loss:0.214, Test_acc:84.8%, Test_loss:0.399, Lr:1.00E-04
Epoch: 5, Train_acc:91.9%, Train_loss:0.205, Test_acc:91.0%, Test_loss:0.233, Lr:1.00E-04
Epoch: 6, Train_acc:92.0%, Train_loss:0.193, Test_acc:90.0%, Test_loss:0.267, Lr:1.00E-04
Epoch: 7, Train_acc:92.1%, Train_loss:0.188, Test_acc:89.4%, Test_loss:0.255, Lr:1.00E-04
Epoch: 8, Train_acc:92.9%, Train_loss:0.180, Test_acc:92.0%, Test_loss:0.193, Lr:1.00E-04
Epoch: 9, Train_acc:93.4%, Train_loss:0.165, Test_acc:89.0%, Test_loss:0.300, Lr:1.00E-04
Epoch:10, Train_acc:93.3%, Train_loss:0.169, Test_acc:86.3%, Test_loss:0.332, Lr:1.00E-04
Epoch:11, Train_acc:94.1%, Train_loss:0.145, Test_acc:91.0%, Test_loss:0.217, Lr:1.00E-04
Epoch:12, Train_acc:94.2%, Train_loss:0.146, Test_acc:89.7%, Test_loss:0.267, Lr:1.00E-04
Epoch:13, Train_acc:94.6%, Train_loss:0.136, Test_acc:90.7%, Test_loss:0.244, Lr:1.00E-04
Epoch:14, Train_acc:94.9%, Train_loss:0.131, Test_acc:92.0%, Test_loss:0.226, Lr:1.00E-05
Epoch:15, Train_acc:97.0%, Train_loss:0.083, Test_acc:93.0%, Test_loss:0.180, Lr:1.00E-05
Epoch:16, Train_acc:97.9%, Train_loss:0.061, Test_acc:93.2%, Test_loss:0.181, Lr:1.00E-05
Epoch:17, Train_acc:98.0%, Train_loss:0.056, Test_acc:93.2%, Test_loss:0.191, Lr:1.00E-05
Epoch:18, Train_acc:98.5%, Train_loss:0.047, Test_acc:92.9%, Test_loss:0.199, Lr:1.00E-05
Epoch:19, Train_acc:98.6%, Train_loss:0.041, Test_acc:93.0%, Test_loss:0.214, Lr:1.00E-05
Epoch:20, Train_acc:99.0%, Train_loss:0.031, Test_acc:93.1%, Test_loss:0.211, Lr:1.00E-05
Epoch:21, Train_acc:99.1%, Train_loss:0.029, Test_acc:93.0%, Test_loss:0.230, Lr:1.00E-06
Epoch:22, Train_acc:99.4%, Train_loss:0.024, Test_acc:92.7%, Test_loss:0.231, Lr:1.00E-06
Epoch:23, Train_acc:99.4%, Train_loss:0.022, Test_acc:93.2%, Test_loss:0.223, Lr:1.00E-06
Epoch:24, Train_acc:99.4%, Train_loss:0.022, Test_acc:92.9%, Test_loss:0.225, Lr:1.00E-06
Epoch:25, Train_acc:99.3%, Train_loss:0.024, Test_acc:92.9%, Test_loss:0.236, Lr:1.00E-06
Epoch:26, Train_acc:99.5%, Train_loss:0.019, Test_acc:92.9%, Test_loss:0.232, Lr:1.00E-06
Epoch:27, Train_acc:99.4%, Train_loss:0.021, Test_acc:93.0%, Test_loss:0.235, Lr:1.00E-07
Epoch:28, Train_acc:99.4%, Train_loss:0.024, Test_acc:92.8%, Test_loss:0.236, Lr:1.00E-07
Epoch:29, Train_acc:99.5%, Train_loss:0.019, Test_acc:93.1%, Test_loss:0.231, Lr:1.00E-07
Epoch:30, Train_acc:99.5%, Train_loss:0.018, Test_acc:93.0%, Test_loss:0.233, Lr:1.00E-07
Epoch:31, Train_acc:99.5%, Train_loss:0.025, Test_acc:93.2%, Test_loss:0.234, Lr:1.00E-07
Epoch:32, Train_acc:99.5%, Train_loss:0.018, Test_acc:93.0%, Test_loss:0.236, Lr:1.00E-07
Done3

比较三种模型的测试集预测的准确率:

print(best_acc1, best_acc2, best_acc3)

代码输出:

0.9664179104477612 0.953150912106136 0.9324212271973465

可以看到DenseNet121的准确率最好,但是只做了一次训练,不能这么绝对。
再来看验证集的准确率

def validate(dataloader, model):
    model.eval()
    size = len(dataloader.dataset)

    validate_acc = 0

    for x, y in dataloader:
        x, y = x.to(device), y.to(device)

        pred = model(x)

        validate_acc += (pred.argmax(1) == y).type(torch.float).sum().item()

    validate_acc /= size

    return validate_acc


# 计算验证集准确率
validate_acc1 = validate(validate_dl, best_model1)
validate_acc2 = validate(validate_dl, best_model2)
validate_acc3 = validate(validate_dl, best_model3)
print(f"Validation Accuracy: {validate_acc1:.2%},{validate_acc2:.2%},{validate_acc3:.2%}")

代码输出:

Validation Accuracy: 96.08%,94.97%,93.47%

同样也是DenseNet121的准确率最好。

六、总结与讨论

SE-Net可以给特征通道加上权重,更好的训练出模型,并且相比单纯的DenseNet,同样的数据集可以有比较好的训练效果,在之前我们单纯的DenseNet验证集的准确率为93.23%。现在验证集的准确率达到96.08%。


http://www.kler.cn/a/453057.html

相关文章:

  • VSCode 插件开发实战(十二):如何集成Git操作能力
  • asp.net core系统记录当前在线人数
  • 浅谈ORACLE中间件SOA BPM,IDM,OID,UCM,WebcenterPortal服务器如何做迁移切换
  • CI/CD是什么?
  • vue中proxy代理配置(测试一)
  • 2024-12-25-sklearn学习(20)无监督学习-双聚类 料峭春风吹酒醒,微冷,山头斜照却相迎。
  • 定位方式:css
  • 选择排序 冒泡排序 MySQL 架构
  • [python SQLAlchemy数据库操作入门]-08.ORM删除不再需要的股票记录
  • C项目 天天酷跑(下篇)
  • ZCC5090EA适用于TYPE-C接口,集成30V OVP功能, 最大1.5A充电电流,带NTC及使能功能,双节锂电升压充电芯片替代CS5090EA
  • 开源智能工业软件技术发展分析
  • “黄师日报”平安小程序springboot+论文源码调试讲解
  • Spring的注解@Autowired 是什么意思?
  • 【每日学点鸿蒙知识】长时任务、profiler allocation、事件订阅、getTagInfo、NativeWindow
  • 重温设计模式--状态模式
  • 基于Spring Boot的中国戏曲文化传播系统
  • Android 中的生产者-消费者模式实现
  • kubeadm 安装最新 k8s 集群
  • Ubuntu20.4 VPN+Docker代理配置
  • 正则表达式优化之实际应用场景优化
  • HBU深度学习实验17-优化算法比较和分析
  • 数据结构的基础与应用
  • 【贪吃蛇小游戏 - JavaIDEA】基于Java实现的贪吃蛇小游戏导入IDEA教程
  • HarmonyOS NEXT 实战之元服务:静态案例效果---查看国内航班服务
  • Go语言实现守护进程的挑战