PyTorch中的autocast:混合精度训练的实现原理
PyTorch中的autocast
:混合精度训练的实现原理
在深度学习中,尤其是在训练大型模型时,计算资源和显存消耗往往是非常关键的因素。为了优化这些问题,混合精度训练(Mixed Precision Training)应运而生。autocast
是 PyTorch 提供的一个工具,用于自动处理混合精度训练中的数值类型选择,使得计算能在尽量减少精度损失的同时,提升性能。
1. 什么是autocast
?
autocast
是 PyTorch 中用于启用自动混合精度的上下文管理器。它可以使代码中的指定部分自动选择合适的浮点数精度(例如 float16
或 bfloat16
),以提高计算效率并节省显存,同时尽量保持模型的训练精度。
- 目的:优化性能,减少显存占用。
- 实现方式:在指定的代码区域内,自动选择使用较低精度(如
float16
或bfloat16
)进行计算。计算结束后,返回高精度(如float32
)以进行梯度更新。
2. autocast
如何工作?
工作流程
autocast
基本上是一个上下文管理器(Context Manager),其工作原理如下:
-
进入
autocast
上下文:- 当代码进入
autocast
上下文时,PyTorch 会自动将相关操作(如矩阵乘法、卷积等)切换到 低精度浮点数(通常为float16
或bfloat16
),以提高计算速度和节省显存。
- 当代码进入
-
操作类型选择:
autocast
会根据硬件和设备类型(例如CUDA
或CPU
)自动选择合适的精度。对于CUDA
设备,通常使用float16
;而对于CPU
设备,使用bfloat16
。
-
返回到默认精度:
- 计算完成后,
autocast
会退出上下文并将所有变量恢复到 默认精度(通常是float32
)。这对于梯度计算和权重更新至关重要,因为在低精度下进行梯度计算可能会导致数值不稳定或精度损失。
- 计算完成后,
-
避免梯度和权重更新中的精度丢失:
- 在
autocast
内部进行的前向传播计算使用低精度(float16
或bfloat16
),但 梯度计算和权重更新 操作仍然在float32
精度下进行,以保证数值稳定性。
- 在
具体代码解析
以下是 PyTorch 中 autocast
的基本使用示例:
import torch
from torch import nn, optim
from torch.cuda.amp import autocast, GradScaler
# 创建一个模型和优化器
model = nn.Linear(10, 1).cuda()
optimizer = optim.Adam(model.parameters(), lr=1e-3)
# GradScaler 用于 Loss Scaling
scaler = GradScaler()
# 假设有一个训练循环
for epoch in range(10):
optimizer.zero_grad()
# 进入 autocast 上下文
with autocast(device_type="cuda"):
inputs = torch.randn(32, 10).cuda() # 输入数据
targets = torch.randn(32, 1).cuda() # 目标数据
outputs = model(inputs)
loss = nn.MSELoss()(outputs, targets)
# 使用 GradScaler 进行 Loss Scaling 和反向传播
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
3. autocast
代码解析
安装完torch之后,在类似于下面的路径下可以找到源码:[~/anaconda3/envs/xxx/lib/python3.10/site-packages/torch/amp/autocast_mode.py, 源码放到了文末。
我们深入分析autocast
的实现代码,理解它是如何工作的:
__enter__
方法:
当进入 autocast
上下文时,__enter__
方法会被调用。此方法完成以下操作:
-
保存当前的设置:
- 保存设备的
autocast
状态、数据类型(dtype)以及缓存设置。
- 保存设备的
-
启用
autocast
:- 通过
torch.set_autocast_enabled()
启用指定设备的autocast
,并设置数据类型(如float16
或bfloat16
)。
- 通过
-
开启缓存:
- 启用或禁用
autocast
的缓存功能,优化计算性能。
- 启用或禁用
def __enter__(self):
self.prev_cache_enabled = torch.is_autocast_cache_enabled()
self.prev = torch.is_autocast_enabled(self.device)
self.prev_fastdtype = torch.get_autocast_dtype(self.device)
torch.set_autocast_enabled(self.device, self._enabled)
torch.set_autocast_dtype(self.device, self.fast_dtype) # type: ignore[arg-type]
torch.autocast_increment_nesting()
torch.set_autocast_cache_enabled(self._cache_enabled)
如果读者对def __enter__(self)
函数中的set_autocast_dtype这样的函数实现感兴趣,可以参考笔者的另一篇博客:PyTorch中的__init__.pyi文件:作用与C++实现关系解析
__exit__
方法:
当退出 autocast
上下文时,__exit__
方法会被调用:
-
恢复原来的设置:
- 恢复原先的
autocast
状态、数据类型以及缓存设置。
- 恢复原先的
-
清理缓存:
- 在嵌套层级降至 0 时,调用
torch.clear_autocast_cache()
清理缓存,以释放内存。
- 在嵌套层级降至 0 时,调用
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any):
if torch.autocast_decrement_nesting() == 0:
torch.clear_autocast_cache()
torch.set_autocast_enabled(self.device, self.prev)
torch.set_autocast_dtype(self.device, self.prev_fastdtype)
torch.set_autocast_cache_enabled(self.prev_cache_enabled)
return False
__call__
方法:
此外,autocast
还可以作为装饰器使用:
def __call__(self, func):
return autocast_decorator(self, func)
这个方法允许你将 autocast
直接应用到函数上,简化代码结构。
4. 使用 autocast
进行混合精度训练
在训练过程中,autocast
会自动选择适当的精度,以确保高效训练:
- 前向传播(Forward Pass):大部分操作使用低精度(
float16
或bfloat16
)进行加速计算。 - 反向传播(Backward Pass):梯度计算仍使用
float32
精度,避免由于精度不足引起的数值不稳定。 - 梯度更新:权重更新在
float32
精度下进行,确保模型稳定收敛。
5. autocast
与GradScaler
的结合
为了进一步提高混合精度训练的稳定性,PyTorch 还提供了 GradScaler
,用来进行 Loss Scaling。其目的是防止低精度计算中的梯度溢出或下溢:
GradScaler
将损失放大,然后进行反向传播,确保梯度数值保持在合理范围内。- 通过
scaler.scale(loss).backward()
和scaler.step(optimizer)
来执行带有缩放的反向传播和优化器步骤。
总结
autocast
是 PyTorch 提供的一种自动混合精度工具,可以在训练过程中自动选择适当的数据类型,从而加速计算和减少显存使用。autocast
在前向传播时使用低精度(如float16
或bfloat16
),但梯度计算和权重更新始终保持float32
精度,以保证数值稳定性。- 通过与
GradScaler
配合使用,能够确保混合精度训练在节省资源的同时,避免精度丢失或梯度下溢问题。
使用 autocast
使得大规模深度学习模型的训练更高效,同时保持较高的精度和稳定性,特别适用于高性能计算环境中的训练任务。
附录:pytorch源码
class autocast:
r"""
Instances of :class:`autocast` serve as context managers or decorators that
allow regions of your script to run in mixed precision.
In these regions, ops run in an op-specific dtype chosen by autocast
to improve performance while maintaining accuracy.
See the :ref:`Autocast Op Reference<autocast-op-reference>` for details.
When entering an autocast-enabled region, Tensors may be any type.
You should not call ``half()`` or ``bfloat16()`` on your model(s) or inputs when using autocasting.
:class:`autocast` should wrap only the forward pass(es) of your network, including the loss
computation(s). Backward passes under autocast are not recommended.
Backward ops run in the same type that autocast used for corresponding forward ops.
Example for CUDA Devices::
# Creates model and optimizer in default precision
model = Net().cuda()
optimizer = optim.SGD(model.parameters(), ...)
for input, target in data:
optimizer.zero_grad()
# Enables autocasting for the forward pass (model + loss)
with torch.autocast(device_type="cuda"):
output = model(input)
loss = loss_fn(output, target)
# Exits the context manager before backward()
loss.backward()
optimizer.step()
See the :ref:`Automatic Mixed Precision examples<amp-examples>` for usage (along with gradient scaling)
in more complex scenarios (e.g., gradient penalty, multiple models/losses, custom autograd functions).
:class:`autocast` can also be used as a decorator, e.g., on the ``forward`` method of your model::
class AutocastModel(nn.Module):
...
@torch.autocast(device_type="cuda")
def forward(self, input):
...
Floating-point Tensors produced in an autocast-enabled region may be ``float16``.
After returning to an autocast-disabled region, using them with floating-point
Tensors of different dtypes may cause type mismatch errors. If so, cast the Tensor(s)
produced in the autocast region back to ``float32`` (or other dtype if desired).
If a Tensor from the autocast region is already ``float32``, the cast is a no-op,
and incurs no additional overhead.
CUDA Example::
# Creates some tensors in default dtype (here assumed to be float32)
a_float32 = torch.rand((8, 8), device="cuda")
b_float32 = torch.rand((8, 8), device="cuda")
c_float32 = torch.rand((8, 8), device="cuda")
d_float32 = torch.rand((8, 8), device="cuda")
with torch.autocast(device_type="cuda"):
# torch.mm is on autocast's list of ops that should run in float16.
# Inputs are float32, but the op runs in float16 and produces float16 output.
# No manual casts are required.
e_float16 = torch.mm(a_float32, b_float32)
# Also handles mixed input types
f_float16 = torch.mm(d_float32, e_float16)
# After exiting autocast, calls f_float16.float() to use with d_float32
g_float32 = torch.mm(d_float32, f_float16.float())
CPU Training Example::
# Creates model and optimizer in default precision
model = Net()
optimizer = optim.SGD(model.parameters(), ...)
for epoch in epochs:
for input, target in data:
optimizer.zero_grad()
# Runs the forward pass with autocasting.
with torch.autocast(device_type="cpu", dtype=torch.bfloat16):
output = model(input)
loss = loss_fn(output, target)
loss.backward()
optimizer.step()
CPU Inference Example::
# Creates model in default precision
model = Net().eval()
with torch.autocast(device_type="cpu", dtype=torch.bfloat16):
for input in data:
# Runs the forward pass with autocasting.
output = model(input)
CPU Inference Example with Jit Trace::
class TestModel(nn.Module):
def __init__(self, input_size, num_classes):
super().__init__()
self.fc1 = nn.Linear(input_size, num_classes)
def forward(self, x):
return self.fc1(x)
input_size = 2
num_classes = 2
model = TestModel(input_size, num_classes).eval()
# For now, we suggest to disable the Jit Autocast Pass,
# As the issue: https://github.com/pytorch/pytorch/issues/75956
torch._C._jit_set_autocast_mode(False)
with torch.cpu.amp.autocast(cache_enabled=False):
model = torch.jit.trace(model, torch.randn(1, input_size))
model = torch.jit.freeze(model)
# Models Run
for _ in range(3):
model(torch.randn(1, input_size))
Type mismatch errors *in* an autocast-enabled region are a bug; if this is what you observe,
please file an issue.
``autocast(enabled=False)`` subregions can be nested in autocast-enabled regions.
Locally disabling autocast can be useful, for example, if you want to force a subregion
to run in a particular ``dtype``. Disabling autocast gives you explicit control over
the execution type. In the subregion, inputs from the surrounding region
should be cast to ``dtype`` before use::
# Creates some tensors in default dtype (here assumed to be float32)
a_float32 = torch.rand((8, 8), device="cuda")
b_float32 = torch.rand((8, 8), device="cuda")
c_float32 = torch.rand((8, 8), device="cuda")
d_float32 = torch.rand((8, 8), device="cuda")
with torch.autocast(device_type="cuda"):
e_float16 = torch.mm(a_float32, b_float32)
with torch.autocast(device_type="cuda", enabled=False):
# Calls e_float16.float() to ensure float32 execution
# (necessary because e_float16 was created in an autocasted region)
f_float32 = torch.mm(c_float32, e_float16.float())
# No manual casts are required when re-entering the autocast-enabled region.
# torch.mm again runs in float16 and produces float16 output, regardless of input types.
g_float16 = torch.mm(d_float32, f_float32)
The autocast state is thread-local. If you want it enabled in a new thread, the context manager or decorator
must be invoked in that thread. This affects :class:`torch.nn.DataParallel` and
:class:`torch.nn.parallel.DistributedDataParallel` when used with more than one GPU per process
(see :ref:`Working with Multiple GPUs<amp-multigpu>`).
Args:
device_type(str, required): Device type to use. Possible values are: 'cuda', 'cpu', 'xpu' and 'hpu'.
The type is the same as the `type` attribute of a :class:`torch.device`.
Thus, you may obtain the device type of a tensor using `Tensor.device.type`.
enabled(bool, optional): Whether autocasting should be enabled in the region.
Default: ``True``
dtype(torch_dtype, optional): Data type for ops run in autocast. It uses the default value
(``torch.float16`` for CUDA and ``torch.bfloat16`` for CPU), given by
:func:`~torch.get_autocast_dtype`, if :attr:`dtype` is ``None``.
Default: ``None``
cache_enabled(bool, optional): Whether the weight cache inside autocast should be enabled.
Default: ``True``
"""
def __init__(
self,
device_type: str,
dtype: Optional[_dtype] = None,
enabled: bool = True,
cache_enabled: Optional[bool] = None,
):
if not isinstance(device_type, str):
raise ValueError(
f"Expected `device_type` of type `str`, got: `{type(device_type)}`"
)
if dtype is None:
dtype = torch.get_autocast_dtype(device_type)
if torch._jit_internal.is_scripting():
self._enabled = enabled
self.device = device_type
self.fast_dtype = dtype
assert dtype is not None
return
self.device = device_type
if not is_autocast_available(self.device):
raise RuntimeError(
f"User specified an unsupported autocast device_type '{self.device}'"
)
self.custom_backend_name = torch._C._get_privateuse1_backend_name()
self.fast_dtype = torch.get_autocast_dtype(self.device)
if self.device == self.custom_backend_name:
necessary_funcs = [
"get_amp_supported_dtype",
]
message = f"Tried to use AMP with the `{self.custom_backend_name}` backend, but the backend has not "
message += "registered a module or the module miss some necessary funcs. The backend should register "
message += "a module by `torch._register_device_module`, and the module must have these funcs: \n"
message += "`get_amp_supported_dtype() -> List[torch.dtype]`. \n"
assert hasattr(torch, self.custom_backend_name), message
self.custom_device_mod = getattr(torch, self.custom_backend_name)
for func in necessary_funcs:
assert hasattr(self.custom_device_mod, func), (
message + f"But the func `{func}` is missing. \n"
)
self._cache_enabled = torch.is_autocast_cache_enabled()
if (
enabled
and torch.cuda.amp.common.amp_definitely_not_available()
and self.device == "cuda"
):
warnings.warn(
"User provided device_type of 'cuda', but CUDA is not available. Disabling"
)
enabled = False
if dtype is not None:
self.fast_dtype = dtype
if cache_enabled is not None:
self._cache_enabled = cache_enabled
if self.device == "cpu":
supported_dtype = [torch.bfloat16, torch.float16]
if self.fast_dtype not in supported_dtype and enabled:
error_message = "In CPU autocast, but the target dtype is not supported. Disabling autocast.\n"
error_message += "CPU Autocast only supports dtype of "
error_message += (
", ".join(str(dtype) for dtype in supported_dtype) + " currently."
)
warnings.warn(error_message)
enabled = False
elif self.device == "xpu":
supported_dtype = [torch.bfloat16, torch.float16]
if self.fast_dtype not in supported_dtype:
error_message = "In XPU autocast, but the target dtype is not supported. Disabling autocast.\n"
error_message += "XPU Autocast only supports dtypes of torch.bfloat16 and torch.float16 currently."
warnings.warn(error_message)
enabled = False
elif self.device == "ipu":
supported_dtypes = [torch.bfloat16, torch.float16]
if self.fast_dtype not in supported_dtypes:
error_message = "In IPU autocast, but the target dtype is not supported. Disabling autocast.\n"
error_message += "IPU Autocast only supports dtypes of torch.bfloat16 and torch.float16 currently."
warnings.warn(error_message)
enabled = False
elif self.device == "hpu":
supported_dtype = [torch.bfloat16, torch.float16]
if self.fast_dtype not in supported_dtype:
error_message = "In HPU autocast, but the target dtype is not supported. Disabling autocast.\n"
error_message += "HPU Autocast only supports dtypes of torch.bfloat16 and torch.float16 currently."
warnings.warn(error_message)
enabled = False
elif self.device == self.custom_backend_name:
supported_dtype = self.custom_device_mod.get_amp_supported_dtype()
if self.fast_dtype not in supported_dtype:
error_message = f"In {self.custom_backend_name} autocast, but the target dtype is not supported. "
error_message += f"Disabling autocast.\n {self.custom_backend_name} Autocast only supports dtypes of "
error_message += (
", ".join(str(dtype) for dtype in supported_dtype) + " currently."
)
warnings.warn(error_message)
enabled = False
elif self.device == "cuda":
if (
enabled
and self.fast_dtype == torch.bfloat16
and not torch.cuda.is_bf16_supported()
):
raise RuntimeError(
"Current CUDA Device does not support bfloat16. Please switch dtype to float16."
)
elif self.device == "xla":
supported_dtype = [torch.float16, torch.bfloat16]
if self.fast_dtype not in supported_dtype:
error_message = "In XLA autocast, but the target dtype is not supported. Disabling autocast.\n"
error_message += (
"XLA Autocast only supports dtype of torch.bfloat16 currently."
)
warnings.warn(error_message)
enabled = False
self._enabled = enabled
def __enter__(self):
if torch._jit_internal.is_scripting():
assert self.fast_dtype is not None
return self
self.prev_cache_enabled = torch.is_autocast_cache_enabled()
self.prev = torch.is_autocast_enabled(self.device)
self.prev_fastdtype = torch.get_autocast_dtype(self.device)
torch.set_autocast_enabled(self.device, self._enabled)
torch.set_autocast_dtype(self.device, self.fast_dtype) # type: ignore[arg-type]
torch.autocast_increment_nesting()
torch.set_autocast_cache_enabled(self._cache_enabled)
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any): # type: ignore[override]
if torch._jit_internal.is_scripting():
return
# Drop the cache when we exit to a nesting level that's outside any instance of autocast.
if torch.autocast_decrement_nesting() == 0:
torch.clear_autocast_cache()
torch.set_autocast_enabled(self.device, self.prev)
torch.set_autocast_dtype(self.device, self.prev_fastdtype)
torch.set_autocast_cache_enabled(self.prev_cache_enabled)
return False
def __call__(self, func):
if torch._jit_internal.is_scripting():
return func
return autocast_decorator(self, func)
后记
2024年12月31日21点10分于上海, 在GPT4o大模型辅助下完成。