当前位置: 首页 > article >正文

昇思MindSpore进阶教程--高级自动微分

大家好,我是刘明,明志科技创始人,华为昇思MindSpore布道师。
技术上主攻前端开发、鸿蒙开发和AI算法研究。
努力为大家带来持续的技术分享,如果你也喜欢我的文章,就点个关注吧

正文开始

mindspore.ops模块提供的grad和value_and_grad接口可以生成网络模型的梯度。grad计算网络梯度,value_and_grad同时计算网络的正向输出和梯度。本文主要介绍如何使用grad接口的主要功能,包括一阶、二阶求导,单独对输入或网络权重求导,返回辅助变量,以及如何停止计算梯度。

一阶求导

计算一阶导数方法:mindspore.grad,其中参数使用方式为:

  • fn:待求导的函数或网络。
  • grad_position:指定求导输入位置的索引。若为int类型,表示对单个输入求导;若为tuple类型,表示对tuple内索引的位置求导,其中索引从0开始;若是None,表示不对输入求导,这种场景下,weights非None。默认值:0。
  • weights:训练网络中需要返回梯度的网络变量。一般可通过weights = net.trainable_params()获取。默认值:None。
  • has_aux:是否返回辅助参数的标志。若为True,fn输出数量必须超过一个,其中只有fn第一个输出参与求导,其他输出值将直接返回。默认值:False。
    下面先构建自定义网络模型Net,再对其进行一阶求导,通过这样一个例子对grad接口的使用方式做简单介绍,即公式:
    f ( x , y ) = x ∗ x ∗ y ∗ z (1) f(x, y)=x * x * y * z \tag{1} f(x,y)=xxyz(1)
    首先定义网络模型Net、输入x和输入y:
import numpy as np
from mindspore import ops, Tensor
import mindspore.nn as nn
import mindspore as ms

# 定义输入x和y
x = Tensor([3.0], dtype=ms.float32)
y = Tensor([5.0], dtype=ms.float32)


class Net(nn.Cell):
    def __init__(self):
        super(Net, self).__init__()
        self.z = ms.Parameter(ms.Tensor(np.array([1.0], np.float32)), name='z')

    def construct(self, x, y):
        out = x * x * y * self.z
        return out

对输入求一阶导

对输入x, y进行求导,需要将grad_position设置成(0, 1):
∂ f ∂ x = 2 ∗ x ∗ y ∗ z (2) \frac{\partial f}{\partial x}=2 * x * y * z \tag{2} xf=2xyz(2)
∂ f ∂ y = x ∗ x ∗ z (3) \frac{\partial f}{\partial y}=x * x * z \tag{3} yf=xxz(3)

net = Net()
grad_fn = ms.grad(net, grad_position=(0, 1))
gradients = grad_fn(x, y)
print(gradients)

对权重进行求导

对权重z进行求导,这里不需要对输入求导,将grad_position设置成None:
∂ f ∂ z = x ∗ x ∗ y (4) \frac{\partial f}{\partial z}=x * x * y \tag{4} zf=xxy(4)

params = ms.ParameterTuple(net.trainable_params())

output = ms.grad(net, grad_position=None, weights=params)(x, y)
print(output)

返回辅助变量

同时对输入和权重求导,其中只有第一个输出参与求导,示例代码如下:

net = nn.Dense(10, 1)
loss_fn = nn.MSELoss()


def forward(inputs, labels):
    logits = net(inputs)
    loss = loss_fn(logits, labels)
    return loss, logits


inputs = Tensor(np.random.randn(16, 10).astype(np.float32))
labels = Tensor(np.random.randn(16, 1).astype(np.float32))
weights = net.trainable_params()

# Aux value does not contribute to the gradient.
grad_fn = ms.grad(forward, grad_position=0, weights=None, has_aux=True)
inputs_gradient, (aux_logits,) = grad_fn(inputs, labels)
print(len(inputs_gradient), aux_logits.shape)

停止计算梯度

可以使用stop_gradient来停止计算指定算子的梯度,从而消除该算子对梯度的影响。

在上面一阶求导使用的矩阵相乘网络模型的基础上,再增加一个算子out2并禁止计算其梯度,得到自定义网络Net2,然后看一下对输入的求导结果情况。

示例代码如下:

class Net(nn.Cell):

    def __init__(self):
        super(Net, self).__init__()

    def construct(self, x, y):
        out1 = x * y
        out2 = x * y
        out2 = ops.stop_gradient(out2)  # 停止计算out2算子的梯度
        out = out1 + out2
        return out


net = Net()
grad_fn = ms.grad(net)
output = grad_fn(x, y)
print(output)

从上面的打印可以看出,由于对out2设置了stop_gradient,所以out2没有对梯度计算有任何的贡献,其输出结果与未加out2算子时一致。

下面删除out2 = stop_gradient(out2),再来看一下输出结果。示例代码为:

class Net(nn.Cell):
    def __init__(self):
        super(Net, self).__init__()

    def construct(self, x, y):
        out1 = x * y
        out2 = x * y
        # out2 = stop_gradient(out2)
        out = out1 + out2
        return out


net = Net()
grad_fn = ms.grad(net)
output = grad_fn(x, y)
print(output)

打印结果可以看出,把out2算子的梯度也计算进去之后,由于out2和out1算子完全相同,因此它们产生的梯度也完全相同,所以可以看到,结果中每一项的值都变为了原来的两倍(存在精度误差)。

高阶求导

高阶微分在AI支持科学计算、二阶优化等领域均有应用。如分子动力学模拟中,利用神经网络训练势能时,损失函数中需计算神经网络输出对输入的导数,则反向传播便存在损失函数对输入、权重的二阶交叉导数。

此外,AI求解微分方程(如PINNs方法)还会存在输出对输入的二阶导数。又如二阶优化中,为了能够让神经网络快速收敛,牛顿法等需计算损失函数对权重的二阶导数。


http://www.kler.cn/a/325404.html

相关文章:

  • mysql中mvcc如何处理纯读事务的?
  • CPU执行指令的过程
  • 用 Python 从零开始创建神经网络(五):损失函数(Loss Functions)计算网络误差
  • Android - Pixel 6a 手机OS 由 Android 15 降级到 Android 14 操作记录
  • Cuda和Pytorch的兼容性
  • 【C++】构造函数
  • 基于springboot+小程序的儿童预防接种预约管理系统(疫苗1)(源码+sql脚本+视频导入教程+文档)
  • 依赖倒置原则(学习笔记)
  • PostgreSQL的表碎片
  • 学习Java (五)
  • Go Sonyflake学习与使用
  • 新能源汽车充电桩怎么选?
  • Linux基础(二):磁盘分区
  • js替换css主题变量并切换iconfont文件
  • uniapp中h5环境添加console.log输出
  • 2024年7月大众点评沈阳美食店铺基础信息
  • 数据结构和算法之树形结构(4)
  • springframework Ordered接口学习
  • BOE(京东方)携故宫博物院举办2024“照亮成长路”公益项目落地仪式以创新科技赋能教育可持续发展
  • 计算机网络--TCP、UDP抓包分析实验
  • 2024年配置YOLOX运行环境+windows+pycharm24.0.1+GPU
  • [C语言]--自定义类型: 结构体
  • 【C/C++】错题记录(一)
  • pdf页面尺寸裁减
  • uni-app+vue3开发微信小程序使用本地图片渲染不出来报错[渲染层网络层错误]Failed to load local image resource
  • 黑马智数Day2