当前位置: 首页 > article >正文

从0开始深度学习(16)——暂退法(Dropout)

上一章的过拟合是由于数据不足导致的,但如果我们有比特征多得多的样本,深度神经网络也有可能过拟合

1 扰动的稳健性

经典泛化理论认为,为了缩小训练和测试性能之间的差距,应该以简单的模型为目标,即模型以较小的维度的形式呈现。

简单性的另一个角度是平滑性,即函数不应该对其输入的微小变化敏感。例如,当我们对图像进行分类时,我们预计向像素添加一些随机噪声应该是基本无影响的。

在2014年,斯里瓦斯塔瓦等人提出了一个想法: 在训练过程中,他们建议在计算后续层之前向网络的每一层注入噪声。 因为当训练一个有多层的深层网络时,注入噪声只会在输入-输出映射上增强平滑性,这个想法被称为暂退法(dropout)

暂退法在前向传播过程中,计算每一内部层的同时注入噪声,这已经成为训练神经网络的常用技术。

如何注入这种噪声? 一种想法是以一种无偏向(unbiased)的方式注入噪声。 这样在固定住其他层时,每一层的期望值等于没有噪音时的值。

2 实践中的暂退法

以多层感知机为例,当我们将暂退法应用到隐藏层,以 P P P的概率将隐藏单元置为零时, 结果可以看作一个只包含原始神经元子集的网络。比如在下图中,删除了 h 2 h_{2} h2 h 5 h_{5} h5,并且它们各自的梯度在执行反向传播时也会消失
在这里插入图片描述
通常,我们在测试时不用暂退法。

3 从零实现暂退法

要实现单层的暂退法函数, 我们从均匀分布 U [ 0 , 1 ] U[0,1] U[0,1]中抽取样本,样本数与这层神经网络的维度一致。 然后我们保留那些对应样本大于 p p p的节点,把剩下的丢弃。

在下面的代码中,我们实现 dropout_layer 函数, 该函数以dropout的概率丢弃张量输入X中的元素, 如上所述重新缩放剩余部分:将剩余部分除以1.0-dropout。

import torch
from torch import nn
from d2l import torch as d2l


def dropout_layer(X, dropout):
    assert 0 <= dropout <= 1
    # 当等于1时,所有元素都被丢弃
    if dropout == 1:
        return torch.zeros_like(X)
    # 当等于0时,所有元素都被保留
    if dropout == 0:
        return X
    # 用0去填补丢弃的元素的位置
    mask = (torch.rand(X.shape) > dropout).float()
    return mask * X / (1.0 - dropout)

以将暂退法应用于每个隐藏层的输出(在激活函数之后), 并且可以为每一层分别设置暂退概率: 常见的技巧是在靠近输入层的地方设置较低的暂退概率。

4 调用API实现暂退法

net = nn.Sequential(nn.Flatten(),
        nn.Linear(784, 256),
        nn.ReLU(),
        # 在第一个全连接层之后添加一个dropout层
        nn.Dropout(dropout1),
        nn.Linear(256, 256),
        nn.ReLU(),
        # 在第二个全连接层之后添加一个dropout层
        nn.Dropout(dropout2),
        nn.Linear(256, 10))

def init_weights(m):
    if type(m) == nn.Linear:
        nn.init.normal_(m.weight, std=0.01)

net.apply(init_weights);

http://www.kler.cn/news/367569.html

相关文章:

  • 生成式 AI 与向量搜索如何扩大零售运营:巨大潜力尚待挖掘
  • 【C++】函数的返回、重载以及匹配、函数指针
  • 使用 ASP.NET Core 8.0 创建最小 API
  • 51单片机STC8G串口Uart配置
  • <Project-11 Calculator> 计算器 0.3 年龄计算器 age Calculator HTML JS
  • 为什么需要MQ?MQ具有哪些作用?你用过哪些MQ产品?请结合过往的项目经验谈谈具体是怎么用的?
  • C++笔记---位图
  • PHP如何抛出和接收错误
  • C语言[求x的y次方]
  • 7.hyperf安装【Docker】
  • 京东电商下单黄金链路:防止订单重复提交与支付的深度解析
  • Pseudo Multi-Camera Editing 数据集:通过常规视频生成的伪标记多摄像机推荐数据集,显著提升模型在未知领域的准确性。
  • 背包九讲——混合背包问题
  • 虾类图像分割系统:改进亮点优化
  • 前端项目接入sqlite轻量级数据库sql.js指南
  • ffmpeg视频滤镜: 色温- colortemperature
  • Windows 11 绕过 TPM 方法总结,24H2 通用免 TPM 镜像下载 (Updated Oct 2024)
  • java项目之在线考试系统设计与实现(springboot)
  • 通过AWS Bedrock探索 Claude 的虚拟桌面魔力:让 AI 代替你动手完成任务!
  • 时间数据可视化基础实验(南丁格尔玫瑰图)——Python热狗大胃王比赛数据集
  • 蓝桥杯普及题
  • Android中导入讯飞大模型ai智能系统
  • nodejs写入日志文件
  • Linux: Shell编程中的应用之基于sh进行数据统计
  • 【C++ 真题】B2106 矩阵转置
  • 基于java SpringBoot和Vue校园求职招聘系统设计