pytorch nn.Dropout类介绍
在 PyTorch 中,nn.Dropout
是一种正则化方法,随机将输入张量的一部分元素置为零,以防止过拟合并提高模型的泛化能力。其基本用法如下:
import torch
import torch.nn as nn
dropout = nn.Dropout(p=0.5) # 丢弃概率为 50%
x = torch.ones((2, 3, 4)) # 输入张量
output = dropout(x) # 输出的部分元素会被置为零
- 它在训练阶段,对于输入张量中的每个元素,会以概率
p
将其置为 0。对于未被置为 0 的元素,需要进行数值缩放,缩放因子为1 / (1 - p)
。 - 在给定的代码中,
p = 0.5
,这意味着每个元素有 0.5 的概率被置为 0,而未被置为 0 的元素将乘以1 / (1 - 0.5)=2
。
注: 输入张量的每个元素会以概率p
将其置为 0,没有维度限制。
如何在指定维度上进行 Dropout?
PyTorch 的标准 nn.Dropout
无法直接指定某个维度进行 Dropout,但可以通过以下几种方法实现在指定维度共享 Dropout 掩码:
方法 1:自定义 Dropout 类(参考上文)
可以继承 nn.Module
,实现一个支持沿指定