当前位置: 首页 > article >正文

droppath

DropPath 是一种用于正则化深度学习模型的技术,它在训练过程中随机丢弃路径(或者说随机让某些部分的输出变为零),从而增强模型的鲁棒性和泛化能力。

代码解释:

import torch
import torch.nn as nn

# 定义 DropPath 类
class DropPath(nn.Module):
    def __init__(self, drop_prob=None):
        super().__init__()
        self.drop_prob = drop_prob

    def forward(self, x):
        return drop_path(x, self.drop_prob, self.training)

def drop_path(x, drop_prob: float = 0., training: bool = False):
#drop_path(输入,将drop_prob初始化为0., 判断是否为训练模式)
    if drop_prob == 0. or not training:
        return x
#如果drop_prob等于0或者不是训练模式直接将输入输出
    keep_prob = 1 - drop_prob
#保留的概率
    shape = (x.shape[0],) + (1,) * (x.ndim - 1) 
# 形状:(batch_size, 1, 1, ...)
# x.shape[0]获取xshape的第一维也就是batch_size
# (1,) * (x.ndim - 1) 将shape用1填充和x的形状一样
    random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
# torch.rand(shape, dtype=x.dtype, device=x.device)生成随机数(生成均值为0,标准差为1的正态# 分布随机数)形状和shape一致的也就是和x一致,数据类型,设备都和x一致
# 将随机数和keep_prob相加得到随机数(范围[keep_prob,1+keep_prob])
    random_tensor.floor_() 
# 二值化,生成 0 或 1 的 mask
# 也就是将随机数向下取整
    output = x.div(keep_prob) * random_tensor
#x.div(keep_prob)将输入张量x的所有值除以keep_prob,目的是 放大保留下来的部分

#* random_tensor根据0 或 1 的 mask决定哪些路径会被保留(1)或丢弃(0)
    return output

为什么要放大保留下来的部分:

  • 丢弃路径会导致部分值被置为零,模型整体输出的总期望值会下降。
  • 为了补偿这种下降,需要对保留下来的部分放大,使得丢弃路径后的总期望值和丢弃前一致。

因为只是补偿所以并不一定等与原期望

数学解释:

假设输入张量是 x=\begin{bmatrix} x_{1,}&x_{2,} & ... &, x_{n} \end{bmatrix},其中每个元素 xi表示特征。

期望:E=\frac{1}{n}\sum_{1}^{n}x_{i}

丢弃之后:E=\frac{1}{n}\sum_{1}^{n}{keepprob}\cdot x_{i}

放大之后:E=\frac{1}{n}\sum_{1}^{n}\frac{​{keepprob}\cdot x_{i}}{keepprob}=\frac{1}{n}\sum_{1}^{n}x_{i}

实例:

import torch
import torch.nn as nn

# 定义 DropPath 类
class DropPath(nn.Module):
    def __init__(self, drop_prob=None):
        super().__init__()
        self.drop_prob = drop_prob

    def forward(self, x):
        return drop_path(x, self.drop_prob, self.training)

def drop_path(x, drop_prob: float = 0., training: bool = False):
    if drop_prob == 0. or not training:
        return x
    keep_prob = 1 - drop_prob
    shape = (x.shape[0],) + (1,) * (x.ndim - 1)  # 形状:(batch_size, 1, 1, ...)
    random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
    random_tensor.floor_()  # 二值化,生成 0 或 1 的 mask
    print(f'mask: {random_tensor}')
    output = x.div(keep_prob) * random_tensor
    return output

# 定义简单模型
class SimpleModel(nn.Module):
    def __init__(self, drop_prob):
        super().__init__()
        self.linear = nn.Linear(4, 4)  # 简单的线性层
        self.drop_path = DropPath(drop_prob)  # 使用 DropPath
        self.activation = nn.ReLU()  # ReLU 激活

    def forward(self, x):
        print("输入数据:")
        print(x)

        x = self.linear(x)  # 线性层
        print("线性层输出:")
        print(x)

        x = self.activation(x)  # ReLU 激活
        print("激活后输出:")
        print(x)

        x = self.drop_path(x)  # DropPath
        print("DropPath 后输出:")
        print(x)

        return x

# 创建模型
model = SimpleModel(drop_prob=0.5)
model.train()  # 设置为训练模式以启用 DropPath

# 输入数据
input_data = torch.tensor([[1.0, 2.0, 3.0, 4.0],
                           [5.0, 6.0, 7.0, 8.0]], dtype=torch.float32)

# 运行模型
output = model(input_data)

输出: 简单理解就是根据mask的1,0值对每个样本进行保留或置零

输入数据:
tensor([[1., 2., 3., 4.],
        [5., 6., 7., 8.]])
线性层输出:
tensor([[ 1.2836, -1.4602,  2.2660, -1.7250],
        [ 1.3035, -4.1391,  4.5453, -2.5738]], grad_fn=<AddmmBackward0>)
激活后输出:
tensor([[1.2836, 0.0000, 2.2660, 0.0000],
        [1.3035, 0.0000, 4.5453, 0.0000]], grad_fn=<ReluBackward0>)
mask: tensor([[1.],
        [0.]])
DropPath 后输出:
tensor([[2.5672, 0.0000, 4.5321, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000]], grad_fn=<MulBackward0>)

 

 


http://www.kler.cn/a/420557.html

相关文章:

  • 游戏引擎学习第25天
  • C_字符串的一些函数
  • 《智能体雏形开发(高阶实操)》开发计划概述
  • 使用 LLaMA-Factory 微调
  • 【closerAI ComfyUI】物体转移术之图案转移,Flux三重控制万物一致性生图,实现LOGO和图案的精准迁移
  • 图像与文字的创意融合:使用Python进行视觉艺术创作
  • Qt的定时器应用案例 || Qt的图片添加显示
  • 2017 NHOI小学(C++)
  • MySQL 单表练习
  • C#中的集合初始化器
  • TongRDS分布式内存数据缓存中间件
  • 《数据结构》学习系列——图(下)
  • flink学习(14)—— 双流join
  • Redis开发05:使用stackexchange.redis库对redis进行增删改查
  • 前端【9种前端常见的设计模式】
  • 详解Qt Pdf之QPdfBookmarkModel 读取pdf标签页并显示
  • 创建 EC2块存储磁盘并将其连接到 Linux 实例
  • Vue3.5新版本特性一览-数组操作10倍性能提升+响应式属性解构+自定义组件优化+ssr水合改善+teleport支持defer!
  • Maven、JAVAWeb、Servlet
  • CS144 (二)
  • Redhat8部署docker27.3.0 防火墙策略怎样配置
  • 使用pymupdf提取PDF文档中的文字和其颜色
  • 前端基础的讲解-JS(18)
  • CentOS修改yum.repos.d源,避免“Could not resolve host: mirrorlist.centos.org”错误
  • 【C++】多线程
  • 如何成为一名优秀的炼丹师(三)