【深度学习系统】Lecture 4 - Automatic Differentiation
1.General introduction to different differentiation methods
在深入介绍自动微分方法之前,我们先来认识一件事情:为什么我们要提出不同的微分方法呢?
这就不得不提到微分在机器学习中所发挥作用的板块——计算当前需要优化的损失函数
l
(
h
θ
(
x
)
,
y
)
l(h_{\theta}(x),y)
l(hθ(x),y)关于参数
θ
\theta
θ的梯度。计算出的梯度继而用于参数
θ
\theta
θ的更新,如下图所示:
计算梯度信息会是一个非常频繁的过程,这就牵扯到一个关键的问题——计算的效率。我们来看传统的 数值微分方法(Numerical differentiation) 的弊端:
- 用定义计算偏导数(保留二阶误差): ∂ f ( θ ) ∂ θ i = lim ϵ → 0 f ( θ + ϵ e i ) − f ( θ ) ϵ \frac{\partial f(\theta)}{\partial \theta_i} = \lim\limits_{\epsilon\to 0}\frac{f(\theta+\epsilon e_i)-f(\theta)}{\epsilon} ∂θi∂f(θ)=ϵ→0limϵf(θ+ϵei)−f(θ)
- 用泰勒展开消除二阶误差(更精确): ∂ f ( θ ) ∂ θ i = f ( θ + ϵ e i ) − f ( θ − ϵ e i ) 2 ϵ + o ( ϵ 2 ) \frac{\partial f(\theta)}{\partial \theta_i} = \frac{f(\theta+\epsilon e_i)-f(\theta-\epsilon e_i)}{2\epsilon}+o(\epsilon^2) ∂θi∂f(θ)=2ϵf(θ+ϵei)−f(θ−ϵei)+o(ϵ2)
可以看到从定义出发,虽然它能够计算任意函数的导数,但是计算机的精度有限,依然存在误差,而且计算效率并不高,因为要对每一个参数都做一遍这样的计算,代价是十分昂贵的。但好处是它非常的简单直观,从而比较适合用来检查我们的自动微分框架的梯度计算是否被正确设计,比如作类似如下的单元测试:
另一种方法是 符号微分(Symbolic differentiation),如下图所示:
其实也就是从公式法的角度计算函数微分,这样就保证了梯度计算的精确性。但同样有一个致命的缺陷:可能存在大量重复计算。举个例子:
这个例子中
θ
\theta
θ是一个n维的向量,计算梯度时每一个维度实际上都进行了n-2次的重复计算,这些计算并不能被简单地重用。
而 自动微分方法(Automatic Differentiation) 是对符号微分的一种创新性改进:引入了计算图(Computational graph),以节点(node)的方式保存中间步骤从而避免了重复计算:
下面我们来介绍第一代自动微分方法:Forward mode automatic differentiation (AD)
2.Forward mode automatic differentiation (AD)
我们需要盯住目标梯度:
∂
y
∂
x
i
\frac{\partial y}{\partial x_i}
∂xi∂y,其中
x
i
x_i
xi是我们关心的参数,我们通过符号微分的公式法注册中间步骤的简单微分算子。然后在前向传播的过程中能够同时计算出相应节点关于某一参数的微分。这样做乍一看很合理,但有一个毛病:一次AD trace只能计算出关于一个参数的梯度信息,如果有n个参数便要这样重复n次。
为了解决上述缺陷,我们将介绍第二代自动微分方法:Reverse mode automatic differentiation
3.Reverse mode automatic differentiation
这次我们的目标不再盯着 ∂ y ∂ x i \frac{\partial y}{\partial x_i} ∂xi∂y,而是盯着 ∂ y ∂ v i \frac{\partial y}{\partial v_i} ∂vi∂y。从后往前推导,直至最后一步推导出 y y y关于所有 x i x_i xi的偏导数,反映在图中为: v 2 ‾ = ∂ y ∂ x 2 \overline{ v_2}=\frac{\partial y}{\partial x_2} v2=∂x2∂y、 v 1 ‾ = ∂ y ∂ x 1 \overline{ v_1}=\frac{\partial y}{\partial x_1} v1=∂x1∂y。这样我们整个反向传播只需遍历一次,即可获得所有我们想要的信息,这非常好。
这里我们需要注意这样一件事情:Derivation for the multiple pathway case
我们根据网络拓扑关系,做这样的记号:
v
i
→
j
‾
=
v
j
‾
∂
v
j
∂
v
i
\overline{v_{i\to j}}=\overline{v_j}\frac{\partial v_j}{\partial v_i}
vi→j=vj∂vi∂vj,于是就有:
v
1
‾
=
v
1
→
2
‾
+
v
1
→
3
‾
\overline{v_1}=\overline{v_{1\to 2}}+\overline{v_{1\to 3}}
v1=v1→2+v1→3从而可以总结出如下规律:
v
i
‾
=
∑
j
∈
next
(
i
)
v
i
→
j
‾
\overline{v_i}=\sum_{j\in \text{next}(i)}\overline {v_{i\to j}}
vi=j∈next(i)∑vi→j
算法实现逻辑伪代码:
完整代码实现:
import numpy as np
# 节点类
class Node:
def __init__(self, value=None, grad=0, name=None):
self.value = value
self.grad = grad
self.inputs = []
self.outputs = []
self.op = None
self.op_prime = None # 导数函数
self.name = name # 添加节点名字属性
def __repr__(self):
return f"Node(name={self.name}, value={self.value}, grad={self.grad}, op={self.op}, op_prime={self.op_prime})"
# 操作注册器
class OperationRegistry:
def __init__(self):
self.operations = {}
def register_operation(self, name, operation):
self.operations[name] = operation
def get_operation(self, name):
return self.operations.get(name)
# 拓扑排序(DFS)
def topological_sort(nodes):
visited = set()
sorted_nodes = []
def visit(node):
if node not in visited:
visited.add(node)
for input_node in node.inputs:
visit(input_node)
sorted_nodes.append(node)
for node in nodes:
visit(node)
return sorted_nodes
# 逆拓扑排序
def reverse_topological_sort(nodes):
return list(reversed(topological_sort(nodes)))
# 正向传播
def forward_pass(output_nodes):
sorted_nodes = topological_sort(output_nodes)
print("forward_pass:", sorted_nodes)
for node in sorted_nodes:
if not node.inputs:
continue
inputs = [input_node.value for input_node in node.inputs]
operation = node.op
node.value = operation(*inputs)
return output_nodes[-1].value
def reverse_pass(output_nodes):
sorted_nodes = reverse_topological_sort(output_nodes)
print("reverse_pass:", sorted_nodes)
output_nodes[0].grad = 1 # y的梯度为1
for node in sorted_nodes[1:]:
print(f"Node {node.name}: Initial grad = {node.grad}")
node.grad = 0
for output_node in node.outputs:
prime_func = registry.get_operation(output_node.op.__name__ + "_prime")
if prime_func.__code__.co_argcount == 1:
print(f"Node {node.name}: Applying {output_node.op.__name__}_prime with value {node.value}")
node.grad += output_node.grad * prime_func(node.value)
elif prime_func.__code__.co_argcount >= 2:
input_values = [input_node.value for input_node in output_node.inputs]
derivative_result = prime_func(*input_values)
if isinstance(derivative_result, tuple):
relevant_derivative = derivative_result[output_node.inputs.index(node)]
print(f"Node {node.name}: Applying {output_node.op.__name__}_prime with values {input_values}, relevant derivative = {relevant_derivative}")
node.grad += output_node.grad * relevant_derivative
else:
print(f"Node {node.name}: Applying {output_node.op.__name__}_prime with values {input_values}")
node.grad += output_node.grad * derivative_result
print(f"Node {node.name}: Final Grad = {node.grad}")
# 恒等函数
def identity(x):
return x
# 恒等函数的导数,prime符号是f'(x)中的(')
def identity_prime(x):
return 1
def add(x, y):
return x + y
def add_prime(x, y):
return 1, 1
def mul(x, y):
return x * y
def mul_prime(x, y):
return y, x
def ln(x):
return np.log(x)
def ln_prime(x):
return 1 / x
def sin(x):
return np.sin(x)
def sin_prime(x):
return np.cos(x)
def sub(x, y):
return x - y
def sub_prime(x, y):
return 1, -1
# 测试
if __name__ == "__main__":
registry = OperationRegistry()
registry.register_operation("add", add)
registry.register_operation("add_prime", add_prime)
registry.register_operation("mul", mul)
registry.register_operation("mul_prime", mul_prime)
registry.register_operation("ln", ln)
registry.register_operation("ln_prime", ln_prime)
registry.register_operation("sin", sin)
registry.register_operation("sin_prime", sin_prime)
registry.register_operation("sub", sub)
registry.register_operation("sub_prime", sub_prime)
registry.register_operation("identity", identity)
registry.register_operation("identity_prime", identity_prime)
# 定义输入节点
x1 = Node(2, name="x1")
x2 = Node(5, name="x2")
# 定义中间节点
v1 = Node(name="v1")
v2 = Node(name="v2")
v3 = Node(name="v3")
v4 = Node(name="v4")
v5 = Node(name="v5")
v6 = Node(name="v6")
v7 = Node(name="v7")
# 定义输出节点
y = Node(name="y")
# 设置节点关系和操作
v1.inputs = [x1]
v1.op = registry.get_operation("identity")
v1.op_prime = registry.get_operation("identity_prime")
x1.outputs.append(v1)
v2.inputs = [x2]
v2.op = registry.get_operation("identity")
v2.op_prime = registry.get_operation("identity_prime")
x2.outputs.append(v2)
v3.inputs = [v1]
v3.op = registry.get_operation("ln")
v3.op_prime = registry.get_operation("ln_prime")
v1.outputs.append(v3)
v4.inputs = [v1, v2]
v4.op = registry.get_operation("mul")
v4.op_prime = registry.get_operation("mul_prime")
v1.outputs.append(v4)
v2.outputs.append(v4)
v5.inputs = [v2]
v5.op = registry.get_operation("sin")
v5.op_prime = registry.get_operation("sin_prime")
v2.outputs.append(v5)
v6.inputs = [v3, v4]
v6.op = registry.get_operation("add")
v6.op_prime = registry.get_operation("add_prime")
v3.outputs.append(v6)
v4.outputs.append(v6)
v7.inputs = [v6, v5]
v7.op = registry.get_operation("sub")
v7.op_prime = registry.get_operation("sub_prime")
v6.outputs.append(v7)
v5.outputs.append(v7)
y.inputs = [v7]
y.op = registry.get_operation("identity")
y.op_prime = registry.get_operation("identity_prime")
v7.outputs.append(y)
# 正向传播
result = forward_pass([y])
print("Forward pass result:", result)
# 反向传播
reverse_pass([y])
print("Gradient of x1:", x1.grad)
print("Gradient of x2:", x2.grad)
第二代微分方法相比第一代微分方法最大的特点是对计算图进行了扩展,使得新扩展出来的计算图的节点本身就是梯度信息。进而可以继续扩展出梯度的梯度计算图,从而计算二阶梯度。
还有一个原因是由于扩展出来的这部分仍然是一个计算图,可以进行计算上的优化。比如将图中由Identify(id)连接的两个节点合并为一个节点。
- 扩展到高维矩阵: