当前位置: 首页 > article >正文

Torch 自动求导

文章目录

  • Torch 自动求导
    • 一、链式求导法则
      • 1.1 原理
      • 1.2 示例
    • 二、传统求导
    • 三、自动求导
      • 3.1 原理
      • 3.2 代码实现

Torch 自动求导

一、链式求导法则

1.1 原理

标量链式法则:

y = f ( u ) , u = g ( x ) , ∂ y ∂ x = ∂ y ∂ u ∂ u ∂ x y=f(u),u=g(x),\frac{\partial y}{\partial x} =\frac{\partial y}{\partial u}\frac{\partial u}{\partial x} y=f(u),u=g(x),xy=uyxu

扩展到向量:

∂ y ∂ x ⃗ 1 × n = ∂ y ∂ u ∂ u ∂ x ⃗ m × n ∂ y ∂ x ⃗ 1 × n = ∂ y ∂ u ⃗ 1 × m ∂ u ⃗ ∂ x ⃗ m × n ∂ y ⃗ ∂ x ⃗ m × n = ∂ y ⃗ ∂ u ⃗ m × k ∂ u ⃗ ∂ x ⃗ k × n \begin{array}{l} \frac{\partial y}{\partial \vec{x}}_{1\times n} =\frac{\partial y}{\partial u}\frac{\partial u}{\partial \vec{x}} _{m\times n} \\ \frac{\partial y}{\partial \vec{x}}_{1\times n} =\frac{\partial y}{\partial \vec{u}} _{1\times m}\frac{\partial \vec{u}}{\partial \vec{x}} _{m\times n} \\ \frac{\partial \vec{y}}{\partial \vec{x}}_{m\times n} =\frac{\partial \vec{y}}{\partial \vec{u}} _{m\times k}\frac{\partial \vec{u}}{\partial \vec{x}} _{k\times n} \\ \end{array} x y1×n=uyx um×nx y1×n=u y1×mx u m×nx y m×n=u y m×kx u k×n

1.2 示例

如计算:

x ⃗ , w ⃗ ∈ R n , y ∈ R z = ( < x ⃗ , w ⃗ > − y ) 2 \begin{array}{l} \vec{x},\vec{w}\in \mathbb{R}^{n},y\in \mathbb{R} \\ z=(<\vec{x},\vec{w}>-y)^{2} \\ \end{array} x ,w Rn,yRz=(<x ,w >y)2
我们要计算: ∂ z ∂ w ⃗ \frac{\partial z}{\partial \vec{w}} w z

即:先进行换元
a = < x ⃗ , w ⃗ > b = a − y z = b 2 \begin{array}{l} \\ a=<\vec{x},\vec{w}> \\ b=a-y \\ z=b^{2} \end{array} a=<x ,w >b=ayz=b2
再进行链式求导:

∂ z ∂ w ⃗ = ∂ z ∂ b ∂ b ∂ a ∂ a ∂ w ⃗ = ∂ b 2 ∂ b ∂ ( a − y ) ∂ a ∂ < x ⃗ , w ⃗ > ∂ w ⃗ = 2 b ⋅ 1 ⋅ x T = 2 ( < x ⃗ , w ⃗ > − y ) x ⃗ T \begin{array}{l} \frac{\partial z}{\partial \vec{w}} &=\frac{\partial z}{\partial b} \frac{\partial b}{\partial a}\frac{\partial a}{\partial \vec{w}} \\ &=\frac{\partial b^{2}}{\partial b}\frac{\partial (a-y)}{\partial a}\frac{\partial <\vec{x},\vec{w}>}{\partial \vec{w}} \\ &=2b \cdot 1 \cdot x^{T} \\ &=2(<\vec{x},\vec{w}>-y)\vec{x}^{T} \end{array} w z=bzabw a=bb2a(ay)w <x ,w >=2b1xT=2(<x ,w >y)x T

二、传统求导

自动求导计算一个函数在指定值上的导数

符号求导,以及使用数值求导 ∂ f ( x ) ∂ x = lim ⁡ h → 0 f ( x + h ) − f ( x ) h \frac{\partial f(x)}{\partial x} =\lim_{ h \to 0 } \frac{f(x+h)-f(x)}{h} xf(x)=limh0hf(x+h)f(x)

sympy库

from sympy import symbols, diff

x = symbols("x")

f = x**2 + 3*x + 2

df_dx = diff(f, x)
print(df_dx)

print("f在x=1处的导数值为:", df_dx.subs(x, 1))

1739445549_9qotv15a78.png1739445549064.png

scipy 库

from scipy.differentiate  import derivative

# 定义函数
def f(x):
    return x**2 + 3*x + 2

# 计算函数在某点的导数
x = 1
df_dx = derivative(f, x)
print("f(x)在x=1处的导数为:", df_dx["df"])
print(df_dx)

1739445983_ou03by5ock.png1739445982266.png

使用极限的方法计算导数:

# 定义函数
def f(x):
    return x**2 + 3*x + 2

# 计算函数在某点的导数
x = 1
h = 1e-6
df_dx = (f(x + h) - f(x)) / h
print(df_dx)

三、自动求导

3.1 原理

将代码分解成操作子
将计算表示成一个无环图

1739446340_aslpa4dflw.png1739446339416.png

自动求导的两种模式:

  • 链式法则: ∂ y ∂ x = ∂ y ∂ u n ∂ u n ∂ u n − 1 ⋯ ∂ u 2 ∂ u 1 ∂ u 1 ∂ x \frac{\partial y}{\partial x} =\frac{\partial y}{\partial u_{n}}\frac{\partial u_{n}}{\partial u_{n-1}}\cdots\frac{\partial u_{2}}{\partial u_{1}}\frac{\partial u_{1}}{\partial x} xy=unyun1unu1u2xu1
  • 正向积累: ∂ y ∂ x = ∂ y ∂ u n ( ∂ u n ∂ u n − 1 ( ⋯ ( ∂ u 2 ∂ u 1 ∂ u 1 ∂ x ) ) ) \frac{\partial y}{\partial x} =\frac{\partial y}{\partial u_{n}}(\frac{\partial u_{n}}{\partial u_{n-1}}(\cdots(\frac{\partial u_{2}}{\partial u_{1}}\frac{\partial u_{1}}{\partial x}))) xy=uny(un1un((u1u2xu1)))
  • 反向积累: ∂ y ∂ x = ( ( ( ∂ y ∂ u n ∂ u n ∂ u n − 1 ) ⋯   ) ∂ u 2 ∂ u 1 ) ∂ u 1 ∂ x \frac{\partial y}{\partial x} =(((\frac{\partial y}{\partial u_{n}}\frac{\partial u_{n}}{\partial u_{n-1}})\cdots)\frac{\partial u_{2}}{\partial u_{1}})\frac{\partial u_{1}}{\partial x} xy=(((unyun1un))u1u2)xu1

1739446704_opgxitxzl5.png1739446703970.png

3.2 代码实现

假设我们对函数 y = 2 X T X y=2X^{T}X y=2XTX进行关于 x x x的求导:

import torch

x = torch.arange(4.0, requires_grad=True)  # 单独开辟一个空间进行梯度的存储
print("x的值为:", x.tolist())  
y = 2 * x@x  # 进行y的构建
y.backward()  # 进行反向传播求导数  dy = 4 x dx
print("y的梯度为:", x.grad.tolist())  # x.grad 获得梯度
x.grad.zero_()  # 在默认情况下,pytorch会累积梯度,我们需要清除之前的值,可以自己试一下有无这一行代码的效果
k = x@x
k.backward()
x.grad

1739447576_nsxrpzcmaw.png1739447575444.png


http://www.kler.cn/a/547923.html

相关文章:

  • 复杂电磁环境下无人机自主导航增强技术研究报告——地磁匹配与多源数据融合方法,附matlab代码
  • docker 基础命令使用(ubuntu)
  • Python说课内容介绍
  • 微软 Microsoft Windows Office Professional LTSC 2024 专业增强版
  • vue开发时,用localStorage常用方法及存储数组方法。
  • SpringBoot如何配置开发环境(JDK、Maven、IDEA等)
  • React - 高阶函数-函数柯里化
  • EasyX学习笔记1:线条
  • 2025年如何选择合适的微服务工具
  • Hive的动态分区的原理
  • 【C++干货分享】集合 位运算
  • C++ references
  • SQLMesh 系列教程4- 详解模型特点及模型类型
  • TongETLV3.0安装指引(by lqw)
  • 1-8 gitee码云的注册与使用
  • OpenAI发布新模型及会员订阅计划:o3-mini、GPT-4.5与GPT-5的全新体验
  • 如何学习Elasticsearch(ES):从入门到精通的完整指南
  • 【读点论文】Rewrite the Stars将svm的核技巧映射到高维空间,从数理逻辑中丰富特征维度维度
  • 【MySQL】第五弹---数据类型全解析:从基础到高级应用
  • 标贝科技参编国内首个AIGC大模型功能测试标准