当前位置: 首页 > article >正文

动手学深度学习(pytorch)学习记录28-使用块的网络(VGG)[学习记录]

目录

  • VGG块
  • VGG网络
  • 训练模型

VGG块

定义了一个名为vgg_block的函数来实现一个VGG块

import torch
from torch import nn
from d2l import torch as d2l
def vgg_block(num_convs, in_channels, out_channels):
    layers = []
    for _ in range(num_convs):
        layers.append(nn.Conv2d(in_channels, out_channels,
                                kernel_size=3, padding=1))
        layers.append(nn.ReLU())
        in_channels = out_channels
    layers.append(nn.MaxPool2d(kernel_size=2,stride=2))
    return nn.Sequential(*layers)

VGG网络

与AlexNet、LeNet一样,VGG网络可以分为两部分:第一部分主要由卷积层和汇聚层组成,第二部分由全连接层组成。
VGG神经网络连接的几个VGG块(在vgg_block函数中定义)。其中有超参数变量conv_arch。该变量指定了每个VGG块里卷积层个数和输出通道数。全连接模块则与AlexNet中的相同。

原始VGG网络有5个卷积块,其中前两个块各有一个卷积层,后三个块各包含两个卷积层。 第一个模块有64个输出通道,每个后续模块将输出通道数量翻倍,直到该数字达到512。由于该网络使用8个卷积层和3个全连接层,因此它通常被称为VGG-11。

conv_arch = ((1, 64), (1, 128), (2, 256), (2, 512), (2, 512))

实现VGG11

def vgg(conv_arch):
    conv_blks = []
    in_channels = 1
    # 卷积层部分
    for (num_convs, out_channels) in conv_arch:
        conv_blks.append(vgg_block(num_convs, in_channels, out_channels))
        in_channels = out_channels

    return nn.Sequential(
        *conv_blks, nn.Flatten(),
        # 全连接层部分
        nn.Linear(out_channels * 7 * 7, 4096), nn.ReLU(), nn.Dropout(0.5),
        nn.Linear(4096, 4096), nn.ReLU(), nn.Dropout(0.5),
        nn.Linear(4096, 10))

net = vgg(conv_arch)

构建一个高度和宽度为224的单通道数据样本,以观察每个层输出的形状。

X = torch.randn(size=(1, 1, 224, 224))
for blk in net:
    X = blk(X)
    print(blk.__class__.__name__,'output shape:\t',X.shape)
Sequential output shape:	 torch.Size([1, 64, 112, 112])
Sequential output shape:	 torch.Size([1, 128, 56, 56])
Sequential output shape:	 torch.Size([1, 256, 28, 28])
Sequential output shape:	 torch.Size([1, 512, 14, 14])
Sequential output shape:	 torch.Size([1, 512, 7, 7])
Flatten output shape:	 torch.Size([1, 25088])
Linear output shape:	 torch.Size([1, 4096])
ReLU output shape:	 torch.Size([1, 4096])
Dropout output shape:	 torch.Size([1, 4096])
Linear output shape:	 torch.Size([1, 4096])
ReLU output shape:	 torch.Size([1, 4096])
Dropout output shape:	 torch.Size([1, 4096])
Linear output shape:	 torch.Size([1, 10])

训练模型

# 由于Fashion-MNIST数据集比较简单,缩放前面几个块的通道数,足够训练这个数据集
ratio = 4 # 设置比例
small_conv_arch = [(pair[0], pair[1] // ratio) for pair in conv_arch]
net = vgg(small_conv_arch)
X = torch.randn(size=(1, 1, 224, 224))
for blk in net:
    X = blk(X)
    print(blk.__class__.__name__,'output shape:\t',X.shape)
Sequential output shape:	 torch.Size([1, 16, 112, 112])
Sequential output shape:	 torch.Size([1, 32, 56, 56])
Sequential output shape:	 torch.Size([1, 64, 28, 28])
Sequential output shape:	 torch.Size([1, 128, 14, 14])
Sequential output shape:	 torch.Size([1, 128, 7, 7])
Flatten output shape:	 torch.Size([1, 6272])
Linear output shape:	 torch.Size([1, 4096])
ReLU output shape:	 torch.Size([1, 4096])
Dropout output shape:	 torch.Size([1, 4096])
Linear output shape:	 torch.Size([1, 4096])
ReLU output shape:	 torch.Size([1, 4096])
Dropout output shape:	 torch.Size([1, 4096])
Linear output shape:	 torch.Size([1, 10])
lr, num_epochs, batch_size = 0.05, 10, 128
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size, resize=224)
d2l.train_ch6(net, train_iter, test_iter, num_epochs, lr, d2l.try_gpu())
loss 0.181, train acc 0.933, test acc 0.919
378.9 examples/sec on cuda:0

在这里插入图片描述

· 本文使用了大量d2l包,这极大地减少了代码编辑量,需要安装d2l包才能运行本文代码
封面图片来源
欢迎点击我的主页查看更多文章。
本人学习地址https://zh-v2.d2l.ai/
恳请大佬批评指正。


http://www.kler.cn/news/305171.html

相关文章:

  • 【计算机网络】HTTPHTTPS
  • vue前端实现下载导入模板文件
  • 1405 问题 E: 世界杯
  • 基于深度学习的信号滤波:创新技术与应用挑战
  • PyTorch 和 TensorFlow
  • 【深度学习】神经网络-怎么分清DNN、CNN、RNN?
  • Anaconda pytorch-gpu CUDA CUDNN 安装指南
  • clickhouse 保证幂等性
  • 前端面试记录
  • mybatis-plu分页出现问题
  • JVM面试真题总结(九)
  • windows检查端口占用并关闭应用
  • git报错,error: bad signature 0x00000000fatal: index file corrupt
  • 3. 进阶指南:自定义 Prompt 提升大模型解题能力
  • 新手教学系列——用Nginx将页面请求分发到不同后端模块
  • 足球大小球及亚盘数据分析与机器学习实战详解:从数据清洗到模型优化
  • vue项目中引入组件时出现的Module is not installed问题
  • 上图为是否色发
  • 15、Python如何获取文件的状态
  • ARM V2处理器微架构分析
  • input和editor一起使用在ios上聚焦异常
  • 【计算机网络 - 基础问题】每日 3 题(四)
  • 目标检测中的解耦和耦合、anchor-free和anchor-base
  • 分销系统后端技术文档
  • 大数据Flink(一百一十八):SQL水印操作(Watermark)
  • Linux基础---07文件传输(网络和Win文件)
  • 9 递归——50. Pow(x, n) ★★
  • linux 操作系统下的curl 命令介绍和使用案例
  • docker如何实现资源隔离
  • Tomcat 版本怎么选?JMeter 真实压测多版本 Tomcat 数据给你最直接的参考,快收藏备用吧!