当前位置: 首页 > article >正文

【深度学习系统】Lecture 4 - Automatic Differentiation

1.General introduction to different differentiation methods

在深入介绍自动微分方法之前,我们先来认识一件事情:为什么我们要提出不同的微分方法呢?

这就不得不提到微分在机器学习中所发挥作用的板块——计算当前需要优化的损失函数 l ( h θ ( x ) , y ) l(h_{\theta}(x),y) l(hθ(x),y)关于参数 θ \theta θ的梯度。计算出的梯度继而用于参数 θ \theta θ的更新,如下图所示:
在这里插入图片描述
计算梯度信息会是一个非常频繁的过程,这就牵扯到一个关键的问题——计算的效率。我们来看传统的 数值微分方法(Numerical differentiation) 的弊端:

  • 用定义计算偏导数(保留二阶误差): ∂ f ( θ ) ∂ θ i = lim ⁡ ϵ → 0 f ( θ + ϵ e i ) − f ( θ ) ϵ \frac{\partial f(\theta)}{\partial \theta_i} = \lim\limits_{\epsilon\to 0}\frac{f(\theta+\epsilon e_i)-f(\theta)}{\epsilon} θif(θ)=ϵ0limϵf(θ+ϵei)f(θ)
  • 用泰勒展开消除二阶误差(更精确): ∂ f ( θ ) ∂ θ i = f ( θ + ϵ e i ) − f ( θ − ϵ e i ) 2 ϵ + o ( ϵ 2 ) \frac{\partial f(\theta)}{\partial \theta_i} = \frac{f(\theta+\epsilon e_i)-f(\theta-\epsilon e_i)}{2\epsilon}+o(\epsilon^2) θif(θ)=2ϵf(θ+ϵei)f(θϵei)+o(ϵ2)

可以看到从定义出发,虽然它能够计算任意函数的导数,但是计算机的精度有限,依然存在误差,而且计算效率并不高,因为要对每一个参数都做一遍这样的计算,代价是十分昂贵的。但好处是它非常的简单直观,从而比较适合用来检查我们的自动微分框架的梯度计算是否被正确设计,比如作类似如下的单元测试:
在这里插入图片描述
另一种方法是 符号微分(Symbolic differentiation),如下图所示:在这里插入图片描述
其实也就是从公式法的角度计算函数微分,这样就保证了梯度计算的精确性。但同样有一个致命的缺陷:可能存在大量重复计算。举个例子:在这里插入图片描述
这个例子中 θ \theta θ是一个n维的向量,计算梯度时每一个维度实际上都进行了n-2次的重复计算,这些计算并不能被简单地重用。

自动微分方法(Automatic Differentiation) 是对符号微分的一种创新性改进:引入了计算图(Computational graph),以节点(node)的方式保存中间步骤从而避免了重复计算
在这里插入图片描述
下面我们来介绍第一代自动微分方法:Forward mode automatic differentiation (AD)

2.Forward mode automatic differentiation (AD)

在这里插入图片描述
我们需要盯住目标梯度: ∂ y ∂ x i \frac{\partial y}{\partial x_i} xiy,其中 x i x_i xi是我们关心的参数,我们通过符号微分的公式法注册中间步骤的简单微分算子。然后在前向传播的过程中能够同时计算出相应节点关于某一参数的微分。这样做乍一看很合理,但有一个毛病:一次AD trace只能计算出关于一个参数的梯度信息,如果有n个参数便要这样重复n次。

为了解决上述缺陷,我们将介绍第二代自动微分方法:Reverse mode automatic differentiation

3.Reverse mode automatic differentiation

在这里插入图片描述

这次我们的目标不再盯着 ∂ y ∂ x i \frac{\partial y}{\partial x_i} xiy,而是盯着 ∂ y ∂ v i \frac{\partial y}{\partial v_i} viy。从后往前推导,直至最后一步推导出 y y y关于所有 x i x_i xi的偏导数,反映在图中为: v 2 ‾ = ∂ y ∂ x 2 \overline{ v_2}=\frac{\partial y}{\partial x_2} v2=x2y v 1 ‾ = ∂ y ∂ x 1 \overline{ v_1}=\frac{\partial y}{\partial x_1} v1=x1y。这样我们整个反向传播只需遍历一次,即可获得所有我们想要的信息,这非常好。

这里我们需要注意这样一件事情:Derivation for the multiple pathway case

在这里插入图片描述
我们根据网络拓扑关系,做这样的记号: v i → j ‾ = v j ‾ ∂ v j ∂ v i \overline{v_{i\to j}}=\overline{v_j}\frac{\partial v_j}{\partial v_i} vij=vjvivj,于是就有:
v 1 ‾ = v 1 → 2 ‾ + v 1 → 3 ‾ \overline{v_1}=\overline{v_{1\to 2}}+\overline{v_{1\to 3}} v1=v12+v13从而可以总结出如下规律:
v i ‾ = ∑ j ∈ next ( i ) v i → j ‾ \overline{v_i}=\sum_{j\in \text{next}(i)}\overline {v_{i\to j}} vi=jnext(i)vij

算法实现逻辑伪代码:
在这里插入图片描述
完整代码实现:

import numpy as np


# 节点类
class Node:
    def __init__(self, value=None, grad=0, name=None):
        self.value = value
        self.grad = grad
        self.inputs = []
        self.outputs = []
        self.op = None
        self.op_prime = None  # 导数函数
        self.name = name  # 添加节点名字属性
    
    def __repr__(self):
        return f"Node(name={self.name}, value={self.value}, grad={self.grad}, op={self.op}, op_prime={self.op_prime})"


# 操作注册器
class OperationRegistry:
    def __init__(self):
        self.operations = {}

    def register_operation(self, name, operation):
        self.operations[name] = operation

    def get_operation(self, name):
        return self.operations.get(name)


# 拓扑排序(DFS)
def topological_sort(nodes):
    visited = set()
    sorted_nodes = []

    def visit(node):
        if node not in visited:
            visited.add(node)
            for input_node in node.inputs:
                visit(input_node)
            sorted_nodes.append(node)

    for node in nodes:
        visit(node)
    return sorted_nodes


# 逆拓扑排序
def reverse_topological_sort(nodes):
    return list(reversed(topological_sort(nodes)))


# 正向传播
def forward_pass(output_nodes):
    sorted_nodes = topological_sort(output_nodes)
    print("forward_pass:", sorted_nodes)
    for node in sorted_nodes:
        if not node.inputs:
            continue
        inputs = [input_node.value for input_node in node.inputs]
        operation = node.op
        node.value = operation(*inputs)
    return output_nodes[-1].value


def reverse_pass(output_nodes):
    sorted_nodes = reverse_topological_sort(output_nodes)
    print("reverse_pass:", sorted_nodes)
    output_nodes[0].grad = 1 # y的梯度为1
    for node in sorted_nodes[1:]:
        print(f"Node {node.name}: Initial grad = {node.grad}")
        node.grad = 0
        for output_node in node.outputs:
            prime_func = registry.get_operation(output_node.op.__name__ + "_prime")
            if prime_func.__code__.co_argcount == 1:
                print(f"Node {node.name}: Applying {output_node.op.__name__}_prime with value {node.value}")
                node.grad += output_node.grad * prime_func(node.value)
            elif prime_func.__code__.co_argcount >= 2:
                input_values = [input_node.value for input_node in output_node.inputs]
                derivative_result = prime_func(*input_values)
                if isinstance(derivative_result, tuple):
                    relevant_derivative = derivative_result[output_node.inputs.index(node)]
                    print(f"Node {node.name}: Applying {output_node.op.__name__}_prime with values {input_values}, relevant derivative = {relevant_derivative}")
                    node.grad += output_node.grad * relevant_derivative
                else:
                    print(f"Node {node.name}: Applying {output_node.op.__name__}_prime with values {input_values}")
                    node.grad += output_node.grad * derivative_result
        print(f"Node {node.name}: Final Grad  = {node.grad}")

# 恒等函数
def identity(x):
    return x

# 恒等函数的导数,prime符号是f'(x)中的(')
def identity_prime(x):
    return 1


def add(x, y):
    return x + y


def add_prime(x, y):
    return 1, 1


def mul(x, y):
    return x * y


def mul_prime(x, y):
    return y, x


def ln(x):
    return np.log(x)


def ln_prime(x):
    return 1 / x


def sin(x):
    return np.sin(x)


def sin_prime(x):
    return np.cos(x)


def sub(x, y):
    return x - y


def sub_prime(x, y):
    return 1, -1


# 测试
if __name__ == "__main__":
    registry = OperationRegistry()
    registry.register_operation("add", add)
    registry.register_operation("add_prime", add_prime)
    registry.register_operation("mul", mul)
    registry.register_operation("mul_prime", mul_prime)
    registry.register_operation("ln", ln)
    registry.register_operation("ln_prime", ln_prime)
    registry.register_operation("sin", sin)
    registry.register_operation("sin_prime", sin_prime)
    registry.register_operation("sub", sub)
    registry.register_operation("sub_prime", sub_prime)
    registry.register_operation("identity", identity)
    registry.register_operation("identity_prime", identity_prime)

    # 定义输入节点
    x1 = Node(2, name="x1")
    x2 = Node(5, name="x2")

    # 定义中间节点
    v1 = Node(name="v1")
    v2 = Node(name="v2")
    v3 = Node(name="v3")
    v4 = Node(name="v4")
    v5 = Node(name="v5")
    v6 = Node(name="v6")
    v7 = Node(name="v7")

    # 定义输出节点
    y = Node(name="y")

    # 设置节点关系和操作
    v1.inputs = [x1]
    v1.op = registry.get_operation("identity")
    v1.op_prime = registry.get_operation("identity_prime")
    x1.outputs.append(v1)

    v2.inputs = [x2]
    v2.op = registry.get_operation("identity")
    v2.op_prime = registry.get_operation("identity_prime")
    x2.outputs.append(v2)

    v3.inputs = [v1]
    v3.op = registry.get_operation("ln")
    v3.op_prime = registry.get_operation("ln_prime")
    v1.outputs.append(v3)

    v4.inputs = [v1, v2]
    v4.op = registry.get_operation("mul")
    v4.op_prime = registry.get_operation("mul_prime")
    v1.outputs.append(v4)
    v2.outputs.append(v4)

    v5.inputs = [v2]
    v5.op = registry.get_operation("sin")
    v5.op_prime = registry.get_operation("sin_prime")
    v2.outputs.append(v5)

    v6.inputs = [v3, v4]
    v6.op = registry.get_operation("add")
    v6.op_prime = registry.get_operation("add_prime")
    v3.outputs.append(v6)
    v4.outputs.append(v6)

    v7.inputs = [v6, v5]
    v7.op = registry.get_operation("sub")
    v7.op_prime = registry.get_operation("sub_prime")
    v6.outputs.append(v7)
    v5.outputs.append(v7)

    y.inputs = [v7]
    y.op = registry.get_operation("identity")
    y.op_prime = registry.get_operation("identity_prime")
    v7.outputs.append(y)

    # 正向传播
    result = forward_pass([y])
    print("Forward pass result:", result)

    # 反向传播
    reverse_pass([y])
    print("Gradient of x1:", x1.grad)
    print("Gradient of x2:", x2.grad)

在这里插入图片描述

第二代微分方法相比第一代微分方法最大的特点是对计算图进行了扩展,使得新扩展出来的计算图的节点本身就是梯度信息。进而可以继续扩展出梯度的梯度计算图,从而计算二阶梯度。
在这里插入图片描述
还有一个原因是由于扩展出来的这部分仍然是一个计算图,可以进行计算上的优化。比如将图中由Identify(id)连接的两个节点合并为一个节点。

  • 扩展到高维矩阵:
    在这里插入图片描述

http://www.kler.cn/a/472902.html

相关文章:

  • Openwrt @ rk3568平台 固件编译实践(二)- ledeWRT版本
  • 【Linux】shell脚本编程
  • 网络安全常见的问题
  • 【微服务】7、分布式事务
  • 基于 GEE Sentinel-1 数据集提取水体
  • 高等数学-----极限、函数、连续
  • 左神算法基础巩固--4
  • ESP32 IDF VScode出现头文件“无法打开 源 文件 ”,并有红色下划线警告
  • Docker 容器运行后自动退出的解决方案
  • MySQL 分库分表实战(一)
  • 无网络时自动切换备用网络环境
  • C++二十三种设计模式之迭代器模式
  • Python爬虫基础——XPath表达式
  • ffmpeg之h264格式转yuv
  • WEBRTC前端播放 播放器组件封装
  • 【Linux】深入理解文件系统(超详细)
  • 自动化执行 SQL 脚本解决方案
  • 十六、Vue 组件
  • 《深入浅出HTTPS​​​​​​​​​​​​​​​​​》读书笔记(26):数字签名
  • 【数据结构-堆】【二分】力扣3296. 移山所需的最少秒数
  • 牛客网刷题 ——C语言初阶(5操作符)——BC90 矩阵计算
  • 解决word桌面图标空白
  • UTTracker背景矫正模块详解:解决无人机追踪中的摄像头运动问题
  • Ruby语言的正则表达式
  • WebSocket 设计思路
  • 怎样用云手机进行海外社媒矩阵引流?