【机器学习】Zygote.jl
Zygote.jl
Zygote 是 Julia语言的微分项目,并且被 Flux.jl 收录作为机器学习训练环节中,执行所有梯度,雅可比矩阵等计算过程。
安装
]add Zygote
计算导数
f ( x ) = x 2 f(x)=x^2 f(x)=x2, x = 5 x=5 x=5
using Zygote
gradient(x -> x^2, 5)
(10.0,)
计算梯度
f ( a , b ) = a b f(a,b)=ab f(a,b)=ab, a = 2 , b = 3 a=2,b=3 a=2,b=3
gradient((a, b) -> a*b, 2, -3)
(-3.0, 2.0)
f
(
W
)
=
W
b
f(W)=Wb
f(W)=Wb
W = rand(2, 3); x = rand(3);
gradient(W -> sum(W*x), W)[1]
2×3 Array{Float64,2}:
0.0462002 0.817608 0.979036
0.0462002 0.817608 0.979036
f ( x ) = 3 x 2 + 2 x + 1 f(x)=3x^2+2x+1 f(x)=3x2+2x+1, x = 1 4 x=\frac{1}{4} x=41 要求有理数.
gradient(x -> 3x^2 + 2x + 1, 1//4)
(7//2,)
控制流结构
f ( x ) = x 3 f(x)=x^3 f(x)=x3, x = 5 x=5 x=5
function pow(x, n)
r = 1
for i = 1:n
r *= x
end
return r
end
pow (generic function with 1 method)
gradient(x -> pow(x, 3), 5)
(75.0,)
字典结构
f ( x ) = x 2 f(x)=x^2 f(x)=x2, x = 5 x=5 x=5
d = Dict()
Dict{Any, Any}()
gradient(5) do x
d[:x] = x
d[:x] * d[:x]
end
(10.0,)
d[:x]
5
结构与类型
定义结构 “Point”, 包括两个分量 Point.x, Point.y
并定义向量的加法与减法运算,以向量的欧式范数的平方作为目标函数
f
(
P
)
=
P
.
x
2
+
P
.
y
2
f(P)=\sqrt{P.x^2+P.y^2}
f(P)=P.x2+P.y2。
import Base: +, -
struct Point
x::Float64
y::Float64
end
a::Point + b::Point = Point(a.x + b.x, a.y + b.y)
a::Point - b::Point = Point(a.x - b.x, a.y - b.y)
dist(p::Point) = sqrt(p.x^2 + p.y^2)
a = Point(1, 2)
Point(1.0, 2.0)
b = Point(3, 4)
Point(3.0, 4.0)
dist(a + b)
7.211102550927978
gradient(a -> dist(a + b), a)[1]
(x = 0.5547001962252291, y = 0.8320502943378437)
雅可比矩阵
考虑映射 f ( a ) = 100 ( a 1 2 , a 2 2 , a 3 2 , 0 , 0 , 0 , 0 ) f(a)=100(a_1^2,a_2^2,a_3^2,0,0,0,0) f(a)=100(a12,a22,a32,0,0,0,0), 在点 ( 1 , 2 , 3 , 4 , 5 , 6 , 7 ) (1,2,3,4,5,6,7) (1,2,3,4,5,6,7) 处的雅克比矩阵。
jacobian(a -> 100*a[1:3].^2, 1:7)[1]
3×7 Matrix{Int64}:
200 0 0 0 0 0 0
0 400 0 0 0 0 0
0 0 600 0 0 0 0
考虑映射
f
(
a
,
x
)
=
(
a
1
2
x
,
a
2
2
x
,
a
3
2
x
)
f(a,x)=(a_1^2x,a_2^2x,a_3^2x)
f(a,x)=(a12x,a22x,a32x), 在点
a
=
(
1
,
2
,
3
)
,
x
=
1
a=(1,2,3),x=1
a=(1,2,3),x=1 处的雅克比矩阵。
jacobian((a,x) -> a.^2 .* x, [1,2,3], 1)
([2 0 0; 0 4 0; 0 0 6], [1, 4, 9])
考虑映射 f ( a , d ) = ( a 11 a 12 a 21 a 22 a 31 a 32 ) f(a,d)=\left(\begin{matrix} a_{11}a_{12}\\ a_{21}a_{22}\\ a_{31}a_{32}\end{matrix}\right) f(a,d)=⎝⎛a11a12a21a22a31a32⎠⎞ 在 a = ( 1 2 3 4 5 6 ) a=\left(\begin{matrix} 1& 2\\3& 4\\5&6\end{matrix}\right) a=⎝⎛135246⎠⎞, d = 2 d=2 d=2 处
jacobian((a,d) -> prod(a, dims=d), [1 2; 3 4; 5 6], 2)
([2 0 … 0 0; 0 4 … 3 0; 0 0 … 0 5], [0, 0, 0])
除了数与抽象数组,其他类型的微分结果都是 Nothing, 其他类型包括字符,字符串,元组等等。Hesse 矩阵 (Hessian)
求二元函数 f ( x ) = x 1 x 2 f(x)=x_1x_2 f(x)=x1x2 的 Hesse 矩阵
hessian(x -> x[1]*x[2], randn(2))
2×2 Matrix{Float64}:
0.0 1.0
1.0 0.0