当前位置: 首页 > article >正文

pytorch retain_grad vs requires_grad

requires_grad大家都挺熟悉的,因此穿插在retain_grad的例子里进行捎带讲解就行。下面看一个代码片段:

import torch

# 创建一个标量 tensor,并开启梯度计算
x = torch.tensor(2.0, requires_grad=True)

# 中间计算:y 依赖于 x,是非叶子节点
y = x * 3

# 继续计算,得到 z
z = y * 4

# 反向传播
z.backward()

# 查看梯度
print("x.grad:", x.grad)  
print("y.grad:", y.grad)  

输出结果为:

x.grad: tensor(12.)
y.grad: None
/tmp/ipykernel_219007/1060175670.py:17: UserWarning: The .grad attribute of a Tensor that is not a leaf Tensor is being accessed. Its .grad attribute won't be populated during autograd.backward(). If you indeed want the .grad field to be populated for a non-leaf Tensor, use .retain_grad() on the non-leaf Tensor. If you access the non-leaf Tensor by mistake, make sure you access the leaf Tensor instead. See github.com/pytorch/pytorch/pull/30531 for more informations. (Triggered internally at aten/src/ATen/core/TensorBody.h:489.)
  print("y.grad:", y.grad)

警告的大致意思是:访问了非叶子节点的.grad属性,但非叶子节点的.grad属性并不会在反向传播的过程中被自动保存下来(这是为了节省内存,毕竟我们只需要计算那些手动设置.requires_gradTrue的张量的梯度,并进行梯度更新,对吧?)

因此,我们只需要添加一行代码y.retain_grad(),修改后的代码如下:

import torch

# 创建一个标量 tensor,并开启梯度计算
x = torch.tensor(2.0, requires_grad=True)

# 中间计算:y 依赖于 x,是非叶子节点
y = x * 3
y.retain_grad()

# 继续计算,得到 z
z = y * 4

# 反向传播
z.backward()

# 查看梯度
print("x.grad:", x.grad)  
print("y.grad:", y.grad)  

输出结果为:

x.grad: tensor(12.)
y.grad: tensor(4.)

可以看到,现在非叶子节点y的梯度也在反向传播以后被正确保存了!


http://www.kler.cn/a/578744.html

相关文章:

  • 电路研究9.3.1——合宙Air780EP中的AT开发指南:TCP 使用 SSL 示例
  • 关于VScode终端无法识别外部命令
  • mysql安装(演示为mac安装流程)
  • 使用 Python 批量提取 PDF 书签:一款实用工具的实现
  • Hadoop集群搭建(一)安装jdk
  • Nacos高频面试题10个
  • 深度学习与数据挖掘题库:401-500题精讲
  • 技术领域,有许多优秀的博客和网站
  • 基于PaddleNLP使用DeepSeek-R1搭建智能体
  • 【Linux篇】:Linux常用工具全解析--探索高效的工具宝藏
  • 生活反思公园散步与小雨遇记
  • Opencv之掩码实现答题卡识别及正确率判断
  • 《从零开始构建视频同步字幕播放软件》
  • React:Redux
  • Deeplabv3+改进1:添加CBAM注意力机制|有效涨点
  • 大道至简:道法自然的应用秘诀
  • Python实例:PyMuPDF实现PDF翻译,英文翻译为中文,并按段落创建中文PDF
  • 整理一下arcGis desktop版本软件, 从入门到精通需要学习的知识点
  • 苦瓜书盘官网,免费pdf/mobi电子书下载网站
  • PawSQL for MSSQL:PawSQL 支持 SQL Server 的SQL优化、SQL审核、性能巡检