当前位置: 首页 > article >正文

Pytorch 自动微分注意点讲解

backward()

backward()函数是pytorch框架实现自动微分的关键函数,一般通过loss.backward()调用,这里的loss一般是标量张量

import numpy as np
import torch
device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')
print(device
)
data1 = torch.randint(0,10,(1,3),dtype=torch.float,requires_grad=True,device=device)
print(data1)
y = data1.pow(2)+100
print('y =',y)
loss = y.mean()
print('loss =',loss)
loss.backward()
print("data1's grad =",data1.grad)
# mps
# tensor([[1., 9., 0.]], device='mps:0', requires_grad=True)
# y = tensor([[101., 181., 100.]], device='mps:0', grad_fn=<AddBackward0>)
# loss = tensor(127.3333, device='mps:0', grad_fn=<MeanBackward0>)
# data1's grad = tensor([[0.6667, 6.0000, 0.0000]], device='mps:0')


可以看到这里模拟了一个函数计算,mean()模拟了损失计算,目的是将食粮张量转为标量张量

在对损失loss进行了反向传播后,叶子节点data1便有了grad属性.也就是2*data1

设备迁移注意点

import numpy as np
import torch
device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')
print(device
)
data1 = torch.randint(0,10,(1,3),dtype=torch.float,requires_grad=True).to(device)
print(data1)
y = data1.pow(2)+100
print('y =',y)
loss = y.mean()
print('loss =',loss)
loss.backward()
print("data1's grad =",data1.grad)
# mps
# tensor([[0., 9., 0.]], device='mps:0', grad_fn=<ToCopyBackward0>)
# y = tensor([[100., 181., 100.]], device='mps:0', grad_fn=<AddBackward0>)
# loss = tensor(127., device='mps:0', grad_fn=<MeanBackward0>)
# /Users/jinhouji/PycharmProjects/pythonProject/lesson01.py:13: UserWarning:
# 
# The .grad attribute of a Tensor that is not a leaf Tensor is being accessed. Its .grad attribute won't be populated during autograd.backward(). If you indeed want the .grad field to be populated for a non-leaf Tensor, use .retain_grad() on the non-leaf Tensor. If you access the non-leaf Tensor by mistake, make sure you access the leaf Tensor instead. See github.com/pytorch/pytorch/pull/30531 for more informations. (Triggered internally at /Users/runner/work/pytorch/pytorch/pytorch/build/aten/src/ATen/core/TensorBody.h:494.)
# 
# data1's grad = None


这里在创建完张量进行to(device)后,会出现打印梯度为None的情况,这是由于在张量创立后进行设备转移操作会导致grad_fn(也就是函数操作记录)的丢失,所以这个时候可以通过detach()+requires_grad_(True)函数来重新建立叶子节点

detach()

detach()函数可以建立一个与原张量共享内存但不进行梯度计算的全新张量

clone()

clone()函数可以拷贝一个和原张量具有相同计算图和张量值的全新张量

requires_grad_()

requires_grad_()函数可以将张量的梯度计算权限打开

is_leaf()

is_leaf()函数用于判断张量是否为叶子节点,返回布尔值

综上,我们可以尝试去重建叶子节点

import numpy as np
import torch
device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')
print(device
)
data1 = torch.randint(0,10,(1,3),dtype=torch.float,requires_grad=True).to(device).detach().requires_grad_(True)
print(data1)
print(data1.is_leaf)
y = data1.pow(2)+100
print('y =',y)
loss = y.mean()
print('loss =',loss)
loss.backward()
print("data1's grad =",data1.grad)
# mps
# tensor([[7., 8., 8.]], device='mps:0', requires_grad=True)
# True
# y = tensor([[149., 164., 164.]], device='mps:0', grad_fn=<AddBackward0>)
# loss = tensor(159., device='mps:0', grad_fn=<MeanBackward0>)
# data1's grad = tensor([[4.6667, 5.3333, 5.3333]], device='mps:0')


以上为使用了detach()和requires_grad()的方法,所以注意要将非叶子节点拆分为叶子节点的方法就是以上所示的流程

暂停梯度计算方法

一般在做模型推理和评估的时候需要暂停梯度计算,以下列举三种停止梯度计算的方法

with torch.no_grad():
import numpy as np
import torch
device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')
print(device
)
data1 = torch.randint(0,10,(1,3),dtype=torch.float,requires_grad=True).to(device).clone()
print(data1)
print(data1.is_leaf)
with torch.no_grad():
    y = data1.pow(2)+100
print(y.requires_grad)
# mps
# tensor([[0., 5., 5.]], device='mps:0', grad_fn=<CloneBackward0>)
# False
# False


with代码块下的语句涉及张量生成的代码,都不会进行梯度计算

@torch.no_grad()

@torch.no_grad()装饰器装饰的函数内部若进行张量生成则不会进行梯度计算

import numpy as np
import torch
device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')
print(device
)
data1 = torch.randint(0,10,(1,3),dtype=torch.float,requires_grad=True).to(device).clone()
print(data1)
print(data1.is_leaf)
@torch.no_grad()
def func11(data1):
    return data1.pow(2)+100
print(func11(data1).requires_grad)
# mps
# tensor([[3., 4., 4.]], device='mps:0', grad_fn=<CloneBackward0>)
# False
# False


torch.set_grad_enabled()

torch.set_grad_enabled()函数可以通过设置参数为False来关闭梯度计算,如需开启梯度计算则需要重新调用函数设置参数为True

import numpy as np
import torch
from torch import set_grad_enabled

device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')
print(device
)
data1 = torch.randint(0,10,(1,3),dtype=torch.float,requires_grad=True).to(device).clone()
print(data1)
print(data1.is_leaf)
torch.set_grad_enabled(False)
y=data1.pow(2)+100
print(y.requires_grad)
set_grad_enabled(True)
y=data1.pow(2)+100
print(y.requires_grad)
# mps
# tensor([[7., 8., 6.]], device='mps:0', grad_fn=<CloneBackward0>)
# False
# False
# True



http://www.kler.cn/a/283299.html

相关文章:

  • MyBatis CRUD快速入门
  • vue2或vue3的name属性有什么作用?
  • Autosar CP DDS规范导读
  • DApp开发:定制化解决方案与源码部署的一站式指南
  • PHP搭建开发环境(Windows系统)
  • 分享一个傻瓜式一键启动的加速器
  • 在 MySQL 中使用 `REPLACE` 函数
  • python实现蚁群算法
  • Google 插件推荐 50 个
  • 【数据库】两个集群数据实现同步方案
  • Python配置管理工具库之hydra使用详解
  • 机器学习—线性回归算法(Linear Regression)
  • 图结构与高级数据结构的学习笔记一
  • 语言的数据访问
  • 高性能4G灯杆网关,未来智慧城市的神经中枢
  • 【LeetCode面试150】——54螺旋矩阵
  • React Hooks 的高级用法
  • LuaJit分析(八)LuaJit预编译库函数加载过程
  • 【秋招笔试】8.21华为秋招-三语言题解
  • 算法训练营|图论第4天 110.字符串接龙 105.有向图的完全可达性 106.岛屿的周长
  • 网络原理 TCP与UDP协议
  • 本地构建spotbugs,替换gradle的默认仓库地址。
  • chapter08-面向对象编程——(Object类详解)——day09
  • 【C++ Primer Plus习题】7.5
  • Docker方式部署K8s集群
  • 灵神算法题单——不定长滑动窗口(求最长最大)