PyTorch model.train() 与 model.eval() 的区别及其源码解析:中英双语
PyTorch model.train()
与 model.eval()
的区别及其源码解析
在深度学习模型的训练和推理过程中,model.train()
和 model.eval()
是两个常用的函数。它们分别用于设置模型的训练模式和评估模式(推理模式)。虽然这两个函数的用途看似简单,但它们对模型中某些层的行为(如 Dropout
层和 BatchNorm
层)具有重要影响。
本文将详细介绍 model.train()
和 model.eval()
的区别,并结合 PyTorch 源码进行深入解析。
1. model.train()
和 model.eval()
的基本作用
model.train()
:
model.train()
用于将模型设置为训练模式。这意味着:
- 归一化层(如 BatchNorm)使用每个批次的统计量。即在训练过程中,BatchNorm 层会计算并使用每个批次的均值和方差来规范化数据。
- 启用 Dropout 层。Dropout 在训练过程中随机丢弃一部分神经元,以防止过拟合。它会对每次前向传播做出不同的选择,从而增强模型的泛化能力。
model.eval()
:
model.eval()
用于将模型设置为评估模式(推理模式)。这意味着:
- 归一化层(如 BatchNorm)使用全局统计量。即在评估模式下,BatchNorm 层会使用在训练过程中累积的全局均值和方差,而不是每个批次的统计量。这是因为在推理时,通常没有足够的样本来计算准确的统计量,因此需要使用全局统计量。
- 禁用 Dropout 层。在推理时,我们希望每次前向传播都能获得稳定的输出,因此 Dropout 层在评估模式下会被禁用,确保神经网络的每一层都参与计算。
2. 源码解析:model.train()
和 model.eval()
的实现
我们可以通过查看 PyTorch 源码来深入理解这两个函数的实现。在 PyTorch 中,model.train()
和 model.eval()
实际上是通过调用 train()
函数来切换训练模式和评估模式。以下是 PyTorch 源码中的相关部分: https://github.com/pytorch/pytorch/blob/main/torch/nn/modules/module.py#L2824
def train(self: T, mode: bool = True) -> T:
r"""Set the module in training mode.
This has an effect only on certain modules. See the documentation of
particular modules for details of their behaviors in training/evaluation
mode, i.e., whether they are affected, e.g. :class:`Dropout`, :class:`BatchNorm`,
etc.
Args:
mode (bool): whether to set training mode (``True``) or evaluation
mode (``False``). Default: ``True``.
Returns:
Module: self
"""
if not isinstance(mode, bool):
raise ValueError("training mode is expected to be boolean")
self.training = mode
for module in self.children():
module.train(mode)
return self
train()
函数的实现首先会检查传入的 mode
参数是否为布尔值,如果不是,会抛出异常。然后,模型的 training
属性会被设置为 mode
,从而决定是否进入训练模式或评估模式。接下来,train()
会递归调用模型的所有子模块,确保模型的每一层都进入正确的模式。
def eval(self: T) -> T:
r"""Set the module in evaluation mode.
This has an effect only on certain modules. See the documentation of
particular modules for details of their behaviors in training/evaluation
mode, i.e., whether they are affected, e.g. :class:`Dropout`, :class:`BatchNorm`,
etc.
This is equivalent with :meth:`self.train(False) <torch.nn.Module.train>`.
Returns:
Module: self
"""
return self.train(False)
eval()
函数则通过调用 self.train(False)
来将模型设置为评估模式。实际上,eval()
函数的实现是 train(False)
的封装,它将模型的 training
属性设置为 False
,使得模型进入评估模式。
3. train()
和 eval()
对模型行为的影响
train()
和 eval()
的主要区别体现在以下几种模块的行为上:
1) Dropout 层
在训练模式下,Dropout
层会随机丢弃一部分神经元,这有助于防止过拟合。但在评估模式下,Dropout
会被禁用,确保每个神经元都参与计算,保证推理结果的稳定性。
2) BatchNorm 层
BatchNorm
层的行为也会有所不同。在训练模式下,BatchNorm
会使用当前批次的数据来计算均值和方差。而在评估模式下,BatchNorm
会使用在训练过程中积累的全局统计量(即所有训练批次的均值和方差)。这种行为的差异确保了推理时的稳定性,因为在推理过程中无法依赖于单个批次的统计量。
3) 其他层的行为
除了 Dropout
和 BatchNorm
,一些特定的层(如 InstanceNorm
和 RNN
)在训练和评估模式下也可能有不同的行为。例如,RNN
层在训练模式下可以进行自我连接(如在某些变体中),而在评估模式下,它会被禁用。
4. 使用场景
-
训练时:在训练过程中,通常需要使用
model.train()
来启用Dropout
和基于批次的BatchNorm
统计量。这样可以避免过拟合,并确保每个训练样本都能得到充分的处理。 -
评估时:在进行模型推理(如验证、测试、预测)时,应该使用
model.eval()
来禁用Dropout
,并使用训练期间获得的全局统计量来进行推理。这样可以确保模型输出稳定且一致。
5. 总结
model.train()
设置模型为训练模式,启用Dropout
层,并使用每个批次的统计量(如BatchNorm
)。model.eval()
设置模型为评估模式,禁用Dropout
层,并使用全局统计量。- 这两个函数对模型的行为有重要影响,特别是对
Dropout
和BatchNorm
等模块的处理,因此在进行推理或评估时,必须使用model.eval()
。
理解并正确使用 model.train()
和 model.eval()
是确保模型在训练和推理过程中表现一致和稳定的关键。
英文版
Differences Between model.train()
and model.eval()
in PyTorch: A Detailed Source Code Analysis
In deep learning model training and inference, model.train()
and model.eval()
are two commonly used functions. They are used to set the model into training mode and evaluation (inference) mode, respectively. Although the purpose of these functions seems simple, they have a significant impact on the behavior of certain layers in the model, such as the Dropout
layer and the BatchNorm
layer.
This article will provide a detailed explanation of the differences between model.train()
and model.eval()
, and will dive deeper into the PyTorch source code to provide a better understanding.
1. Basic Functions of model.train()
and model.eval()
model.train()
:
model.train()
is used to set the model into training mode. This means:
- Normalization layers (e.g., BatchNorm) use per-batch statistics. During training, BatchNorm layers calculate and use the mean and variance of each batch to normalize the data.
- Enables Dropout layers. Dropout randomly drops some neurons during training to prevent overfitting. It makes different choices for each forward pass, enhancing the model’s generalization ability.
model.eval()
:
model.eval()
is used to set the model into evaluation mode (inference mode). This means:
- Normalization layers (e.g., BatchNorm) use running statistics. In evaluation mode, BatchNorm layers use the global mean and variance accumulated during training instead of per-batch statistics. This is because, during inference, there may not be enough samples to calculate accurate statistics, so global statistics are used.
- Disables Dropout layers. During inference, we want the output to be stable, so Dropout is disabled to ensure every layer in the network is involved in the computation.
2. Source Code Analysis: Implementation of model.train()
and model.eval()
We can gain a deeper understanding of the implementation of model.train()
and model.eval()
by reviewing the PyTorch source code. In PyTorch, both model.train()
and model.eval()
are essentially functions that modify the model’s behavior by calling the train()
function to switch between training and evaluation modes. Below is the relevant part of the PyTorch source code:
def train(self: T, mode: bool = True) -> T:
r"""Set the module in training mode.
This has an effect only on certain modules. See the documentation of
particular modules for details of their behaviors in training/evaluation
mode, i.e., whether they are affected, e.g. :class:`Dropout`, :class:`BatchNorm`,
etc.
Args:
mode (bool): whether to set training mode (``True``) or evaluation
mode (``False``). Default: ``True``.
Returns:
Module: self
"""
if not isinstance(mode, bool):
raise ValueError("training mode is expected to be boolean")
self.training = mode
for module in self.children():
module.train(mode)
return self
The train()
function first checks whether the passed mode
argument is a boolean; if it’s not, an exception is raised. Then, it sets the model’s training
attribute to mode
, determining whether it should be in training mode or evaluation mode. After that, train()
recursively calls train(mode)
on all child modules, ensuring each layer of the model is set to the correct mode.
def eval(self: T) -> T:
r"""Set the module in evaluation mode.
This has an effect only on certain modules. See the documentation of
particular modules for details of their behaviors in training/evaluation
mode, i.e., whether they are affected, e.g. :class:`Dropout`, :class:`BatchNorm`,
etc.
This is equivalent with :meth:`self.train(False) <torch.nn.Module.train>`.
Returns:
Module: self
"""
return self.train(False)
The eval()
function sets the model into evaluation mode by calling self.train(False)
. In fact, eval()
is simply a wrapper for train(False)
, which sets the model’s training
attribute to False
, thereby switching the model to evaluation mode.
3. Effects of train()
and eval()
on Model Behavior
The main differences between train()
and eval()
are reflected in how certain layers behave:
1) Dropout Layer
In training mode, the Dropout
layer randomly drops neurons during the forward pass to prevent overfitting. However, in evaluation mode, Dropout
is disabled, ensuring that every neuron is involved in the forward pass to stabilize the output during inference.
2) BatchNorm Layer
The behavior of the BatchNorm
layer differs between training and evaluation modes. During training, BatchNorm uses the mean and variance of the current batch to normalize the data. In evaluation mode, BatchNorm uses the global mean and variance accumulated during training. This behavior ensures stability during inference, as it is not feasible to calculate accurate statistics from a single batch during inference.
3) Other Layers
Apart from Dropout
and BatchNorm
, other layers (such as InstanceNorm
and RNN
) may also behave differently in training and evaluation modes. For example, RNN
layers may have self-connections (in some variants) enabled during training but disabled during evaluation.
4. Use Cases
-
During Training: When training a model, you typically need to use
model.train()
to enableDropout
and use batch-based statistics forBatchNorm
. This helps prevent overfitting and ensures each training sample is properly processed. -
During Evaluation: When performing model inference (e.g., validation, testing, prediction), you should use
model.eval()
to disableDropout
and use the global statistics accumulated during training. This ensures that the output of the model is stable and consistent.
5. Conclusion
model.train()
sets the model to training mode, enabling theDropout
layer and using per-batch statistics forBatchNorm
.model.eval()
sets the model to evaluation mode, disabling theDropout
layer and using global statistics forBatchNorm
.- These functions have important effects on the model’s behavior, particularly on modules like
Dropout
andBatchNorm
, so it is crucial to usemodel.eval()
during inference or evaluation.
Understanding and correctly using model.train()
and model.eval()
is key to ensuring that your model performs consistently and reliably during both training and inference.
参考
[1] https://github.com/pytorch/pytorch/blob/main/torch/nn/modules/module.py#L2824
[2] https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.eval
[3] https://stackoverflow.com/questions/60018578/what-does-model-eval-do-in-pytorch
后记
2024年12月25日16点56分于上海,在GPT4o mini大模型辅助下完成。