深度学习笔记31_ResNet与DenseNet结合探索
- 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
- 🍖 原作者:K同学啊 | 接辅导、项目定制
一、我的环境
1.语言环境:Python 3.9
2.编译器:Pycharm
3.深度学习环境:pytorch
二、GPU设置
# 设置GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
三、数据导入
import copy
import pathlib
import warnings
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms
warnings.filterwarnings("ignore")
data_dir = "data/bird_photos"
data_dir = pathlib.Path(data_dir)
image_count = len(list(data_dir.glob('*/*')))
data_paths = list(data_dir.glob("*"))
classnames = [str(path).split("\\")[2] for path in data_paths]
print(classnames)
#['Bananaquit', 'Black Skimmer', 'Black Throated Bushtiti', 'Cockatoo']
num_classes = len(classnames)
print(num_classes)
#4
四、加载数据
batch_size = 8
train_dl = DataLoader(train_ds,
batch_size=batch_size,
shuffle=True,
num_workers=1)
test_dl = DataLoader(test_ds,
batch_size=batch_size,
shuffle=True,
num_workers=1)
五、数据处理
'''图像数据变换'''
train_transforms = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225],
)
])
total_data = datasets.ImageFolder(data_dir, transform=train_transforms)
#划分数据集
train_size = int(0.8 * len(total_data)) # train_size表示训练集大小,通过将总体数据长度的80%转换为整数得到;
test_size = len(total_data) - train_size # test_size表示测试集大小,是总体数据长度减去训练集大小。
# 使用torch.utils.data.random_split()方法进行数据集划分。该方法将总体数据total_data按照指定的大小比例([train_size, test_size])随机划分为训练集和测试集,
# 并将划分结果分别赋值给train_dataset和test_dataset两个变量。
train_ds, test_ds = random_split(total_data, [train_size, test_size])
数据可视化
plt.figure(figsize=(16, 10))
# plt.title("数据集")
for i in range(20):
plt.subplot(4, 5, i + 1)
plt.axis("off")
image = random.choice(img_list)
label_name = image.parts[-2]
plt.title(label_name)
plt.imshow(Image.open(str(image)))
plt.show()
再次检查数据
for X, y in test_dl:
print("Shape of X [N, C, H, W]:", X.shape)
print("Shape of y:", y.shape, y.dtype)
break
六、构建模型
class Block(nn.Module):
"""
param : in_channel--输入通道数
mid_channel -- 中间经历的通道数
out_channel -- ResNet部分使用的通道数(sum操作,这部分输出仍然是out_channel个通道)
dense_channel -- DenseNet部分使用的通道数(concat操作,这部分输出是2*dense_channel个通道)
groups -- conv2中的分组卷积参数
is_shortcut -- ResNet前是否进行shortcut操作
"""
def __init__(self, in_channel, mid_channel, out_channel, dense_channel, stride, groups, is_shortcut=False):
super(Block, self).__init__()
self.is_shortcut = is_shortcut
self.out_channel = out_channel
self.conv1 = nn.Sequential(
nn.Conv2d(in_channel, mid_channel, kernel_size=1, bias=False),
nn.BatchNorm2d(mid_channel),
nn.ReLU()
)
self.conv2 = nn.Sequential(
nn.Conv2d(mid_channel, mid_channel, kernel_size=3, stride=stride, padding=1, groups=groups, bias=False),
nn.BatchNorm2d(mid_channel),
nn.ReLU()
)
self.conv3 = nn.Sequential(
nn.Conv2d(mid_channel, out_channel + dense_channel, kernel_size=1, bias=False),
nn.BatchNorm2d(out_channel + dense_channel)
)
if self.is_shortcut:
self.shortcut = nn.Sequential(
nn.Conv2d(in_channel, out_channel + dense_channel, kernel_size=3, padding=1, stride=stride, bias=False),
nn.BatchNorm2d(out_channel + dense_channel)
)
self.relu = nn.ReLU(inplace=True)
# a[:, :self.out_channel, :, :]+x[:, :self.out_channel, :, :]是使用ResNet的方法,即采用sum的方式将特征图进行求和,通道数不变,都是out_channel个通道
# a[:, self.out_channel:, :, :], x[:, self.out_channel:, :, :]]是使用DenseNet的方法,即采用concat的方式将特征图在channel维度上直接进行叠加,通道数加倍,即2*dense_channel
# 注意最终是将out_channel个通道的特征(ResNet方式)与2*dense_channel个通道特征(DenseNet方式)进行叠加,因此最终通道数为out_channel+2*dense_channel
def forward(self, x):
a = x
x = self.conv1(x)
x = self.conv2(x)
x = self.conv3(x)
if self.is_shortcut:
a = self.shortcut(a)
x = torch.cat([a[:, :self.out_channel, :, :] + x[:, :self.out_channel, :, :], a[:, self.out_channel:, :, :],
x[:, self.out_channel:, :, :]], dim=1)
x = self.relu(x)
return x
class DPN(nn.Module):
def __init__(self, cfg):
super(DPN, self).__init__()
self.group = cfg['group']
self.in_channel = cfg['in_channel']
mid_channels = cfg['mid_channels']
out_channels = cfg['out_channels']
dense_channels = cfg['dense_channels']
num = cfg['num']
self.conv1 = nn.Sequential(
nn.Conv2d(3, self.in_channel, 7, stride=2, padding=3, bias=False, padding_mode='zeros'),
nn.BatchNorm2d(self.in_channel),
nn.ReLU(),
nn.MaxPool2d(kernel_size=3, stride=2, padding=0)
)
self.conv2 = self._make_layers(mid_channels[0], out_channels[0], dense_channels[0], num[0], stride=1)
self.conv3 = self._make_layers(mid_channels[1], out_channels[1], dense_channels[1], num[1], stride=2)
self.conv4 = self._make_layers(mid_channels[2], out_channels[2], dense_channels[2], num[2], stride=2)
self.conv5 = self._make_layers(mid_channels[3], out_channels[3], dense_channels[3], num[3], stride=2)
self.pool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(cfg['out_channels'][3] + (num[3] + 1) * cfg['dense_channels'][3], cfg['classes']) # fc层需要计算
def _make_layers(self, mid_channel, out_channel, dense_channel, num, stride):
layers = []
# is_shortcut=True表示进行shortcut操作,则将浅层的特征进行一次卷积后与进行第三次卷积的特征图相加(ResNet方式)和concat(DeseNet方式)操作
# 第一次使用Block可以满足浅层特征的利用,后续重复的Block则不需要线层特征,因此后续的Block的is_shortcut=False(默认值)
layers.append(Block(self.in_channel, mid_channel, out_channel, dense_channel, stride=stride, groups=self.group,
is_shortcut=True))
self.in_channel = out_channel + dense_channel * 2
for i in range(1, num):
layers.append(Block(self.in_channel, mid_channel, out_channel, dense_channel, stride=1, groups=self.group))
# 由于Block包含DenseNet在叠加特征图,所以第一次是2倍dense_channel,后面每次都会多出1倍dense_channel
self.in_channel += dense_channel
return nn.Sequential(*layers)
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
x = self.conv3(x)
x = self.conv4(x)
x = self.conv5(x)
x = self.pool(x)
x = torch.flatten(x, start_dim=1)
x = self.fc(x)
return x
def DPN92(n_class=4):
cfg = {
"group": 32,
"in_channel": 64,
"mid_channels": (96, 192, 384, 768),
"out_channels": (256, 512, 1024, 2048),
"dense_channels": (16, 32, 24, 128),
"num": (3, 4, 20, 3),
"classes": (n_class)
}
return DPN(cfg)
def DPN98(n_class=4):
cfg = {
"group": 40,
"in_channel": 96,
"mid_channels": (160, 320, 640, 1280),
"out_channels": (256, 512, 1024, 2048),
"dense_channels": (16, 32, 32, 128),
"num": (3, 6, 20, 3),
"classes": (n_class)
}
return DPN(cfg)
"""搭建DPN92模型"""
model = DPN92().to(device)
summary(model, (3, 224, 224)) # 查看模型的参数量以及相关指标
----------------------------------------------------------------
Layer (type) Output Shape Param #
================================================================
Conv2d-1 [-1, 64, 112, 112] 9,408
BatchNorm2d-2 [-1, 64, 112, 112] 128
ReLU-3 [-1, 64, 112, 112] 0
MaxPool2d-4 [-1, 64, 55, 55] 0
Conv2d-5 [-1, 96, 55, 55] 6,144
BatchNorm2d-6 [-1, 96, 55, 55] 192
ReLU-7 [-1, 96, 55, 55] 0
Conv2d-8 [-1, 96, 55, 55] 2,592
BatchNorm2d-9 [-1, 96, 55, 55] 192
ReLU-10 [-1, 96, 55, 55] 0
Conv2d-11 [-1, 272, 55, 55] 26,112
BatchNorm2d-12 [-1, 272, 55, 55] 544
Conv2d-13 [-1, 272, 55, 55] 156,672
BatchNorm2d-14 [-1, 272, 55, 55] 544
ReLU-15 [-1, 288, 55, 55] 0
Block-16 [-1, 288, 55, 55] 0
Conv2d-17 [-1, 96, 55, 55] 27,648
BatchNorm2d-18 [-1, 96, 55, 55] 192
ReLU-19 [-1, 96, 55, 55] 0
Conv2d-20 [-1, 96, 55, 55] 2,592
BatchNorm2d-21 [-1, 96, 55, 55] 192
ReLU-22 [-1, 96, 55, 55] 0
Conv2d-23 [-1, 272, 55, 55] 26,112
BatchNorm2d-24 [-1, 272, 55, 55] 544
ReLU-25 [-1, 304, 55, 55] 0
Block-26 [-1, 304, 55, 55] 0
Conv2d-27 [-1, 96, 55, 55] 29,184
BatchNorm2d-28 [-1, 96, 55, 55] 192
ReLU-29 [-1, 96, 55, 55] 0
Conv2d-30 [-1, 96, 55, 55] 2,592
BatchNorm2d-31 [-1, 96, 55, 55] 192
ReLU-32 [-1, 96, 55, 55] 0
Conv2d-33 [-1, 272, 55, 55] 26,112
BatchNorm2d-34 [-1, 272, 55, 55] 544
ReLU-35 [-1, 320, 55, 55] 0
Block-36 [-1, 320, 55, 55] 0
Conv2d-37 [-1, 192, 55, 55] 61,440
BatchNorm2d-38 [-1, 192, 55, 55] 384
ReLU-39 [-1, 192, 55, 55] 0
Conv2d-40 [-1, 192, 28, 28] 10,368
BatchNorm2d-41 [-1, 192, 28, 28] 384
ReLU-42 [-1, 192, 28, 28] 0
Conv2d-43 [-1, 544, 28, 28] 104,448
BatchNorm2d-44 [-1, 544, 28, 28] 1,088
Conv2d-45 [-1, 544, 28, 28] 1,566,720
BatchNorm2d-46 [-1, 544, 28, 28] 1,088
ReLU-47 [-1, 576, 28, 28] 0
Block-48 [-1, 576, 28, 28] 0
Conv2d-49 [-1, 192, 28, 28] 110,592
BatchNorm2d-50 [-1, 192, 28, 28] 384
ReLU-51 [-1, 192, 28, 28] 0
Conv2d-52 [-1, 192, 28, 28] 10,368
BatchNorm2d-53 [-1, 192, 28, 28] 384
ReLU-54 [-1, 192, 28, 28] 0
Conv2d-55 [-1, 544, 28, 28] 104,448
BatchNorm2d-56 [-1, 544, 28, 28] 1,088
ReLU-57 [-1, 608, 28, 28] 0
Block-58 [-1, 608, 28, 28] 0
Conv2d-59 [-1, 192, 28, 28] 116,736
BatchNorm2d-60 [-1, 192, 28, 28] 384
ReLU-61 [-1, 192, 28, 28] 0
Conv2d-62 [-1, 192, 28, 28] 10,368
BatchNorm2d-63 [-1, 192, 28, 28] 384
ReLU-64 [-1, 192, 28, 28] 0
Conv2d-65 [-1, 544, 28, 28] 104,448
BatchNorm2d-66 [-1, 544, 28, 28] 1,088
ReLU-67 [-1, 640, 28, 28] 0
Block-68 [-1, 640, 28, 28] 0
Conv2d-69 [-1, 192, 28, 28] 122,880
BatchNorm2d-70 [-1, 192, 28, 28] 384
ReLU-71 [-1, 192, 28, 28] 0
Conv2d-72 [-1, 192, 28, 28] 10,368
BatchNorm2d-73 [-1, 192, 28, 28] 384
ReLU-74 [-1, 192, 28, 28] 0
Conv2d-75 [-1, 544, 28, 28] 104,448
BatchNorm2d-76 [-1, 544, 28, 28] 1,088
ReLU-77 [-1, 672, 28, 28] 0
Block-78 [-1, 672, 28, 28] 0
Conv2d-79 [-1, 384, 28, 28] 258,048
BatchNorm2d-80 [-1, 384, 28, 28] 768
ReLU-81 [-1, 384, 28, 28] 0
Conv2d-82 [-1, 384, 14, 14] 41,472
BatchNorm2d-83 [-1, 384, 14, 14] 768
ReLU-84 [-1, 384, 14, 14] 0
Conv2d-85 [-1, 1048, 14, 14] 402,432
BatchNorm2d-86 [-1, 1048, 14, 14] 2,096
Conv2d-87 [-1, 1048, 14, 14] 6,338,304
BatchNorm2d-88 [-1, 1048, 14, 14] 2,096
ReLU-89 [-1, 1072, 14, 14] 0
Block-90 [-1, 1072, 14, 14] 0
Conv2d-91 [-1, 384, 14, 14] 411,648
BatchNorm2d-92 [-1, 384, 14, 14] 768
ReLU-93 [-1, 384, 14, 14] 0
Conv2d-94 [-1, 384, 14, 14] 41,472
BatchNorm2d-95 [-1, 384, 14, 14] 768
ReLU-96 [-1, 384, 14, 14] 0
Conv2d-97 [-1, 1048, 14, 14] 402,432
BatchNorm2d-98 [-1, 1048, 14, 14] 2,096
ReLU-99 [-1, 1096, 14, 14] 0
Block-100 [-1, 1096, 14, 14] 0
Conv2d-101 [-1, 384, 14, 14] 420,864
BatchNorm2d-102 [-1, 384, 14, 14] 768
ReLU-103 [-1, 384, 14, 14] 0
Conv2d-104 [-1, 384, 14, 14] 41,472
BatchNorm2d-105 [-1, 384, 14, 14] 768
ReLU-106 [-1, 384, 14, 14] 0
Conv2d-107 [-1, 1048, 14, 14] 402,432
BatchNorm2d-108 [-1, 1048, 14, 14] 2,096
ReLU-109 [-1, 1120, 14, 14] 0
Block-110 [-1, 1120, 14, 14] 0
Conv2d-111 [-1, 384, 14, 14] 430,080
BatchNorm2d-112 [-1, 384, 14, 14] 768
ReLU-113 [-1, 384, 14, 14] 0
Conv2d-114 [-1, 384, 14, 14] 41,472
BatchNorm2d-115 [-1, 384, 14, 14] 768
ReLU-116 [-1, 384, 14, 14] 0
Conv2d-117 [-1, 1048, 14, 14] 402,432
BatchNorm2d-118 [-1, 1048, 14, 14] 2,096
ReLU-119 [-1, 1144, 14, 14] 0
Block-120 [-1, 1144, 14, 14] 0
Conv2d-121 [-1, 384, 14, 14] 439,296
BatchNorm2d-122 [-1, 384, 14, 14] 768
ReLU-123 [-1, 384, 14, 14] 0
Conv2d-124 [-1, 384, 14, 14] 41,472
BatchNorm2d-125 [-1, 384, 14, 14] 768
ReLU-126 [-1, 384, 14, 14] 0
Conv2d-127 [-1, 1048, 14, 14] 402,432
BatchNorm2d-128 [-1, 1048, 14, 14] 2,096
ReLU-129 [-1, 1168, 14, 14] 0
Block-130 [-1, 1168, 14, 14] 0
Conv2d-131 [-1, 384, 14, 14] 448,512
BatchNorm2d-132 [-1, 384, 14, 14] 768
ReLU-133 [-1, 384, 14, 14] 0
Conv2d-134 [-1, 384, 14, 14] 41,472
BatchNorm2d-135 [-1, 384, 14, 14] 768
ReLU-136 [-1, 384, 14, 14] 0
Conv2d-137 [-1, 1048, 14, 14] 402,432
BatchNorm2d-138 [-1, 1048, 14, 14] 2,096
ReLU-139 [-1, 1192, 14, 14] 0
Block-140 [-1, 1192, 14, 14] 0
Conv2d-141 [-1, 384, 14, 14] 457,728
BatchNorm2d-142 [-1, 384, 14, 14] 768
ReLU-143 [-1, 384, 14, 14] 0
Conv2d-144 [-1, 384, 14, 14] 41,472
BatchNorm2d-145 [-1, 384, 14, 14] 768
ReLU-146 [-1, 384, 14, 14] 0
Conv2d-147 [-1, 1048, 14, 14] 402,432
BatchNorm2d-148 [-1, 1048, 14, 14] 2,096
ReLU-149 [-1, 1216, 14, 14] 0
Block-150 [-1, 1216, 14, 14] 0
Conv2d-151 [-1, 384, 14, 14] 466,944
BatchNorm2d-152 [-1, 384, 14, 14] 768
ReLU-153 [-1, 384, 14, 14] 0
Conv2d-154 [-1, 384, 14, 14] 41,472
BatchNorm2d-155 [-1, 384, 14, 14] 768
ReLU-156 [-1, 384, 14, 14] 0
Conv2d-157 [-1, 1048, 14, 14] 402,432
BatchNorm2d-158 [-1, 1048, 14, 14] 2,096
ReLU-159 [-1, 1240, 14, 14] 0
Block-160 [-1, 1240, 14, 14] 0
Conv2d-161 [-1, 384, 14, 14] 476,160
BatchNorm2d-162 [-1, 384, 14, 14] 768
ReLU-163 [-1, 384, 14, 14] 0
Conv2d-164 [-1, 384, 14, 14] 41,472
BatchNorm2d-165 [-1, 384, 14, 14] 768
ReLU-166 [-1, 384, 14, 14] 0
Conv2d-167 [-1, 1048, 14, 14] 402,432
BatchNorm2d-168 [-1, 1048, 14, 14] 2,096
ReLU-169 [-1, 1264, 14, 14] 0
Block-170 [-1, 1264, 14, 14] 0
Conv2d-171 [-1, 384, 14, 14] 485,376
BatchNorm2d-172 [-1, 384, 14, 14] 768
ReLU-173 [-1, 384, 14, 14] 0
Conv2d-174 [-1, 384, 14, 14] 41,472
BatchNorm2d-175 [-1, 384, 14, 14] 768
ReLU-176 [-1, 384, 14, 14] 0
Conv2d-177 [-1, 1048, 14, 14] 402,432
BatchNorm2d-178 [-1, 1048, 14, 14] 2,096
ReLU-179 [-1, 1288, 14, 14] 0
Block-180 [-1, 1288, 14, 14] 0
Conv2d-181 [-1, 384, 14, 14] 494,592
BatchNorm2d-182 [-1, 384, 14, 14] 768
ReLU-183 [-1, 384, 14, 14] 0
Conv2d-184 [-1, 384, 14, 14] 41,472
BatchNorm2d-185 [-1, 384, 14, 14] 768
ReLU-186 [-1, 384, 14, 14] 0
Conv2d-187 [-1, 1048, 14, 14] 402,432
BatchNorm2d-188 [-1, 1048, 14, 14] 2,096
ReLU-189 [-1, 1312, 14, 14] 0
Block-190 [-1, 1312, 14, 14] 0
Conv2d-191 [-1, 384, 14, 14] 503,808
BatchNorm2d-192 [-1, 384, 14, 14] 768
ReLU-193 [-1, 384, 14, 14] 0
Conv2d-194 [-1, 384, 14, 14] 41,472
BatchNorm2d-195 [-1, 384, 14, 14] 768
ReLU-196 [-1, 384, 14, 14] 0
Conv2d-197 [-1, 1048, 14, 14] 402,432
BatchNorm2d-198 [-1, 1048, 14, 14] 2,096
ReLU-199 [-1, 1336, 14, 14] 0
Block-200 [-1, 1336, 14, 14] 0
Conv2d-201 [-1, 384, 14, 14] 513,024
BatchNorm2d-202 [-1, 384, 14, 14] 768
ReLU-203 [-1, 384, 14, 14] 0
Conv2d-204 [-1, 384, 14, 14] 41,472
BatchNorm2d-205 [-1, 384, 14, 14] 768
ReLU-206 [-1, 384, 14, 14] 0
Conv2d-207 [-1, 1048, 14, 14] 402,432
BatchNorm2d-208 [-1, 1048, 14, 14] 2,096
ReLU-209 [-1, 1360, 14, 14] 0
Block-210 [-1, 1360, 14, 14] 0
Conv2d-211 [-1, 384, 14, 14] 522,240
BatchNorm2d-212 [-1, 384, 14, 14] 768
ReLU-213 [-1, 384, 14, 14] 0
Conv2d-214 [-1, 384, 14, 14] 41,472
BatchNorm2d-215 [-1, 384, 14, 14] 768
ReLU-216 [-1, 384, 14, 14] 0
Conv2d-217 [-1, 1048, 14, 14] 402,432
BatchNorm2d-218 [-1, 1048, 14, 14] 2,096
ReLU-219 [-1, 1384, 14, 14] 0
Block-220 [-1, 1384, 14, 14] 0
Conv2d-221 [-1, 384, 14, 14] 531,456
BatchNorm2d-222 [-1, 384, 14, 14] 768
ReLU-223 [-1, 384, 14, 14] 0
Conv2d-224 [-1, 384, 14, 14] 41,472
BatchNorm2d-225 [-1, 384, 14, 14] 768
ReLU-226 [-1, 384, 14, 14] 0
Conv2d-227 [-1, 1048, 14, 14] 402,432
BatchNorm2d-228 [-1, 1048, 14, 14] 2,096
ReLU-229 [-1, 1408, 14, 14] 0
Block-230 [-1, 1408, 14, 14] 0
Conv2d-231 [-1, 384, 14, 14] 540,672
BatchNorm2d-232 [-1, 384, 14, 14] 768
ReLU-233 [-1, 384, 14, 14] 0
Conv2d-234 [-1, 384, 14, 14] 41,472
BatchNorm2d-235 [-1, 384, 14, 14] 768
ReLU-236 [-1, 384, 14, 14] 0
Conv2d-237 [-1, 1048, 14, 14] 402,432
BatchNorm2d-238 [-1, 1048, 14, 14] 2,096
ReLU-239 [-1, 1432, 14, 14] 0
Block-240 [-1, 1432, 14, 14] 0
Conv2d-241 [-1, 384, 14, 14] 549,888
BatchNorm2d-242 [-1, 384, 14, 14] 768
ReLU-243 [-1, 384, 14, 14] 0
Conv2d-244 [-1, 384, 14, 14] 41,472
BatchNorm2d-245 [-1, 384, 14, 14] 768
ReLU-246 [-1, 384, 14, 14] 0
Conv2d-247 [-1, 1048, 14, 14] 402,432
BatchNorm2d-248 [-1, 1048, 14, 14] 2,096
ReLU-249 [-1, 1456, 14, 14] 0
Block-250 [-1, 1456, 14, 14] 0
Conv2d-251 [-1, 384, 14, 14] 559,104
BatchNorm2d-252 [-1, 384, 14, 14] 768
ReLU-253 [-1, 384, 14, 14] 0
Conv2d-254 [-1, 384, 14, 14] 41,472
BatchNorm2d-255 [-1, 384, 14, 14] 768
ReLU-256 [-1, 384, 14, 14] 0
Conv2d-257 [-1, 1048, 14, 14] 402,432
BatchNorm2d-258 [-1, 1048, 14, 14] 2,096
ReLU-259 [-1, 1480, 14, 14] 0
Block-260 [-1, 1480, 14, 14] 0
Conv2d-261 [-1, 384, 14, 14] 568,320
BatchNorm2d-262 [-1, 384, 14, 14] 768
ReLU-263 [-1, 384, 14, 14] 0
Conv2d-264 [-1, 384, 14, 14] 41,472
BatchNorm2d-265 [-1, 384, 14, 14] 768
ReLU-266 [-1, 384, 14, 14] 0
Conv2d-267 [-1, 1048, 14, 14] 402,432
BatchNorm2d-268 [-1, 1048, 14, 14] 2,096
ReLU-269 [-1, 1504, 14, 14] 0
Block-270 [-1, 1504, 14, 14] 0
Conv2d-271 [-1, 384, 14, 14] 577,536
BatchNorm2d-272 [-1, 384, 14, 14] 768
ReLU-273 [-1, 384, 14, 14] 0
Conv2d-274 [-1, 384, 14, 14] 41,472
BatchNorm2d-275 [-1, 384, 14, 14] 768
ReLU-276 [-1, 384, 14, 14] 0
Conv2d-277 [-1, 1048, 14, 14] 402,432
BatchNorm2d-278 [-1, 1048, 14, 14] 2,096
ReLU-279 [-1, 1528, 14, 14] 0
Block-280 [-1, 1528, 14, 14] 0
Conv2d-281 [-1, 768, 14, 14] 1,173,504
BatchNorm2d-282 [-1, 768, 14, 14] 1,536
ReLU-283 [-1, 768, 14, 14] 0
Conv2d-284 [-1, 768, 7, 7] 165,888
BatchNorm2d-285 [-1, 768, 7, 7] 1,536
ReLU-286 [-1, 768, 7, 7] 0
Conv2d-287 [-1, 2176, 7, 7] 1,671,168
BatchNorm2d-288 [-1, 2176, 7, 7] 4,352
Conv2d-289 [-1, 2176, 7, 7] 29,924,352
BatchNorm2d-290 [-1, 2176, 7, 7] 4,352
ReLU-291 [-1, 2304, 7, 7] 0
Block-292 [-1, 2304, 7, 7] 0
Conv2d-293 [-1, 768, 7, 7] 1,769,472
BatchNorm2d-294 [-1, 768, 7, 7] 1,536
ReLU-295 [-1, 768, 7, 7] 0
Conv2d-296 [-1, 768, 7, 7] 165,888
BatchNorm2d-297 [-1, 768, 7, 7] 1,536
ReLU-298 [-1, 768, 7, 7] 0
Conv2d-299 [-1, 2176, 7, 7] 1,671,168
BatchNorm2d-300 [-1, 2176, 7, 7] 4,352
ReLU-301 [-1, 2432, 7, 7] 0
Block-302 [-1, 2432, 7, 7] 0
Conv2d-303 [-1, 768, 7, 7] 1,867,776
BatchNorm2d-304 [-1, 768, 7, 7] 1,536
ReLU-305 [-1, 768, 7, 7] 0
Conv2d-306 [-1, 768, 7, 7] 165,888
BatchNorm2d-307 [-1, 768, 7, 7] 1,536
ReLU-308 [-1, 768, 7, 7] 0
Conv2d-309 [-1, 2176, 7, 7] 1,671,168
BatchNorm2d-310 [-1, 2176, 7, 7] 4,352
ReLU-311 [-1, 2560, 7, 7] 0
Block-312 [-1, 2560, 7, 7] 0
AdaptiveAvgPool2d-313 [-1, 2560, 1, 1] 0
Linear-314 [-1, 4] 10,244
================================================================
Total params: 67,994,324
Trainable params: 67,994,324
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.57
Forward/backward pass size (MB): 489.24
Params size (MB): 259.38
Estimated Total Size (MB): 749.20
----------------------------------------------------------------
def train(dataloader, model, loss_fn, optimizer):
size = len(dataloader.dataset)
num_batches = len(dataloader)
train_acc, train_loss = 0, 0
for X, y in dataloader:
X, y = X.to(device), y.to(device)
pred = model(X)
loss = loss_fn(pred, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
train_loss += loss.item()
train_acc += (pred.argmax(1) == y).type(torch.float).sum().item()
train_loss /= num_batches
train_acc /= size
return train_acc, train_loss
def test(dataloader, model, loss_fn):
size = len(dataloader.dataset) # 测试集的大小
num_batches = len(dataloader) # 批次数目, (size/batch_size,向上取整)
test_loss, test_acc = 0, 0
# 当不进行训练时,停止梯度更新,节省计算内存消耗
with torch.no_grad():
for imgs, target in dataloader:
imgs, target = imgs.to(device), target.to(device)
# 计算loss
target_pred = model(imgs)
loss = loss_fn(target_pred, target)
test_loss += loss.item()
test_acc += (target_pred.argmax(1) == target).type(torch.float).sum().item()
test_acc /= size
test_loss /= num_batches
return test_acc, test_loss
七、训练模型
epochs = 50
train_loss = []
train_acc = []
test_loss = []
test_acc = []
patience = 10
no_improve_epoch = 0
loss_fn = nn.CrossEntropyLoss()
opt = torch.optim.Adam(model.parameters(), lr=1e-4)
scheduler = torch.optim.lr_scheduler.StepLR(opt, step_size=2, gamma=0.92)
def main():
best_acc = 0
# 开始训练
for epoch in range(epochs):
model.train()
epoch_train_acc, epoch_train_loss = train(train_dl, model, loss_fn, opt)
model.eval()
epoch_test_acc, epoch_test_loss = test(test_dl, model, loss_fn)
if epoch_test_acc > best_acc:
best_acc = epoch_test_acc
best_model = copy.deepcopy(model)
if no_improve_epoch > patience:
print(f"Early stop triggered at epoch {epoch + 1}")
break # 早停
train_acc.append(epoch_train_acc)
train_loss.append(epoch_train_loss)
test_acc.append(epoch_test_acc)
test_loss.append(epoch_test_loss)
scheduler.step() # 更新学习率
lr = opt.state_dict()['param_groups'][0]['lr']
template = ('Epoch:{:2d}, Train_acc:{:.1f}%, Train_loss:{:.3f}, Test_acc:{:.1f}%, Test_loss:{:.3f}, Lr:{:.2E}')
print(template.format(epoch + 1, epoch_train_acc * 100, epoch_train_loss,
epoch_test_acc * 100, epoch_test_loss, lr))
print('Done')
八、模型评估
plt.rcParams['axes.unicode_minus'] = False # 用来正常显示负号
plt.rcParams['figure.dpi'] = 100 # 分辨率
epochs_range = range(epoch)
plt.figure(figsize=(12, 3))
plt.subplot(1, 2, 1)
plt.plot(epochs_range, train_acc, label='Training Accuracy')
plt.plot(epochs_range, test_acc, label='Test Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')
plt.subplot(1, 2, 2)
plt.plot(epochs_range, train_loss, label='Training Loss')
plt.plot(epochs_range, test_loss, label='Test Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()
九、总结
这周学习ResNet与DenseNet模型结合实现鸟类预测:
结合的方式:
1. 残差连接与密集连接的结合
- 残差块嵌入密集块:可以在DenseNet的每个密集块中引入ResNet的残差连接。例如,每个密集块的输出可以通过残差连接与输入进行相加,这样可以帮助网络更容易学习到残差映射。
2. 混合架构设计
-
前后端结构:设计一个网络架构,前半部分使用ResNet提取初始特征,后半部分使用DenseNet进一步融合和精细化特征。这样的结构可以有效地利用两者的优点。
-
分支结构:在网络中实现多分支结构,一部分为ResNet,另一部分为DenseNet。网络可以通过特征融合层(如拼接或加法)将两个分支的特征进行融合。
3. 融合层设计
- 特征融合:可以在网络的不同层次设计特征融合层,利用注意力机制或者简单的拼接方式将ResNet和DenseNet的特征进行组合。
算法详解:
介绍的duall path networks(DPN)是颜水成老师新作,2017年4月在arxiv上放出,对于图像分类的效果有一定提升。我们知道ResNet,ResNeXt,DenseNet等网络在图像分类领域的效果显而易见,而DPN可以说是融合了ResNeXt和DenseNet的核心思想,这里为什么不说是融合了ResNet和DenseNet,因为作者也用了group操作,而ResNeXt和ResNet的主要区别就在于group操作。
我们知道ResNet,ResNeXt,DenseNet等网络在图像分类领域的效果显而易见,而DPN可以说是融合了ResNeXt和DenseNet的核心思想,这里为什么不说是融合了ResNet和DenseNet,因为作者也用了group操作,而ResNeXt和ResNet的主要区别就在于group操作。
优势:
1、关于模型复杂度,作者的原文是这么说的:The DPN-92 costs about 15% fewer parameters than ResNeXt-101 (32 4d), while the DPN-98 costs about 26% fewer parameters than ResNeXt-101 (64 4d).
2、关于计算复杂度,作者的原文是这么说的:DPN-92 consumes about 19% less FLOPs than ResNeXt-101(32 4d), and the DPN-98 consumes about 25% less FLOPs than ResNeXt-101(64 4d).
先放上网络结构Table1,有一个直观的印象。其实DPN和ResNeXt(ResNet)的结构很相似。最开始一个7*7的卷积层和max pooling层,然后是4个stage,每个stage包含几个sub-stage(后面会介绍),再接着是一个global average pooling和全连接层,最后是softmax层。重点在于stage里面的内容,也是DPN算法的核心。
因为DPN算法简单讲就是将ResNeXt和DenseNet融合成一个网络,因此在介绍DPN的每个stage里面的结构之前,先简单过一下ResNet(ResNeXt和ResNet的子结构在宏观上是一样的)和DenseNet的核心内容。下图中的(a)是ResNet的某个stage中的一部分。(a)的左边竖着的大矩形框表示输入输出内容,对一个输入x,分两条线走,一条线还是x本身,另一条线是x经过11卷积,33卷积,11卷积(这三个卷积层的组合又称作bottleneck),然后把这两条线的输出做一个element-wise addition,也就是对应值相加,就是(a)中的加号,得到的结果又变成下一个同样模块的输入,几个这样的模块组合在一起就成了一个stage(比如Table1中的conv3)。(b)表示DenseNet的核心内容。(c)的左边竖着的多边形框表示输入输出内容,对输入x,只走一条线,那就是经过几层卷积后和x做一个通道的合并(cancat),得到的结果又成了下一个小模块的输入,这样每一个小模块的输入都在不断累加,举个例子:第二个小模块的输入包含第一个小模块的输出和第一个小模块的输入,以此类推。
实验结果:
Table2是在ImageNet-1k数据集上和目前最好的几个算法的对比:ResNet,ResNeXt,DenseNet。可以看出在模型大小,GFLOP和准确率方面DPN网络都更胜一筹。不过在这个对比中好像DenseNet的表现不如DenseNet那篇论文介绍的那么喜人,可能是因为DenseNet的需要更多的训练技巧。
DPN模型架构: