当前位置: 首页 > article >正文

深度学习 自动求梯度

 代码示例:

import torch

# 创建一个标量张量 x,并启用梯度计算
x = torch.tensor(3.0, requires_grad=True)

# 计算 y = x^2
y = torch.pow(x, 2)

# 判断 x 和 y 是否需要梯度计算
print(x.requires_grad)  # 输出 x 的 requires_grad 属性
print(y.requires_grad)  # 输出 y 的 requires_grad 属性

# 反向传播,计算 y 对 x 的导数
y.backward()

# 查看 x 的梯度
print(x.grad)  # 输出 x 的梯度

 代码详解:

代码一:
x = torch.tensor(3.0, requires_grad=True)

在这行代码中,创建了一个 PyTorch 张量 x。

解释

  1. torch.tensor(3.0):

    • 这部分创建了一个张量,其值为 3.0。这是一个标量张量,数据类型为浮点数(float)。
  2. requires_grad=True:

    • 这个参数指定 PyTorch 需要跟踪该张量的所有操作,以便在后续的反向传播过程中计算梯度。换句话说,设置 requires_grad=True 使得这个张量在执行任何操作后,能够计算其梯度。
 代码二:
y = torch.pow(x, 2)

解释

  1. torch.pow(x, 2):

    • 这是 PyTorch 中的一个函数,用于计算 x 的幂。在这里,x 被提高到 2 的幂,即计算 (x^2)。
    • 由于之前我们已经定义了 x = torch.tensor(3.0, requires_grad=True),所以 torch.pow(x, 2) 实际上会计算 (3.0^2),得到的结果是 9.0
  2. y = ...:

    • 将计算得到的结果 9.0 存储在张量 y 中。由于 x 的 requires_grad 属性为 True,PyTorch 会自动设置 y 的 requires_grad 属性为 True,使得 y 也可以用于梯度计算。
代码三: 
y.backward()

 y.backward() 是 PyTorch 中用于计算梯度的重要方法,它在反向传播过程中发挥着关键作用。

解释

  1. y.backward():
    • 这行代码触发反向传播,以计算损失函数 y 相对于输入张量 x 的梯度。
    • 在调用 backward() 方法之前,计算图已经构建完毕,y 是通过某些操作(例如 torch.pow(x, 2))生成的张量。
    • 当 backward() 被调用时,PyTorch 会从 y 开始,沿着计算图向后传播,计算所有需要计算梯度的张量的梯度。

自动微分

  • 在 PyTorch 中,backward() 使用自动微分(automatic differentiation)来计算梯度。这意味着系统会自动根据张量间的运算关系,利用链式法则来计算每个张量的梯度。

计算过程

  • 在之前的例子中,我们定义了 ( y = x^2 ) 并且 ( x = 3.0 )。
  • 在反向传播过程中,PyTorch 计算 ( \frac{dy}{dx} ) 的值:
    • 根据导数公式,( \frac{dy}{dx} = 2x )。
    • 对于 ( x = 3.0 ),因此 ( \frac{dy}{dx} = 2 \times 3.0 = 6.0 )。
  • 这个值会被存储在 x.grad 中,方便后续使用。 

print(x.grad) 语句用于输出张量 x 的梯度值。 

为了更好的理解什么是梯度,看下面示例代码: 

示例二:

import torch
x=torch.tensor(3.0,requires_grad=True)
y=torch.tensor(4.0,requires_grad=False)
z=torch.pow(x,2)+torch.pow(y,2)
print("x.requires_grad:",x.requires_grad)
print("y.requires_grad:",y.requires_grad)
print("z.requires_grad:",z.requires_grad)
z.backward()
print("x.grad:",x.grad)
print("y.grad:",y.grad)
print("z.grad:",z.grad)
print(z)

输出:

x.requires_grad: True
y.requires_grad: False
z.requires_grad: True
x.grad: tensor(6.)
y.grad: None
z.grad: None
tensor(25., grad_fn=<AddBackward0>)

输出解释

  1. 可求导性检查:

    • x.requires_grad: True 表示 x 是一个可求导的张量。
    • y.requires_grad: False 表示 y 不是可求导的张量。
    • z.requires_grad: True 表示 z 是可求导的,因为它是由可求导的张量 x 计算得出的。
  2. 梯度计算:

    • 调用 z.backward() 时,计算了 z 关于 x 的梯度。
    • y.grad 输出 None,因为 y 不可求导。
  3. 关于 z 的梯度:

    • z.grad 输出 None,这是因为 z 不是叶子节点。只有叶子节点的 grad 属性会被自动设置。
我们在运行此段代码时会遇到一个警告:

 大致意思是:

你在访问 z.grad 时遇到的警告提示你正在访问一个非叶子张量的梯度属性。此警告说明 z 不是一个叶子张量,因此其 .grad 属性在执行 backward() 时不会被填充。

叶子张量与非叶子张量

在 PyTorch 中,叶子张量(leaf tensors)是指那些没有任何历史计算的张量,通常是由用户直接创建的张量(例如通过 torch.tensor() 创建)。而 非叶子张量 是由其他张量经过操作计算得出的张量(例如加法、乘法等操作生成的结果)。

为了使非叶子张量的 .grad 属性被填充,你可以在计算图中使用 .retain_grad() 方法。这将允许你在调用 backward() 后访问非叶子张量的梯度。

请看修改后的示例三:

 示例三:

import torch

# 创建一个可求导的张量 x 和一个不可求导的张量 y
x = torch.tensor(3.0, requires_grad=True)  # x 可求导
y = torch.tensor(4.0, requires_grad=False) # y 不可求导

# 定义函数 z = f(x, y) = x^2 + y^2
z = torch.pow(x, 2) + torch.pow(y, 2)

# 让 z 保留梯度
z.retain_grad()

# 打印每个张量的 requires_grad 属性
print("x.requires_grad:", x.requires_grad)  # 输出: True
print("y.requires_grad:", y.requires_grad)  # 输出: False
print("z.requires_grad:", z.requires_grad)  # 输出: True

# 反向传播以计算梯度
z.backward()

# 打印 x 和 y 的梯度
print("x.grad:", x.grad)  # 输出: tensor(6.)
print("y.grad:", y.grad)  # 输出: None
print("z.grad:", z.grad)  # 输出: tensor(1.)

为什么z的梯度为1或者z的导为1?

z 对自身的导数为1

举个例子:

y=x**2;

y对于x的导为2*x;

y对于自身的导为1。


http://www.kler.cn/news/359525.html

相关文章:

  • kubernetes(k8s)面试之2024
  • Spring Boot:中小型医院网站开发新趋势
  • react18中如何实现同步的setState来实现所见即所得的效果
  • 【C语言】文件操作(2)(文件缓冲区和随机读取函数)
  • 当物理学奖遇上机器学习:创新融合的里程碑
  • Unity修改鼠标指针大小
  • nginx中的HTTP 负载均衡
  • 【python+Redis】hash修改
  • 真空探针台选型需知
  • Spring Boot:如何实现JAR包的直接运行
  • 首个统一生成和判别任务的条件生成模型框架BiGR:专注于增强生成和表示能力,可执行视觉生成、辨别、编辑等任务
  • Android Studio Ladybug指定ndk版本
  • python excel如何转成json,并且如何解决excel转成json时中文汉字乱码的问题
  • Mac 安装 Telnet 工具
  • Maven - Assembly实战
  • ubuntu 虚拟机将linux文件夹映射为windows网络位置
  • Openlayers高级交互(2/20):清除所有图层的有效方法
  • 01 springboot-整合日志(logback-config.xml)
  • 【H2O2|全栈】JS入门知识(五)
  • 前端报错:‘vue-cli-service‘ 不是内部或外部命令,也不是可运行的程序(node_modules下载不下来)