with torch.no_grad()
是 PyTorch 中用于临时禁用梯度计算的上下文管理器,其核心作用是优化推理或评估阶段的性能,具体特点如下:
1. 基本功能
• 禁用梯度追踪:在代码块内,所有操作不会被记录到计算图中
,因此不会产生梯度信息
,节省内存并加速计算。
• 恢复梯度状态:退出
上下文后,张量的 requires_grad
属性自动恢复
到进入前的状态,不影响后续梯度计算。
2. 典型应用场景
• 模型评估/推理:在验证或测试时,仅需前向传播,无需计算梯度
,可提升效率并减少显存占用。例如:
with torch.no_grad(