当前位置: 首页 > article >正文

《深度学习》VGG网络

文章目录

  • 1.VGG的网络架构
  • 2.案例:手写数字识别

学习目标:

  • 知道VGG网络结构的特点
  • 能够利用VGG网络完成图像分类

2014年,⽜津⼤学计算机视觉组(Visual Geometry Group)和GoogleDeepMind公司的研究员⼀起研发出了新的深度卷积神经⽹络:
VGGNet,并取得了ILSVRC2014⽐赛分类项⽬的第⼆名,主要贡献是使⽤很⼩的卷积核(3×3)构建卷积神经⽹络结构,能够取得较好的识别精
度,常⽤来提取图像特征的VGG-16和VGG-19。

1.VGG的网络架构

VGG可以看成是加深版的AlexNet,整个⽹络由卷积层和全连接层叠加⽽成,和AlexNet不同的是,VGG中使⽤的都是⼩尺⼨的卷积核(3×3),其⽹络架构如下图所示:
在这里插入图片描述
VGGNet使⽤的全部都是3x3的⼩卷积核和2x2的池化核,通过不断加深⽹络来提升性能。VGG可以通过重复使⽤简单的基础块来构建深度模型。
在这里插入图片描述
在tf.keras中实现VGG模型,⾸先来实现VGG块,它的组成规律是:连续使⽤多个相同的填充为1、卷积核⼤⼩为33的卷积层后接上⼀个步幅为2、窗⼝形状为22的最⼤池化层。卷积层保持输⼊的⾼和宽不变,⽽池化层则对其减半。我们使⽤ vgg_block 函数来实现这个基础的VGG块,它可以指定卷积层的数量 num_convs 和每层的卷积核个数num_filters:

# num_convs表示有几个卷积层,num_filters表示卷积核的个数
def vgg_block(num_convs, num_filters):
    # 构建序列模型
    blk = tf.keras.models.Sequential()
    #遍历卷积层
    for _ in range(num_convs):
        #设置卷积层
        blk.add(tf.keras.layers.Conv2D(num_filters, kernel_size=3, padding="same", activation="relu"))
    #池化层、
    blk.add(tf.keras.layers.MaxPool2D(pool_size=2, strides=2))
    return blk

VGG16⽹络有5个卷积块,前2块使⽤两个卷积层,⽽后3块使⽤三个卷积层。第⼀块的输出通道是64,之后每次对输出通道数翻倍,直到变为512。

#卷积快的参数
conv_arch = ((2, 64), (2, 128), (3, 256), (3, 512), (3, 512))

因为这个⽹络使⽤了13个卷积层和3个全连接层,所以经常被称为VGG-16,通过制定conv_arch得到模型架构后构建VGG16:

def vgg(con_arch):
    #序列模型构建
    net = tf.keras.models.Sequential()
    #生成卷积部分
    for(num_convs, num_filters) in con_arch:
        net.add(vgg_block(num_convs, num_filters))
    #全连接层
    net.add(tf.keras.models.Sequential(
        [
            #展开
            tf.keras.layers.Flatten(),
            #全连接层
            tf.keras.layers.Dense(4096, activation="relu"),
            #随机失活
            tf.keras.layers.Dropout(0.5),
            #全连接层
            tf.keras.layers.Dense(4096, activation="relu"),
            #随机失活
            tf.keras.layers.Dropout(0.5),
            #输出层
            tf.keras.layers.Dense(10, activation="softmax")
        ]
    ))
    return net
# 网络实例化
net = vgg(conv_arch)

我们构造⼀个⾼和宽均为224的单通道数据样本来看⼀下模型的架构:

x = tf.random.uniform((1, 224, 224, 1))
y = net(x)
net.summary()

在这里插入图片描述

2.案例:手写数字识别

之后补充。。。。


http://www.kler.cn/a/395086.html

相关文章:

  • Golang学习历程【第三篇 基本数据类型类型转换】
  • Spring(三)-SpringWeb-概述、特点、搭建、运行流程、组件、接受请求、获取请求数据、特殊处理、拦截器
  • ansible play-book玩法
  • 门户系统需要压测吗?以及门户系统如何压力测试?
  • panddleocr-文本检测+文本方向分类+文本识别整体流程
  • AOP 面向切面编程的实现原理
  • 【算法】区间DP
  • A3超级计算机虚拟机,为大型语言模型LLM和AIGC提供强大算力支持
  • King3399(ubuntu文件系统)wifi设备树分析
  • 学习日志009--面向对象的编程
  • 前后端、网关、协议方面补充
  • 41页PPT | 华为业务流程架构全景视图:全业务域L1-L3级流程全案
  • python中父类和子类继承学习
  • Django处理前端请求的流程梳理
  • 通过命令学习k8s
  • ABAP开发学习——权限控制 实例1
  • PHP代码审计 - SQL注入
  • LeetCode面试经典150题C++实现,更新中
  • gcc 1.c和g++ 1.c编译阶段有什么区别?如何知道g++编译默认会定义_GNU_SOURCE?
  • Mysql篇-三大日志
  • Linux设置Nginx开机启动
  • http拉取git仓库,每次都要输入帐号密码,Ubuntu上记住帐号密码
  • 微积分复习笔记 Calculus Volume 1 - 5.5 Substitution
  • sqlserver 常用分页函数
  • ssh key的生成密钥
  • 区间数位和模板(贪心)