当前位置: 首页 > article >正文

基于PyTorch的深度学习3——非标量反向传播

以下内容由大语言模型生成:

张量,边代表从输入到输出的操作。当我们对一个标量调用backward()函数时,PyTorch能够直接计算该标量关于参与运算的所有张量的梯度。这是因为导数的概念最开始是针对标量值定义的。

为什么需要传入gradient参数?

当你有一个非标量(即形状不为1的张量)并希望对其求导时,实际上你是在尝试计算这个张量关于某些变量的雅可比矩阵(Jacobian Matrix)。然而,PyTorch并不直接支持这种操作,因为它复杂化了自动微分的流程。为了简化这个问题,PyTorch要求在这种情况下提供一个额外的gradient参数,这个参数同样是一个张量,并且其形状必须与调用backward()的那个张量相同。这个gradient实际上扮演了一个“权重”的角色,用于将多维导数的问题转化为标量导数的问题。

如何理解这个过程?

考虑你的例子:假设你有一个损失向量loss=(y_1, y_2, ..., y_m),你想要根据它来更新一些模型参数。由于loss不是一个标量,你不能直接对它调用backward()。这时,你可以引入一个向量v=(v_1, v_2, ..., v_m),然后将lossv进行点乘得到一个新的标量loss*v^T(这里v^T表示v的转置,虽然在实际代码中我们不会这样写,这只是为了表达数学概念)。这个新生成的标量可以被用来调用backward()方法,从而触发梯度的计算。

具体来说,这样做实际上是计算了loss的雅可比矩阵与v的乘积。换句话说,原本你需要计算的是雅可比矩阵,但现在通过点乘转换后,你只需计算一个标量关于所需变量的梯度。这使得PyTorch的自动微分机制能够处理这种情况,而不需要直接支持张量对张量的求导。

1)定义叶子节点及计算节点。

import torch

#定义叶节点张量x,形状为1x2
x=torch.tensor([[2.3]],dtype=torch.float,requires_grad=True)

#初始化Jacobian矩阵
J=torch.zeros(2,2)

#初始化目标张量,形状为1x2
y=torch.zeros(1,2)

#定义y与x之间的映射关系:
#y1=x1**2+3*x2,y2=x2**2+2*x1
y[0,0]=x[0,0]**2+3*x[0,1]
y[0,1]=x[0,1]**2+2*x[0,0]

2)手工计算y对x的梯度

y对x的梯度是一个雅可比矩阵,可以通过手动计算值

#生成y1对x的梯度
y.backward(torch.Tensor([[1, 0]]),retain_graph=True)
##gradient的作用:传入的gradient张量扮演了一个权重的角色,它决定了每个元素在最终梯度计算中的重要
##本质上,这是将雅可比矩阵乘以这个gradient向量,从而将多维导数的问题简化为一维标量导数的问题。

J[0]=x.grad

#梯度是累加的,故需要对x的梯度清零
x.grad = torch.zeros_like(x.grad)

#生成y2对x的梯度
y.backward(torch.Tensor([[0, 1]]))

J[1]=x.grad
#显示jacobian矩阵的值
print(J)


http://www.kler.cn/a/578857.html

相关文章:

  • Linux下的shell指令(二)
  • 计算机三级网络技术知识点汇总【7】
  • 打破界限!家电行业3D数字化营销,线上线下无缝对接
  • 【计算机网络】确认家庭网络是千兆/百兆带宽并排查问题
  • 深度解析:视频软编码与硬编码的优劣对比
  • 第十五届蓝桥杯省赛电子类单片机学习过程记录(客观题)
  • 旋转位置编码 (2)
  • 利用PHP爬虫根据关键词获取17网(17zwd)商品列表:实战指南
  • ABeam 德硕 | 中国汽车市场(1)——正在推进电动化的中国汽车市场
  • Android View设置圆角方式大全
  • LDR6500:革新手机OTG充电体验的关键芯片
  • MySQL快速使用Windows压缩包创建测试数据库
  • 基于javaweb的SpringBoot+MyBatis自习室座位管理系统设计和实现(源码+文档+部署讲解)
  • C#编译自动增加文件的版本号
  • 在Vue中 使用 Web Worker
  • 基于大数据的电影情感分析推荐系统
  • JVM_八股场景题
  • 从技术角度看大语言模型进化技术路线与落地应用详解:未来的最佳实践方向是什么?
  • mybaties中使用的设计模式
  • 介绍如何基于现有的可运行STGCN(Spatial-Temporal Graph Convolutional Network)模型代码进行交通流预测的改动