当前位置: 首页 > article >正文

Pytorch 第十回:卷积神经网络——DenseNet模型

Pytorch 第十回:卷积神经网络——DenseNet模型

本次开启深度学习第十回,基于Pytorch的DenseNet卷积神经网络模型。这是分享的第五个卷积神经网络模型。在第九回当中,分享了ResNet模型,该模型解决了梯度消失和网络退化的问题。本回的DenseNet模型在某种程度上来说是ResNet模型的升级版,接下来给大家分享具体思路。
本次学习,借助的平台是PyCharm 2024.1.3,python版本3.11 numpy版本是1.26.4,pytorch版本2.0.0+cu118,d2l的版本是1.0.3

文章目录

  • Pytorch 第十回:卷积神经网络——DenseNet模型
  • 前言
    • 1、 DenseNet(Densely Connected Convolutional Networks)模型
    • 2、稠密块
    • 3、过度层
  • 一、数据准备
  • 二、模型准备
    • 1.稠密块
    • 2.过渡层
    • 3.DenseNet模型
  • 三、模型训练
    • 1、实例化DenseNet模型
    • 2、迭代训练模型
    • 3、输出展示
  • 总结


前言

讲述模型前,先讲述两个概念,统一下思路:

1、 DenseNet(Densely Connected Convolutional Networks)模型

DenseNet模型是一种通过密集连接实现特征重用的深度卷积神经网络。其设计理念是让每一层的输入包含前面所有层的输出,通过通道维度拼接,从而形成密集的信息传递流。换而言之,DenseNet模型是采用跨层方式将特征在通道维度上进行拼接。
对比ResNet模型,在特征传递方式上,ResNet模型只是跨层求和,传递了浅层的部分特征。再看DenseNet模型,其每层的输出都可以传递到后续的网络层。这样不仅可以很好的进行梯度传播,还可以将数据的浅层特征和深层特征共同进行学习,从而得到良好的训练效果。

2、稠密块

稠密块是DenseNet模型的基层单元,其作用是通过密集连接实现特征重用和梯度流动优化。稠密连接的效果如下:
在这里插入图片描述
从上图可以清晰的看到数据的流动方向,尤其是每层的输出都可以传递到后续的网络层。

3、过度层

过渡层也是DenseNet模型的基础单元。稠密块会增加传递的通道数量,使模型复杂化;过渡层则是削减通道数,控制模型复杂度的。其组合结构如下图所示:
在这里插入图片描述

闲言少叙,直接展示逻辑,先上引用:

import numpy as np
import torch
from torch import nn
from torchvision.datasets import CIFAR10
import time
from torch.utils.data import DataLoader
from d2l import torch as d2l

一、数据准备

如前几回一样,本次仍然采用CIFAR10数据集,因此不做重点解释(有兴趣的可以查看第六回内容),本回只展示代码:

def data_treating(x):
    x = x.resize((96, 96))  #
    x = np.array(x, dtype='float32') / 255
    x = (x - 0.5) / 0.5  #
    x = x.transpose((2, 0, 1))  #
    x = torch.from_numpy(x)
    return x
train_set = CIFAR10('./data', train=True, transform=data_treating)
train_data = torch.utils.data.DataLoader(train_set, batch_size=64, shuffle=True)
test_set = CIFAR10('./data', train=False, transform=data_treating)
test_data = torch.utils.data.DataLoader(test_set, batch_size=128, shuffle=False)

二、模型准备

1.稠密块

1)在定义稠密块时,为简化代码,这里将卷积操作单独进行封装。封装卷积块时,按照“批量规范化层->激活层->卷积层”的结构进行组合。代码如下所示:

def conv_block(channel_in, channel_out):
    return nn.Sequential(
        nn.BatchNorm2d(channel_in),
        nn.ReLU(True),
        nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1))

2)稠密块含有多个卷积块,有两点需要注意:一是卷积块的输入通道不同,但是输出通道相同;二是在前向传播中,不仅需要连接卷积块,输入特征也需要连接。
这里给稠密块设定了三个参数:参数一是首层卷积的输入通道数,参数二是输出通道数,参数三是卷积层的个数。代码如下所示:

class dense_block(nn.Module):
    def __init__(self, channel_in, channel_growth, num_conv):
        super(dense_block, self).__init__()
        block = []
        for i in range(num_conv):
            block.append(conv_block((channel_in + channel_growth * i), channel_growth))

        self.net = nn.Sequential(*block)

    def forward(self, x):
        for layer in self.net:
            y = layer(x)
            x = torch.cat((y, x), dim=1)
        return x

注:
1)为保证前后卷积块能够准确的进行连接,后一层卷积的输入通道个数应该为:
首层卷积数+输出通道数*该层的层数(首层为0层)
2)卷积块的个数影响着稠密块输出通道数相对于稠密块输入通道数的增长程度,因此参数num_conv也称为增长率

2.过渡层

过渡层的结构与前面封装的卷积块很类似,也存在“批量规范化层->激活层->卷积层”的结构。不同的是,这里在卷积层后还加上了平均池化层。代码如下所示:

def transition_block(channel_in, channel_out):
    return nn.Sequential(
        nn.BatchNorm2d(channel_in),
        nn.ReLU(True),
        nn.Conv2d(channel_in, channel_out, kernel_size=1),
        nn.AvgPool2d(kernel_size=2, stride=2)
    )

注:
这里的1*1卷积层用于降低通道数,池化层用于降低数据大小。

3.DenseNet模型

与第九回的ResNet模型比较类似,DenseNet模型也采用了首个网络块单独定义,后续网络块重叠定义的方式进行搭建。为了提高代码的简洁性,这里采用循环的方式搭建后续网络块。与ResNet类似,在网络的最后,也是采用连接全局池化层和全连接输出模型结果。代码如下所示:

class dense_net(nn.Module):
    def __init__(self, channel_in, classes_out, channel_out, block_num):
        super(dense_net, self).__init__()
        self.block1 = nn.Sequential(
            nn.Conv2d(channel_in, 64, 7, 2, 3),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            nn.MaxPool2d(3, 2, padding=1)
        )

        channels = 64
        blocks = []
        for i, num in enumerate(block_num):
            blocks.append(dense_block(channels, channel_out, num))
            channels += num * channel_out
            if i != len(block_num) - 1:
                blocks.append(transition_block(channels, channels // 2))
                channels = channels // 2

        self.block2 = nn.Sequential(*blocks)
        self.block2.add_module('bn', nn.BatchNorm2d(channels))
        self.block2.add_module('relu', nn.ReLU(True))
        self.block2.add_module('avg_pool', nn.AvgPool2d(3))
        self.block2.add_module('flatten', nn.Flatten())
        self.classifier = nn.Linear(channels, classes_out)

    def forward(self, x):
        x = self.block1(x)
        x = self.block2(x)
        x = self.classifier(x)
        return x

三、模型训练

1、实例化DenseNet模型

这里输入为3个通道,因为彩色图片有三个数据通道。输出为10,因为数据集有10个类别(数据集的介绍,在第六回中)

classify_ResNet = dense_net(3, 10, channel_out=32, block_num=[6, 12, 24, 16])

2、迭代训练模型

本次训练采用d2l.train_ch6()函数,其参数有六个:第一个是模型,第二个是训练集,第三个是测试集,第四个是迭代次数(设定为20次),第五个是学习率(设定为0.01),第六个是进行训练的设备(设定为GPU训练)。

d2l.train_ch6(classify_ResNet, train_data, test_data, 20, 0.01, d2l.try_gpu())

注:
由于本回采用d2l.train_ch6()进行数据训练,里面集成了损失函数和优化器,因此不需要单独定义(在第八回小记中介绍了如何安装d2l库)。

3、输出展示

epoch0, loss 1.391, train acc 0.499, test acc 0.533,consume time 545.7
epoch4, loss 0.589, train acc 0.796, test acc 0.689,consume time 2732.1
epoch8, loss 0.303, train acc 0.896, test acc 0.618,consume time 4916.3
epoch12, loss 0.127, train acc 0.959, test acc 0.709,consume time 7100.8
epoch16, loss 0.054, train acc 0.984, test acc 0.735,consume time 9286.2
epoch19, loss 0.029, train acc 0.992, test acc 0.810,consume time 10926.9

对比第九回的ResNet模型,DenseNet模型的测试集精度有了较大的提升。

总结

1)数据准备:准备CIFAR10数据集
2)模型准备:准备稠密块,过渡层、DenseNet模型
3)数据训练:实例化训练模型,采用train_ch6函数进行迭代训练。


http://www.kler.cn/a/582576.html

相关文章:

  • 图论Day2·搜索
  • 大模型安全新范式:DeepSeek一体机内容安全卫士发布
  • JS—闭包:3分钟从入门到放弃
  • 数据结构:排序详解(使用语言:C语言)
  • 赶紧白P这款免费神器!
  • 差分数组题目
  • 机器学习(吴恩达)
  • 有关MyBatis的缓存(一级缓存和二级缓存)
  • 【第四节】windows sdk编程:windows 中的窗口
  • 基于Python+SQLite实现校园信息化统计平台
  • java校验String是否符合时间格式 yyyy-MM-dd HH:mm:ss
  • vs2022用git插件重置--删除更改(--hard)后恢复删除的内容
  • Qt 6.6.1 中 QPixmap::grabWindow() 的用法与替代方案
  • Spring之生命周期Bean的生成过程
  • python-leetcode-K 和数对的最大数目
  • 【Godot4.3】RenderingServer总结
  • c++介绍运算符重载九
  • vscode接入DeepSeek 免费送2000 万 Tokens 解决DeepSeek无法充值问题
  • 5秒学会excel中序号列自动增加,不是拖动,图解加说明,解决序号自增多了手拖太累
  • VSTO(C#)Excel开发5:调整表格到一页