深度学习camp-第J5周:DenseNet+SE-Net实战
- 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
- 🍖 原作者:K同学啊
📌 本周任务(自己改进一下):
●1. 在DenseNet系列算法中插入SE-Net通道注意力机制,并完成乳腺癌数据集识别
●2. 改进思路是否可以迁移到其他地方呢
●3. 验证集accuracy比较
一、前言
SE-Net(Squeeze-and-Excitation Network),它是 ImageNet 2017 竞赛的冠军模型,由 WMW 团队提出。具有复杂度低,参数少和计算量小的优点。SE-Net 的设计思路简单,容易与现有的网络结构(如 Inception 和 ResNet)结合,增强这些网络的性能。传统的网络结构(如 Inception)主要关注空间维度的特征提升,而 SE-Net 则将重点放在特征通道之间的关系上。这意味着它关注的是不同特征通道的重要性,而不仅仅是空间位置。SE-Net 通过学习每个特征通道的重要性,来自动调整特征的权重。具体来说,它会提升那些对当前任务有用的特征,同时抑制那些不重要的特征。这一过程被称为“特征重标定”。SE 模块是 SE-Net 的核心组成部分,负责实现特征重标定的功能。虽然具体的 SE 模块图未显示,但通常它包括两个主要步骤:首先通过全局平均池化获取特征通道的全局信息,然后通过全连接层学习每个通道的重要性,最后根据这些重要性调整特征通道的输出。如下图所示:
对于给定的一个输入x,其特征通道数为C’,经过一系列操作变换后通道数变为C,SE-Net会进行如下操作
- Squeeze 操作: 通过全局平均池化将每个通道的特征压缩成一个单一的数值,从而得到一个全局空间信息的通道描述符。这一步可以视为对每个通道的特征进行“压缩”,从而总结出通道的全局信息。
- Excitation 操作: 采用一个全连接的神经网络,通常包含两层,第一层用来降维(减少模型复杂度和参数量),第二层用来恢复维度。这个过程通过 Sigmoid 函数输出每个通道的权重系数,从而实现对每个通道的“激励”。
- Scale 操作: 最后,通过将 Excitation 操作的输出(即通道的权重系数)与原始输入按元素相乘,实现了对特征的重新缩放。这种按权重调整通道输出的方法,增强了模型对有用特征的捕捉能力,同时抑制了不重要的特征。
二、SE模块的通用性
SE模块的一个优点在于他可以直接应用于现有的网络结构中,以Inception和ResNet为例,我们只需要在Inception模块和Residual模块后面加上SE模块即可。
三、SE模块的代码实现
class SELayer(nn.Module):
def __init__(self, channels, reduction= 16):
super(SELayer, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1) #Squeeze操作
self.fc = nn.Sequential(
nn.Linear(channels, channels // reduction, bias=False),
nn.ReLU(inplace=True),
nn.Linear(channels // reduction, channels, bias=False),
nn.Sigmoid() #产生通道权重
)
def forward(self, x):
b, c, h, w = x.size()
y = self.avg_pool(x).view(b,c)
y = self.fc(y).view(b,c,1,1)
return x * y.expand_as(x)
四、SE模块插入到DenseNet网络中
我直接把网络结构放出来:
import torch
import torch.nn as nn
import torch.nn.functional as F
class SELayer(nn.Module):
def __init__(self, channels, reduction= 16):
super(SELayer, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1) #Squeeze操作
self.fc = nn.Sequential(
nn.Linear(channels, channels // reduction, bias=False),
nn.ReLU(inplace=True),
nn.Linear(channels // reduction, channels, bias=False),
nn.Sigmoid() #产生通道权重
)
def forward(self, x):
b, c, h, w = x.size()
y = self.avg_pool(x).view(b,c)
y = self.fc(y).view(b,c,1,1)
return x * y.expand_as(x)
#定义卷积块
class ConvBlock(nn.Module):
def __init__(self, in_channels, growth_rate):
super(ConvBlock, self).__init__()
self.conv1 = nn.Sequential(
nn.BatchNorm2d(in_channels),
nn.ReLU(inplace=True),
nn.Conv2d(in_channels, 4 * growth_rate, kernel_size=1, bias=False),
nn.BatchNorm2d(4 * growth_rate),
nn.ReLU(inplace=True),
nn.Conv2d(4 * growth_rate, growth_rate, kernel_size=3, padding=1, bias=False)
)
def forward(self, x):
x = self.conv1(x)
return x
class DenseBlock(nn.Module):
def __init__(self, num_layers, in_channels, growth_rate):
super(DenseBlock, self).__init__()
self.layers = nn.ModuleList()
for i in range(num_layers):
self.layers.append(ConvBlock(in_channels + i * growth_rate, growth_rate))
def forward(self, x):
for layer in self.layers:
new_features = layer(x)
x = torch.cat([x, new_features], dim=1)
return x
class TransitionBlock(nn.Module):
def __init__(self, in_channels, out_channels):
super(TransitionBlock, self).__init__()
self.bn = nn.BatchNorm2d(in_channels)
self.relu = nn.ReLU(inplace=True)
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
self.avg_pool = nn.AvgPool2d(kernel_size=2, stride=2)
def forward(self, x):
x = self.bn(x)
x = self.relu(x)
x = self.conv(x)
x = self.avg_pool(x)
return x
# 构建DenseNet
class DenseNet(nn.Module):
def __init__(self, block_config, num_classes=1000, growth_rate=32):
super(DenseNet, self).__init__()
num_init_features = 64
self.features = nn.Sequential(
nn.Conv2d(3, num_init_features, kernel_size=7, stride=2, padding=3, bias=False),
nn.BatchNorm2d(num_init_features),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
)
num_features = num_init_features
for i, num_blocks in enumerate(block_config):
block = DenseBlock(num_blocks, num_features, growth_rate)
self.features.add_module('denseblock{}'.format(i + 1), block)
num_features += num_blocks * growth_rate
if i != len(block_config) - 1:
trans = TransitionBlock(num_features, num_features // 2)
self.features.add_module('transition{}'.format(i + 1), trans)
num_features = num_features // 2
self.features.add_module('norm5', nn.BatchNorm2d(num_features))
self.classifier = nn.Linear(num_features, num_classes)
self.SE_layer = SELayer(num_features) # SE Layer should be initialized with the correct number of channels
def forward(self, x):
x = self.features(x)
x = self.SE_layer(x)
x = F.relu(x, inplace=True)
x = F.adaptive_avg_pool2d(x, (1, 1))
x = torch.flatten(x, 1)
x = self.classifier(x)
return x
# 定义不同的 DenseNet 配置
def DenseNet121(num_classes=3):
return DenseNet([6, 12, 24, 16], num_classes=num_classes)
def DenseNet169(num_classes=3):
return DenseNet([6, 12, 32, 32], num_classes=num_classes)
def DenseNet201(num_classes=3):
return DenseNet([6, 12, 48, 32], num_classes=num_classes)
import torchsummary
model1 = DenseNet121().cuda()
x = (3, 224, 224)
torchsummary.summary(model1,x)
model2 = DenseNet169().cuda()
model3 = DenseNet201().cuda()
代码输出:
----------------------------------------------------------------
Layer (type) Output Shape Param #
================================================================
Conv2d-1 [-1, 64, 112, 112] 9,408
BatchNorm2d-2 [-1, 64, 112, 112] 128
ReLU-3 [-1, 64, 112, 112] 0
MaxPool2d-4 [-1, 64, 56, 56] 0
BatchNorm2d-5 [-1, 64, 56, 56] 128
ReLU-6 [-1, 64, 56, 56] 0
Conv2d-7 [-1, 128, 56, 56] 8,192
BatchNorm2d-8 [-1, 128, 56, 56] 256
ReLU-9 [-1, 128, 56, 56] 0
Conv2d-10 [-1, 32, 56, 56] 36,864
ConvBlock-11 [-1, 32, 56, 56] 0
BatchNorm2d-12 [-1, 96, 56, 56] 192
ReLU-13 [-1, 96, 56, 56] 0
Conv2d-14 [-1, 128, 56, 56] 12,288
BatchNorm2d-15 [-1, 128, 56, 56] 256
ReLU-16 [-1, 128, 56, 56] 0
Conv2d-17 [-1, 32, 56, 56] 36,864
ConvBlock-18 [-1, 32, 56, 56] 0
BatchNorm2d-19 [-1, 128, 56, 56] 256
ReLU-20 [-1, 128, 56, 56] 0
Conv2d-21 [-1, 128, 56, 56] 16,384
BatchNorm2d-22 [-1, 128, 56, 56] 256
ReLU-23 [-1, 128, 56, 56] 0
Conv2d-24 [-1, 32, 56, 56] 36,864
ConvBlock-25 [-1, 32, 56, 56] 0
BatchNorm2d-26 [-1, 160, 56, 56] 320
ReLU-27 [-1, 160, 56, 56] 0
Conv2d-28 [-1, 128, 56, 56] 20,480
BatchNorm2d-29 [-1, 128, 56, 56] 256
ReLU-30 [-1, 128, 56, 56] 0
Conv2d-31 [-1, 32, 56, 56] 36,864
ConvBlock-32 [-1, 32, 56, 56] 0
BatchNorm2d-33 [-1, 192, 56, 56] 384
ReLU-34 [-1, 192, 56, 56] 0
Conv2d-35 [-1, 128, 56, 56] 24,576
BatchNorm2d-36 [-1, 128, 56, 56] 256
ReLU-37 [-1, 128, 56, 56] 0
Conv2d-38 [-1, 32, 56, 56] 36,864
ConvBlock-39 [-1, 32, 56, 56] 0
BatchNorm2d-40 [-1, 224, 56, 56] 448
ReLU-41 [-1, 224, 56, 56] 0
Conv2d-42 [-1, 128, 56, 56] 28,672
BatchNorm2d-43 [-1, 128, 56, 56] 256
ReLU-44 [-1, 128, 56, 56] 0
Conv2d-45 [-1, 32, 56, 56] 36,864
ConvBlock-46 [-1, 32, 56, 56] 0
DenseBlock-47 [-1, 256, 56, 56] 0
BatchNorm2d-48 [-1, 256, 56, 56] 512
ReLU-49 [-1, 256, 56, 56] 0
Conv2d-50 [-1, 128, 56, 56] 32,768
AvgPool2d-51 [-1, 128, 28, 28] 0
TransitionBlock-52 [-1, 128, 28, 28] 0
BatchNorm2d-53 [-1, 128, 28, 28] 256
ReLU-54 [-1, 128, 28, 28] 0
Conv2d-55 [-1, 128, 28, 28] 16,384
BatchNorm2d-56 [-1, 128, 28, 28] 256
ReLU-57 [-1, 128, 28, 28] 0
Conv2d-58 [-1, 32, 28, 28] 36,864
ConvBlock-59 [-1, 32, 28, 28] 0
BatchNorm2d-60 [-1, 160, 28, 28] 320
ReLU-61 [-1, 160, 28, 28] 0
Conv2d-62 [-1, 128, 28, 28] 20,480
BatchNorm2d-63 [-1, 128, 28, 28] 256
ReLU-64 [-1, 128, 28, 28] 0
Conv2d-65 [-1, 32, 28, 28] 36,864
ConvBlock-66 [-1, 32, 28, 28] 0
BatchNorm2d-67 [-1, 192, 28, 28] 384
ReLU-68 [-1, 192, 28, 28] 0
Conv2d-69 [-1, 128, 28, 28] 24,576
BatchNorm2d-70 [-1, 128, 28, 28] 256
ReLU-71 [-1, 128, 28, 28] 0
Conv2d-72 [-1, 32, 28, 28] 36,864
ConvBlock-73 [-1, 32, 28, 28] 0
BatchNorm2d-74 [-1, 224, 28, 28] 448
ReLU-75 [-1, 224, 28, 28] 0
Conv2d-76 [-1, 128, 28, 28] 28,672
BatchNorm2d-77 [-1, 128, 28, 28] 256
ReLU-78 [-1, 128, 28, 28] 0
Conv2d-79 [-1, 32, 28, 28] 36,864
ConvBlock-80 [-1, 32, 28, 28] 0
BatchNorm2d-81 [-1, 256, 28, 28] 512
ReLU-82 [-1, 256, 28, 28] 0
Conv2d-83 [-1, 128, 28, 28] 32,768
BatchNorm2d-84 [-1, 128, 28, 28] 256
ReLU-85 [-1, 128, 28, 28] 0
Conv2d-86 [-1, 32, 28, 28] 36,864
ConvBlock-87 [-1, 32, 28, 28] 0
BatchNorm2d-88 [-1, 288, 28, 28] 576
ReLU-89 [-1, 288, 28, 28] 0
Conv2d-90 [-1, 128, 28, 28] 36,864
BatchNorm2d-91 [-1, 128, 28, 28] 256
ReLU-92 [-1, 128, 28, 28] 0
Conv2d-93 [-1, 32, 28, 28] 36,864
ConvBlock-94 [-1, 32, 28, 28] 0
BatchNorm2d-95 [-1, 320, 28, 28] 640
ReLU-96 [-1, 320, 28, 28] 0
Conv2d-97 [-1, 128, 28, 28] 40,960
BatchNorm2d-98 [-1, 128, 28, 28] 256
ReLU-99 [-1, 128, 28, 28] 0
Conv2d-100 [-1, 32, 28, 28] 36,864
ConvBlock-101 [-1, 32, 28, 28] 0
BatchNorm2d-102 [-1, 352, 28, 28] 704
ReLU-103 [-1, 352, 28, 28] 0
Conv2d-104 [-1, 128, 28, 28] 45,056
BatchNorm2d-105 [-1, 128, 28, 28] 256
ReLU-106 [-1, 128, 28, 28] 0
Conv2d-107 [-1, 32, 28, 28] 36,864
ConvBlock-108 [-1, 32, 28, 28] 0
BatchNorm2d-109 [-1, 384, 28, 28] 768
ReLU-110 [-1, 384, 28, 28] 0
Conv2d-111 [-1, 128, 28, 28] 49,152
BatchNorm2d-112 [-1, 128, 28, 28] 256
ReLU-113 [-1, 128, 28, 28] 0
Conv2d-114 [-1, 32, 28, 28] 36,864
ConvBlock-115 [-1, 32, 28, 28] 0
BatchNorm2d-116 [-1, 416, 28, 28] 832
ReLU-117 [-1, 416, 28, 28] 0
Conv2d-118 [-1, 128, 28, 28] 53,248
BatchNorm2d-119 [-1, 128, 28, 28] 256
ReLU-120 [-1, 128, 28, 28] 0
Conv2d-121 [-1, 32, 28, 28] 36,864
ConvBlock-122 [-1, 32, 28, 28] 0
BatchNorm2d-123 [-1, 448, 28, 28] 896
ReLU-124 [-1, 448, 28, 28] 0
Conv2d-125 [-1, 128, 28, 28] 57,344
BatchNorm2d-126 [-1, 128, 28, 28] 256
ReLU-127 [-1, 128, 28, 28] 0
Conv2d-128 [-1, 32, 28, 28] 36,864
ConvBlock-129 [-1, 32, 28, 28] 0
BatchNorm2d-130 [-1, 480, 28, 28] 960
ReLU-131 [-1, 480, 28, 28] 0
Conv2d-132 [-1, 128, 28, 28] 61,440
BatchNorm2d-133 [-1, 128, 28, 28] 256
ReLU-134 [-1, 128, 28, 28] 0
Conv2d-135 [-1, 32, 28, 28] 36,864
ConvBlock-136 [-1, 32, 28, 28] 0
DenseBlock-137 [-1, 512, 28, 28] 0
BatchNorm2d-138 [-1, 512, 28, 28] 1,024
ReLU-139 [-1, 512, 28, 28] 0
Conv2d-140 [-1, 256, 28, 28] 131,072
AvgPool2d-141 [-1, 256, 14, 14] 0
TransitionBlock-142 [-1, 256, 14, 14] 0
BatchNorm2d-143 [-1, 256, 14, 14] 512
ReLU-144 [-1, 256, 14, 14] 0
Conv2d-145 [-1, 128, 14, 14] 32,768
BatchNorm2d-146 [-1, 128, 14, 14] 256
ReLU-147 [-1, 128, 14, 14] 0
Conv2d-148 [-1, 32, 14, 14] 36,864
ConvBlock-149 [-1, 32, 14, 14] 0
BatchNorm2d-150 [-1, 288, 14, 14] 576
ReLU-151 [-1, 288, 14, 14] 0
Conv2d-152 [-1, 128, 14, 14] 36,864
BatchNorm2d-153 [-1, 128, 14, 14] 256
ReLU-154 [-1, 128, 14, 14] 0
Conv2d-155 [-1, 32, 14, 14] 36,864
ConvBlock-156 [-1, 32, 14, 14] 0
BatchNorm2d-157 [-1, 320, 14, 14] 640
ReLU-158 [-1, 320, 14, 14] 0
Conv2d-159 [-1, 128, 14, 14] 40,960
BatchNorm2d-160 [-1, 128, 14, 14] 256
ReLU-161 [-1, 128, 14, 14] 0
Conv2d-162 [-1, 32, 14, 14] 36,864
ConvBlock-163 [-1, 32, 14, 14] 0
BatchNorm2d-164 [-1, 352, 14, 14] 704
ReLU-165 [-1, 352, 14, 14] 0
Conv2d-166 [-1, 128, 14, 14] 45,056
BatchNorm2d-167 [-1, 128, 14, 14] 256
ReLU-168 [-1, 128, 14, 14] 0
Conv2d-169 [-1, 32, 14, 14] 36,864
ConvBlock-170 [-1, 32, 14, 14] 0
BatchNorm2d-171 [-1, 384, 14, 14] 768
ReLU-172 [-1, 384, 14, 14] 0
Conv2d-173 [-1, 128, 14, 14] 49,152
BatchNorm2d-174 [-1, 128, 14, 14] 256
ReLU-175 [-1, 128, 14, 14] 0
Conv2d-176 [-1, 32, 14, 14] 36,864
ConvBlock-177 [-1, 32, 14, 14] 0
BatchNorm2d-178 [-1, 416, 14, 14] 832
ReLU-179 [-1, 416, 14, 14] 0
Conv2d-180 [-1, 128, 14, 14] 53,248
BatchNorm2d-181 [-1, 128, 14, 14] 256
ReLU-182 [-1, 128, 14, 14] 0
Conv2d-183 [-1, 32, 14, 14] 36,864
ConvBlock-184 [-1, 32, 14, 14] 0
BatchNorm2d-185 [-1, 448, 14, 14] 896
ReLU-186 [-1, 448, 14, 14] 0
Conv2d-187 [-1, 128, 14, 14] 57,344
BatchNorm2d-188 [-1, 128, 14, 14] 256
ReLU-189 [-1, 128, 14, 14] 0
Conv2d-190 [-1, 32, 14, 14] 36,864
ConvBlock-191 [-1, 32, 14, 14] 0
BatchNorm2d-192 [-1, 480, 14, 14] 960
ReLU-193 [-1, 480, 14, 14] 0
Conv2d-194 [-1, 128, 14, 14] 61,440
BatchNorm2d-195 [-1, 128, 14, 14] 256
ReLU-196 [-1, 128, 14, 14] 0
Conv2d-197 [-1, 32, 14, 14] 36,864
ConvBlock-198 [-1, 32, 14, 14] 0
BatchNorm2d-199 [-1, 512, 14, 14] 1,024
ReLU-200 [-1, 512, 14, 14] 0
Conv2d-201 [-1, 128, 14, 14] 65,536
BatchNorm2d-202 [-1, 128, 14, 14] 256
ReLU-203 [-1, 128, 14, 14] 0
Conv2d-204 [-1, 32, 14, 14] 36,864
ConvBlock-205 [-1, 32, 14, 14] 0
BatchNorm2d-206 [-1, 544, 14, 14] 1,088
ReLU-207 [-1, 544, 14, 14] 0
Conv2d-208 [-1, 128, 14, 14] 69,632
BatchNorm2d-209 [-1, 128, 14, 14] 256
ReLU-210 [-1, 128, 14, 14] 0
Conv2d-211 [-1, 32, 14, 14] 36,864
ConvBlock-212 [-1, 32, 14, 14] 0
BatchNorm2d-213 [-1, 576, 14, 14] 1,152
ReLU-214 [-1, 576, 14, 14] 0
Conv2d-215 [-1, 128, 14, 14] 73,728
BatchNorm2d-216 [-1, 128, 14, 14] 256
ReLU-217 [-1, 128, 14, 14] 0
Conv2d-218 [-1, 32, 14, 14] 36,864
ConvBlock-219 [-1, 32, 14, 14] 0
BatchNorm2d-220 [-1, 608, 14, 14] 1,216
ReLU-221 [-1, 608, 14, 14] 0
Conv2d-222 [-1, 128, 14, 14] 77,824
BatchNorm2d-223 [-1, 128, 14, 14] 256
ReLU-224 [-1, 128, 14, 14] 0
Conv2d-225 [-1, 32, 14, 14] 36,864
ConvBlock-226 [-1, 32, 14, 14] 0
BatchNorm2d-227 [-1, 640, 14, 14] 1,280
ReLU-228 [-1, 640, 14, 14] 0
Conv2d-229 [-1, 128, 14, 14] 81,920
BatchNorm2d-230 [-1, 128, 14, 14] 256
ReLU-231 [-1, 128, 14, 14] 0
Conv2d-232 [-1, 32, 14, 14] 36,864
ConvBlock-233 [-1, 32, 14, 14] 0
BatchNorm2d-234 [-1, 672, 14, 14] 1,344
ReLU-235 [-1, 672, 14, 14] 0
Conv2d-236 [-1, 128, 14, 14] 86,016
BatchNorm2d-237 [-1, 128, 14, 14] 256
ReLU-238 [-1, 128, 14, 14] 0
Conv2d-239 [-1, 32, 14, 14] 36,864
ConvBlock-240 [-1, 32, 14, 14] 0
BatchNorm2d-241 [-1, 704, 14, 14] 1,408
ReLU-242 [-1, 704, 14, 14] 0
Conv2d-243 [-1, 128, 14, 14] 90,112
BatchNorm2d-244 [-1, 128, 14, 14] 256
ReLU-245 [-1, 128, 14, 14] 0
Conv2d-246 [-1, 32, 14, 14] 36,864
ConvBlock-247 [-1, 32, 14, 14] 0
BatchNorm2d-248 [-1, 736, 14, 14] 1,472
ReLU-249 [-1, 736, 14, 14] 0
Conv2d-250 [-1, 128, 14, 14] 94,208
BatchNorm2d-251 [-1, 128, 14, 14] 256
ReLU-252 [-1, 128, 14, 14] 0
Conv2d-253 [-1, 32, 14, 14] 36,864
ConvBlock-254 [-1, 32, 14, 14] 0
BatchNorm2d-255 [-1, 768, 14, 14] 1,536
ReLU-256 [-1, 768, 14, 14] 0
Conv2d-257 [-1, 128, 14, 14] 98,304
BatchNorm2d-258 [-1, 128, 14, 14] 256
ReLU-259 [-1, 128, 14, 14] 0
Conv2d-260 [-1, 32, 14, 14] 36,864
ConvBlock-261 [-1, 32, 14, 14] 0
BatchNorm2d-262 [-1, 800, 14, 14] 1,600
ReLU-263 [-1, 800, 14, 14] 0
Conv2d-264 [-1, 128, 14, 14] 102,400
BatchNorm2d-265 [-1, 128, 14, 14] 256
ReLU-266 [-1, 128, 14, 14] 0
Conv2d-267 [-1, 32, 14, 14] 36,864
ConvBlock-268 [-1, 32, 14, 14] 0
BatchNorm2d-269 [-1, 832, 14, 14] 1,664
ReLU-270 [-1, 832, 14, 14] 0
Conv2d-271 [-1, 128, 14, 14] 106,496
BatchNorm2d-272 [-1, 128, 14, 14] 256
ReLU-273 [-1, 128, 14, 14] 0
Conv2d-274 [-1, 32, 14, 14] 36,864
ConvBlock-275 [-1, 32, 14, 14] 0
BatchNorm2d-276 [-1, 864, 14, 14] 1,728
ReLU-277 [-1, 864, 14, 14] 0
Conv2d-278 [-1, 128, 14, 14] 110,592
BatchNorm2d-279 [-1, 128, 14, 14] 256
ReLU-280 [-1, 128, 14, 14] 0
Conv2d-281 [-1, 32, 14, 14] 36,864
ConvBlock-282 [-1, 32, 14, 14] 0
BatchNorm2d-283 [-1, 896, 14, 14] 1,792
ReLU-284 [-1, 896, 14, 14] 0
Conv2d-285 [-1, 128, 14, 14] 114,688
BatchNorm2d-286 [-1, 128, 14, 14] 256
ReLU-287 [-1, 128, 14, 14] 0
Conv2d-288 [-1, 32, 14, 14] 36,864
ConvBlock-289 [-1, 32, 14, 14] 0
BatchNorm2d-290 [-1, 928, 14, 14] 1,856
ReLU-291 [-1, 928, 14, 14] 0
Conv2d-292 [-1, 128, 14, 14] 118,784
BatchNorm2d-293 [-1, 128, 14, 14] 256
ReLU-294 [-1, 128, 14, 14] 0
Conv2d-295 [-1, 32, 14, 14] 36,864
ConvBlock-296 [-1, 32, 14, 14] 0
BatchNorm2d-297 [-1, 960, 14, 14] 1,920
ReLU-298 [-1, 960, 14, 14] 0
Conv2d-299 [-1, 128, 14, 14] 122,880
BatchNorm2d-300 [-1, 128, 14, 14] 256
ReLU-301 [-1, 128, 14, 14] 0
Conv2d-302 [-1, 32, 14, 14] 36,864
ConvBlock-303 [-1, 32, 14, 14] 0
BatchNorm2d-304 [-1, 992, 14, 14] 1,984
ReLU-305 [-1, 992, 14, 14] 0
Conv2d-306 [-1, 128, 14, 14] 126,976
BatchNorm2d-307 [-1, 128, 14, 14] 256
ReLU-308 [-1, 128, 14, 14] 0
Conv2d-309 [-1, 32, 14, 14] 36,864
ConvBlock-310 [-1, 32, 14, 14] 0
DenseBlock-311 [-1, 1024, 14, 14] 0
BatchNorm2d-312 [-1, 1024, 14, 14] 2,048
ReLU-313 [-1, 1024, 14, 14] 0
Conv2d-314 [-1, 512, 14, 14] 524,288
AvgPool2d-315 [-1, 512, 7, 7] 0
TransitionBlock-316 [-1, 512, 7, 7] 0
BatchNorm2d-317 [-1, 512, 7, 7] 1,024
ReLU-318 [-1, 512, 7, 7] 0
Conv2d-319 [-1, 128, 7, 7] 65,536
BatchNorm2d-320 [-1, 128, 7, 7] 256
ReLU-321 [-1, 128, 7, 7] 0
Conv2d-322 [-1, 32, 7, 7] 36,864
ConvBlock-323 [-1, 32, 7, 7] 0
BatchNorm2d-324 [-1, 544, 7, 7] 1,088
ReLU-325 [-1, 544, 7, 7] 0
Conv2d-326 [-1, 128, 7, 7] 69,632
BatchNorm2d-327 [-1, 128, 7, 7] 256
ReLU-328 [-1, 128, 7, 7] 0
Conv2d-329 [-1, 32, 7, 7] 36,864
ConvBlock-330 [-1, 32, 7, 7] 0
BatchNorm2d-331 [-1, 576, 7, 7] 1,152
ReLU-332 [-1, 576, 7, 7] 0
Conv2d-333 [-1, 128, 7, 7] 73,728
BatchNorm2d-334 [-1, 128, 7, 7] 256
ReLU-335 [-1, 128, 7, 7] 0
Conv2d-336 [-1, 32, 7, 7] 36,864
ConvBlock-337 [-1, 32, 7, 7] 0
BatchNorm2d-338 [-1, 608, 7, 7] 1,216
ReLU-339 [-1, 608, 7, 7] 0
Conv2d-340 [-1, 128, 7, 7] 77,824
BatchNorm2d-341 [-1, 128, 7, 7] 256
ReLU-342 [-1, 128, 7, 7] 0
Conv2d-343 [-1, 32, 7, 7] 36,864
ConvBlock-344 [-1, 32, 7, 7] 0
BatchNorm2d-345 [-1, 640, 7, 7] 1,280
ReLU-346 [-1, 640, 7, 7] 0
Conv2d-347 [-1, 128, 7, 7] 81,920
BatchNorm2d-348 [-1, 128, 7, 7] 256
ReLU-349 [-1, 128, 7, 7] 0
Conv2d-350 [-1, 32, 7, 7] 36,864
ConvBlock-351 [-1, 32, 7, 7] 0
BatchNorm2d-352 [-1, 672, 7, 7] 1,344
ReLU-353 [-1, 672, 7, 7] 0
Conv2d-354 [-1, 128, 7, 7] 86,016
BatchNorm2d-355 [-1, 128, 7, 7] 256
ReLU-356 [-1, 128, 7, 7] 0
Conv2d-357 [-1, 32, 7, 7] 36,864
ConvBlock-358 [-1, 32, 7, 7] 0
BatchNorm2d-359 [-1, 704, 7, 7] 1,408
ReLU-360 [-1, 704, 7, 7] 0
Conv2d-361 [-1, 128, 7, 7] 90,112
BatchNorm2d-362 [-1, 128, 7, 7] 256
ReLU-363 [-1, 128, 7, 7] 0
Conv2d-364 [-1, 32, 7, 7] 36,864
ConvBlock-365 [-1, 32, 7, 7] 0
BatchNorm2d-366 [-1, 736, 7, 7] 1,472
ReLU-367 [-1, 736, 7, 7] 0
Conv2d-368 [-1, 128, 7, 7] 94,208
BatchNorm2d-369 [-1, 128, 7, 7] 256
ReLU-370 [-1, 128, 7, 7] 0
Conv2d-371 [-1, 32, 7, 7] 36,864
ConvBlock-372 [-1, 32, 7, 7] 0
BatchNorm2d-373 [-1, 768, 7, 7] 1,536
ReLU-374 [-1, 768, 7, 7] 0
Conv2d-375 [-1, 128, 7, 7] 98,304
BatchNorm2d-376 [-1, 128, 7, 7] 256
ReLU-377 [-1, 128, 7, 7] 0
Conv2d-378 [-1, 32, 7, 7] 36,864
ConvBlock-379 [-1, 32, 7, 7] 0
BatchNorm2d-380 [-1, 800, 7, 7] 1,600
ReLU-381 [-1, 800, 7, 7] 0
Conv2d-382 [-1, 128, 7, 7] 102,400
BatchNorm2d-383 [-1, 128, 7, 7] 256
ReLU-384 [-1, 128, 7, 7] 0
Conv2d-385 [-1, 32, 7, 7] 36,864
ConvBlock-386 [-1, 32, 7, 7] 0
BatchNorm2d-387 [-1, 832, 7, 7] 1,664
ReLU-388 [-1, 832, 7, 7] 0
Conv2d-389 [-1, 128, 7, 7] 106,496
BatchNorm2d-390 [-1, 128, 7, 7] 256
ReLU-391 [-1, 128, 7, 7] 0
Conv2d-392 [-1, 32, 7, 7] 36,864
ConvBlock-393 [-1, 32, 7, 7] 0
BatchNorm2d-394 [-1, 864, 7, 7] 1,728
ReLU-395 [-1, 864, 7, 7] 0
Conv2d-396 [-1, 128, 7, 7] 110,592
BatchNorm2d-397 [-1, 128, 7, 7] 256
ReLU-398 [-1, 128, 7, 7] 0
Conv2d-399 [-1, 32, 7, 7] 36,864
ConvBlock-400 [-1, 32, 7, 7] 0
BatchNorm2d-401 [-1, 896, 7, 7] 1,792
ReLU-402 [-1, 896, 7, 7] 0
Conv2d-403 [-1, 128, 7, 7] 114,688
BatchNorm2d-404 [-1, 128, 7, 7] 256
ReLU-405 [-1, 128, 7, 7] 0
Conv2d-406 [-1, 32, 7, 7] 36,864
ConvBlock-407 [-1, 32, 7, 7] 0
BatchNorm2d-408 [-1, 928, 7, 7] 1,856
ReLU-409 [-1, 928, 7, 7] 0
Conv2d-410 [-1, 128, 7, 7] 118,784
BatchNorm2d-411 [-1, 128, 7, 7] 256
ReLU-412 [-1, 128, 7, 7] 0
Conv2d-413 [-1, 32, 7, 7] 36,864
ConvBlock-414 [-1, 32, 7, 7] 0
BatchNorm2d-415 [-1, 960, 7, 7] 1,920
ReLU-416 [-1, 960, 7, 7] 0
Conv2d-417 [-1, 128, 7, 7] 122,880
BatchNorm2d-418 [-1, 128, 7, 7] 256
ReLU-419 [-1, 128, 7, 7] 0
Conv2d-420 [-1, 32, 7, 7] 36,864
ConvBlock-421 [-1, 32, 7, 7] 0
BatchNorm2d-422 [-1, 992, 7, 7] 1,984
ReLU-423 [-1, 992, 7, 7] 0
Conv2d-424 [-1, 128, 7, 7] 126,976
BatchNorm2d-425 [-1, 128, 7, 7] 256
ReLU-426 [-1, 128, 7, 7] 0
Conv2d-427 [-1, 32, 7, 7] 36,864
ConvBlock-428 [-1, 32, 7, 7] 0
DenseBlock-429 [-1, 1024, 7, 7] 0
BatchNorm2d-430 [-1, 1024, 7, 7] 2,048
AdaptiveAvgPool2d-431 [-1, 1024, 1, 1] 0
Linear-432 [-1, 64] 65,536
ReLU-433 [-1, 64] 0
Linear-434 [-1, 1024] 65,536
Sigmoid-435 [-1, 1024] 0
SELayer-436 [-1, 1024, 7, 7] 0
Linear-437 [-1, 3] 3,075
================================================================
Total params: 7,088,003
Trainable params: 7,088,003
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.57
Forward/backward pass size (MB): 315.27
Params size (MB): 27.04
Estimated Total Size (MB): 342.88
----------------------------------------------------------------
五、用于乳腺癌识别以及比较网络之间的差异
这次数据的预处理部分与前几周完全相同,我们不在给出代码,直接看训练部分的代码:
import copy
from torch.optim.lr_scheduler import ReduceLROnPlateau
opt1 = torch.optim.Adam(model1.parameters(), lr= 1e-4)
scheduler1 = ReduceLROnPlateau(opt1, mode='min', factor=0.1, patience=5, verbose=True) # 当指标(如损失)连续 5 次没有改善时,将学习率乘以 0.1
opt2 = torch.optim.Adam(model2.parameters(), lr= 1e-4)
scheduler2 = ReduceLROnPlateau(opt2, mode='min', factor=0.1, patience=5, verbose=True)
opt3 = torch.optim.Adam(model3.parameters(), lr= 1e-4)
scheduler3 = ReduceLROnPlateau(opt3, mode='min', factor=0.1, patience=5, verbose=True)
loss_fn = nn.CrossEntropyLoss() # 交叉熵
epochs = 32
train_loss = []
train_acc = []
test_loss = []
test_acc = []
best_acc1 = 0 # 设置一个最佳准确率,作为最佳模型的判别指标
best_acc2 = 0
best_acc3 = 0
for epoch in range(epochs):
model1.train()
epoch_train_acc, epoch_train_loss = train(train_dl, model1, loss_fn, opt1)
model1.eval()
epoch_test_acc, epoch_test_loss = test(test_dl, model1, loss_fn)
scheduler1.step(epoch_test_loss)
if epoch_test_acc > best_acc1:
best_acc1 = epoch_test_acc
best_model1 = copy.deepcopy(model1)
train_acc.append(epoch_train_acc)
train_loss.append(epoch_train_loss)
test_acc.append(epoch_test_acc)
test_loss.append(epoch_test_loss)
# 获取当前的学习率
lr = opt1.state_dict()['param_groups'][0]['lr']
template = ('Epoch:{:2d}, Train_acc:{:.1f}%, Train_loss:{:.3f}, Test_acc:{:.1f}%, Test_loss:{:.3f}, Lr:{:.2E}')
print(template.format(epoch+1, epoch_train_acc*100, epoch_train_loss,
epoch_test_acc*100, epoch_test_loss, lr))
# 保存最佳模型到文件中
PATH = './best_model1.pth' # 保存的参数文件名
torch.save(best_model1.state_dict(), PATH)
print('Done1')
for epoch in range(epochs):
model2.train()
epoch_train_acc, epoch_train_loss = train(train_dl, model2, loss_fn, opt2)
model2.eval()
epoch_test_acc, epoch_test_loss = test(test_dl, model2, loss_fn)
scheduler2.step(epoch_test_loss)
if epoch_test_acc > best_acc2:
best_acc2 = epoch_test_acc
best_model2 = copy.deepcopy(model2)
train_acc.append(epoch_train_acc)
train_loss.append(epoch_train_loss)
test_acc.append(epoch_test_acc)
test_loss.append(epoch_test_loss)
# 获取当前的学习率
lr = opt2.state_dict()['param_groups'][0]['lr']
template = ('Epoch:{:2d}, Train_acc:{:.1f}%, Train_loss:{:.3f}, Test_acc:{:.1f}%, Test_loss:{:.3f}, Lr:{:.2E}')
print(template.format(epoch+1, epoch_train_acc*100, epoch_train_loss,
epoch_test_acc*100, epoch_test_loss, lr))
# 保存最佳模型到文件中
PATH = './best_model2.pth' # 保存的参数文件名
torch.save(best_model2.state_dict(), PATH)
print('Done2')
for epoch in range(epochs):
model3.train()
epoch_train_acc, epoch_train_loss = train(train_dl, model3, loss_fn, opt3)
model3.eval()
epoch_test_acc, epoch_test_loss = test(test_dl, model3, loss_fn)
scheduler3.step(epoch_test_loss)
if epoch_test_acc > best_acc3:
best_acc3 = epoch_test_acc
best_model3 = copy.deepcopy(model3)
train_acc.append(epoch_train_acc)
train_loss.append(epoch_train_loss)
test_acc.append(epoch_test_acc)
test_loss.append(epoch_test_loss)
# 获取当前的学习率
lr = opt3.state_dict()['param_groups'][0]['lr']
template = ('Epoch:{:2d}, Train_acc:{:.1f}%, Train_loss:{:.3f}, Test_acc:{:.1f}%, Test_loss:{:.3f}, Lr:{:.2E}')
print(template.format(epoch+1, epoch_train_acc*100, epoch_train_loss,
epoch_test_acc*100, epoch_test_loss, lr))
# 保存最佳模型到文件中
PATH = './best_model3.pth' # 保存的参数文件名
torch.save(best_model3.state_dict(), PATH)
print('Done3')
代码输出:
Epoch: 1, Train_acc:94.9%, Train_loss:0.132, Test_acc:94.2%, Test_loss:0.153, Lr:1.00E-04
Epoch: 2, Train_acc:95.5%, Train_loss:0.113, Test_acc:95.1%, Test_loss:0.137, Lr:1.00E-04
Epoch: 3, Train_acc:96.4%, Train_loss:0.098, Test_acc:94.4%, Test_loss:0.145, Lr:1.00E-04
Epoch: 4, Train_acc:96.4%, Train_loss:0.101, Test_acc:91.0%, Test_loss:0.317, Lr:1.00E-04
Epoch: 5, Train_acc:96.7%, Train_loss:0.092, Test_acc:93.6%, Test_loss:0.193, Lr:1.00E-04
Epoch: 6, Train_acc:96.9%, Train_loss:0.079, Test_acc:91.9%, Test_loss:0.217, Lr:1.00E-04
Epoch: 7, Train_acc:97.5%, Train_loss:0.065, Test_acc:95.1%, Test_loss:0.153, Lr:1.00E-04
Epoch: 8, Train_acc:97.7%, Train_loss:0.062, Test_acc:93.9%, Test_loss:0.154, Lr:1.00E-05
Epoch: 9, Train_acc:99.1%, Train_loss:0.031, Test_acc:96.2%, Test_loss:0.115, Lr:1.00E-05
Epoch:10, Train_acc:99.6%, Train_loss:0.016, Test_acc:96.4%, Test_loss:0.112, Lr:1.00E-05
Epoch:11, Train_acc:99.7%, Train_loss:0.015, Test_acc:96.3%, Test_loss:0.116, Lr:1.00E-05
Epoch:12, Train_acc:99.7%, Train_loss:0.012, Test_acc:96.5%, Test_loss:0.115, Lr:1.00E-05
Epoch:13, Train_acc:99.9%, Train_loss:0.008, Test_acc:96.5%, Test_loss:0.123, Lr:1.00E-05
Epoch:14, Train_acc:99.8%, Train_loss:0.013, Test_acc:96.3%, Test_loss:0.126, Lr:1.00E-05
Epoch:15, Train_acc:99.8%, Train_loss:0.008, Test_acc:96.2%, Test_loss:0.125, Lr:1.00E-05
Epoch:16, Train_acc:99.9%, Train_loss:0.006, Test_acc:96.3%, Test_loss:0.134, Lr:1.00E-06
Epoch:17, Train_acc:99.9%, Train_loss:0.006, Test_acc:96.4%, Test_loss:0.126, Lr:1.00E-06
Epoch:18, Train_acc:99.9%, Train_loss:0.006, Test_acc:96.5%, Test_loss:0.126, Lr:1.00E-06
Epoch:19, Train_acc:99.9%, Train_loss:0.005, Test_acc:96.5%, Test_loss:0.125, Lr:1.00E-06
Epoch:20, Train_acc:99.9%, Train_loss:0.005, Test_acc:96.4%, Test_loss:0.125, Lr:1.00E-06
Epoch:21, Train_acc:99.9%, Train_loss:0.007, Test_acc:96.6%, Test_loss:0.121, Lr:1.00E-06
Epoch:22, Train_acc:100.0%, Train_loss:0.005, Test_acc:96.4%, Test_loss:0.126, Lr:1.00E-07
Epoch:23, Train_acc:99.9%, Train_loss:0.004, Test_acc:96.5%, Test_loss:0.122, Lr:1.00E-07
Epoch:24, Train_acc:99.9%, Train_loss:0.011, Test_acc:96.2%, Test_loss:0.130, Lr:1.00E-07
Epoch:25, Train_acc:100.0%, Train_loss:0.004, Test_acc:96.6%, Test_loss:0.127, Lr:1.00E-07
Epoch:26, Train_acc:99.9%, Train_loss:0.005, Test_acc:96.6%, Test_loss:0.126, Lr:1.00E-07
Epoch:27, Train_acc:99.9%, Train_loss:0.004, Test_acc:96.6%, Test_loss:0.123, Lr:1.00E-07
Epoch:28, Train_acc:99.9%, Train_loss:0.005, Test_acc:96.5%, Test_loss:0.120, Lr:1.00E-08
Epoch:29, Train_acc:99.9%, Train_loss:0.005, Test_acc:96.6%, Test_loss:0.122, Lr:1.00E-08
Epoch:30, Train_acc:99.9%, Train_loss:0.005, Test_acc:96.5%, Test_loss:0.125, Lr:1.00E-08
Epoch:31, Train_acc:99.9%, Train_loss:0.005, Test_acc:96.2%, Test_loss:0.127, Lr:1.00E-08
Epoch:32, Train_acc:100.0%, Train_loss:0.004, Test_acc:96.6%, Test_loss:0.121, Lr:1.00E-08
Done1
Epoch: 1, Train_acc:93.5%, Train_loss:0.161, Test_acc:90.6%, Test_loss:0.243, Lr:1.00E-04
Epoch: 2, Train_acc:94.3%, Train_loss:0.145, Test_acc:93.8%, Test_loss:0.183, Lr:1.00E-04
Epoch: 3, Train_acc:94.6%, Train_loss:0.132, Test_acc:92.2%, Test_loss:0.213, Lr:1.00E-04
Epoch: 4, Train_acc:95.0%, Train_loss:0.126, Test_acc:89.1%, Test_loss:0.299, Lr:1.00E-04
Epoch: 5, Train_acc:95.5%, Train_loss:0.115, Test_acc:89.7%, Test_loss:0.248, Lr:1.00E-04
Epoch: 6, Train_acc:96.2%, Train_loss:0.102, Test_acc:94.1%, Test_loss:0.152, Lr:1.00E-04
Epoch: 7, Train_acc:96.4%, Train_loss:0.093, Test_acc:93.9%, Test_loss:0.172, Lr:1.00E-04
Epoch: 8, Train_acc:96.8%, Train_loss:0.077, Test_acc:94.1%, Test_loss:0.165, Lr:1.00E-04
Epoch: 9, Train_acc:97.3%, Train_loss:0.072, Test_acc:92.5%, Test_loss:0.268, Lr:1.00E-04
Epoch:10, Train_acc:97.8%, Train_loss:0.059, Test_acc:90.1%, Test_loss:0.384, Lr:1.00E-04
Epoch:11, Train_acc:97.6%, Train_loss:0.064, Test_acc:94.0%, Test_loss:0.204, Lr:1.00E-04
Epoch:12, Train_acc:97.5%, Train_loss:0.063, Test_acc:92.9%, Test_loss:0.234, Lr:1.00E-05
Epoch:13, Train_acc:99.4%, Train_loss:0.023, Test_acc:94.8%, Test_loss:0.162, Lr:1.00E-05
Epoch:14, Train_acc:99.8%, Train_loss:0.013, Test_acc:95.1%, Test_loss:0.161, Lr:1.00E-05
Epoch:15, Train_acc:99.7%, Train_loss:0.012, Test_acc:95.1%, Test_loss:0.164, Lr:1.00E-05
Epoch:16, Train_acc:99.7%, Train_loss:0.010, Test_acc:94.8%, Test_loss:0.179, Lr:1.00E-05
Epoch:17, Train_acc:99.9%, Train_loss:0.009, Test_acc:94.9%, Test_loss:0.163, Lr:1.00E-05
Epoch:18, Train_acc:99.8%, Train_loss:0.007, Test_acc:95.1%, Test_loss:0.168, Lr:1.00E-06
Epoch:19, Train_acc:99.9%, Train_loss:0.006, Test_acc:95.1%, Test_loss:0.175, Lr:1.00E-06
Epoch:20, Train_acc:99.9%, Train_loss:0.006, Test_acc:95.1%, Test_loss:0.174, Lr:1.00E-06
Epoch:21, Train_acc:99.9%, Train_loss:0.005, Test_acc:94.7%, Test_loss:0.179, Lr:1.00E-06
Epoch:22, Train_acc:100.0%, Train_loss:0.006, Test_acc:94.8%, Test_loss:0.188, Lr:1.00E-06
Epoch:23, Train_acc:99.9%, Train_loss:0.005, Test_acc:95.1%, Test_loss:0.179, Lr:1.00E-06
Epoch:24, Train_acc:99.9%, Train_loss:0.005, Test_acc:94.8%, Test_loss:0.178, Lr:1.00E-07
Epoch:25, Train_acc:99.9%, Train_loss:0.003, Test_acc:95.1%, Test_loss:0.181, Lr:1.00E-07
Epoch:26, Train_acc:99.9%, Train_loss:0.005, Test_acc:95.2%, Test_loss:0.179, Lr:1.00E-07
Epoch:27, Train_acc:99.9%, Train_loss:0.009, Test_acc:95.1%, Test_loss:0.191, Lr:1.00E-07
Epoch:28, Train_acc:99.9%, Train_loss:0.004, Test_acc:94.9%, Test_loss:0.185, Lr:1.00E-07
Epoch:29, Train_acc:99.9%, Train_loss:0.004, Test_acc:95.1%, Test_loss:0.191, Lr:1.00E-07
Epoch:30, Train_acc:100.0%, Train_loss:0.004, Test_acc:95.3%, Test_loss:0.175, Lr:1.00E-08
Epoch:31, Train_acc:100.0%, Train_loss:0.004, Test_acc:95.2%, Test_loss:0.175, Lr:1.00E-08
Epoch:32, Train_acc:99.9%, Train_loss:0.004, Test_acc:95.1%, Test_loss:0.179, Lr:1.00E-08
Done2
Epoch: 1, Train_acc:89.5%, Train_loss:0.254, Test_acc:89.7%, Test_loss:0.263, Lr:1.00E-04
Epoch: 2, Train_acc:89.7%, Train_loss:0.248, Test_acc:88.5%, Test_loss:0.290, Lr:1.00E-04
Epoch: 3, Train_acc:91.2%, Train_loss:0.222, Test_acc:91.3%, Test_loss:0.222, Lr:1.00E-04
Epoch: 4, Train_acc:91.2%, Train_loss:0.214, Test_acc:84.8%, Test_loss:0.399, Lr:1.00E-04
Epoch: 5, Train_acc:91.9%, Train_loss:0.205, Test_acc:91.0%, Test_loss:0.233, Lr:1.00E-04
Epoch: 6, Train_acc:92.0%, Train_loss:0.193, Test_acc:90.0%, Test_loss:0.267, Lr:1.00E-04
Epoch: 7, Train_acc:92.1%, Train_loss:0.188, Test_acc:89.4%, Test_loss:0.255, Lr:1.00E-04
Epoch: 8, Train_acc:92.9%, Train_loss:0.180, Test_acc:92.0%, Test_loss:0.193, Lr:1.00E-04
Epoch: 9, Train_acc:93.4%, Train_loss:0.165, Test_acc:89.0%, Test_loss:0.300, Lr:1.00E-04
Epoch:10, Train_acc:93.3%, Train_loss:0.169, Test_acc:86.3%, Test_loss:0.332, Lr:1.00E-04
Epoch:11, Train_acc:94.1%, Train_loss:0.145, Test_acc:91.0%, Test_loss:0.217, Lr:1.00E-04
Epoch:12, Train_acc:94.2%, Train_loss:0.146, Test_acc:89.7%, Test_loss:0.267, Lr:1.00E-04
Epoch:13, Train_acc:94.6%, Train_loss:0.136, Test_acc:90.7%, Test_loss:0.244, Lr:1.00E-04
Epoch:14, Train_acc:94.9%, Train_loss:0.131, Test_acc:92.0%, Test_loss:0.226, Lr:1.00E-05
Epoch:15, Train_acc:97.0%, Train_loss:0.083, Test_acc:93.0%, Test_loss:0.180, Lr:1.00E-05
Epoch:16, Train_acc:97.9%, Train_loss:0.061, Test_acc:93.2%, Test_loss:0.181, Lr:1.00E-05
Epoch:17, Train_acc:98.0%, Train_loss:0.056, Test_acc:93.2%, Test_loss:0.191, Lr:1.00E-05
Epoch:18, Train_acc:98.5%, Train_loss:0.047, Test_acc:92.9%, Test_loss:0.199, Lr:1.00E-05
Epoch:19, Train_acc:98.6%, Train_loss:0.041, Test_acc:93.0%, Test_loss:0.214, Lr:1.00E-05
Epoch:20, Train_acc:99.0%, Train_loss:0.031, Test_acc:93.1%, Test_loss:0.211, Lr:1.00E-05
Epoch:21, Train_acc:99.1%, Train_loss:0.029, Test_acc:93.0%, Test_loss:0.230, Lr:1.00E-06
Epoch:22, Train_acc:99.4%, Train_loss:0.024, Test_acc:92.7%, Test_loss:0.231, Lr:1.00E-06
Epoch:23, Train_acc:99.4%, Train_loss:0.022, Test_acc:93.2%, Test_loss:0.223, Lr:1.00E-06
Epoch:24, Train_acc:99.4%, Train_loss:0.022, Test_acc:92.9%, Test_loss:0.225, Lr:1.00E-06
Epoch:25, Train_acc:99.3%, Train_loss:0.024, Test_acc:92.9%, Test_loss:0.236, Lr:1.00E-06
Epoch:26, Train_acc:99.5%, Train_loss:0.019, Test_acc:92.9%, Test_loss:0.232, Lr:1.00E-06
Epoch:27, Train_acc:99.4%, Train_loss:0.021, Test_acc:93.0%, Test_loss:0.235, Lr:1.00E-07
Epoch:28, Train_acc:99.4%, Train_loss:0.024, Test_acc:92.8%, Test_loss:0.236, Lr:1.00E-07
Epoch:29, Train_acc:99.5%, Train_loss:0.019, Test_acc:93.1%, Test_loss:0.231, Lr:1.00E-07
Epoch:30, Train_acc:99.5%, Train_loss:0.018, Test_acc:93.0%, Test_loss:0.233, Lr:1.00E-07
Epoch:31, Train_acc:99.5%, Train_loss:0.025, Test_acc:93.2%, Test_loss:0.234, Lr:1.00E-07
Epoch:32, Train_acc:99.5%, Train_loss:0.018, Test_acc:93.0%, Test_loss:0.236, Lr:1.00E-07
Done3
比较三种模型的测试集预测的准确率:
print(best_acc1, best_acc2, best_acc3)
代码输出:
0.9664179104477612 0.953150912106136 0.9324212271973465
可以看到DenseNet121的准确率最好,但是只做了一次训练,不能这么绝对。
再来看验证集的准确率
def validate(dataloader, model):
model.eval()
size = len(dataloader.dataset)
validate_acc = 0
for x, y in dataloader:
x, y = x.to(device), y.to(device)
pred = model(x)
validate_acc += (pred.argmax(1) == y).type(torch.float).sum().item()
validate_acc /= size
return validate_acc
# 计算验证集准确率
validate_acc1 = validate(validate_dl, best_model1)
validate_acc2 = validate(validate_dl, best_model2)
validate_acc3 = validate(validate_dl, best_model3)
print(f"Validation Accuracy: {validate_acc1:.2%},{validate_acc2:.2%},{validate_acc3:.2%}")
代码输出:
Validation Accuracy: 96.08%,94.97%,93.47%
同样也是DenseNet121的准确率最好。
六、总结与讨论
SE-Net可以给特征通道加上权重,更好的训练出模型,并且相比单纯的DenseNet,同样的数据集可以有比较好的训练效果,在之前我们单纯的DenseNet验证集的准确率为93.23%。现在验证集的准确率达到96.08%。