当前位置: 首页 > article >正文

【机器学习】Zygote.jl

Zygote.jl

Zygote 是 Julia语言的微分项目,并且被 Flux.jl 收录作为机器学习训练环节中,执行所有梯度,雅可比矩阵等计算过程。

安装

]add Zygote

计算导数

f ( x ) = x 2 f(x)=x^2 f(x)=x2, x = 5 x=5 x=5

using Zygote
gradient(x -> x^2, 5)

(10.0,)

计算梯度

f ( a , b ) = a b f(a,b)=ab f(a,b)=ab, a = 2 , b = 3 a=2,b=3 a=2,b=3

gradient((a, b) -> a*b, 2, -3)

(-3.0, 2.0)
f ( W ) = W b f(W)=Wb f(W)=Wb

W = rand(2, 3); x = rand(3);
gradient(W -> sum(W*x), W)[1]

2×3 Array{Float64,2}:
0.0462002 0.817608 0.979036
0.0462002 0.817608 0.979036

f ( x ) = 3 x 2 + 2 x + 1 f(x)=3x^2+2x+1 f(x)=3x2+2x+1, x = 1 4 x=\frac{1}{4} x=41 要求有理数.

gradient(x -> 3x^2 + 2x + 1, 1//4)

(7//2,)

控制流结构

f ( x ) = x 3 f(x)=x^3 f(x)=x3, x = 5 x=5 x=5

 function pow(x, n)
         r = 1
         for i = 1:n
           r *= x
         end
         return r
       end

pow (generic function with 1 method)

gradient(x -> pow(x, 3), 5)

(75.0,)

字典结构

f ( x ) = x 2 f(x)=x^2 f(x)=x2, x = 5 x=5 x=5

d = Dict()

Dict{Any, Any}()

gradient(5) do x
         d[:x] = x
         d[:x] * d[:x]
       end

(10.0,)

d[:x]

5

结构与类型

定义结构 “Point”, 包括两个分量 Point.x, Point.y
并定义向量的加法与减法运算,以向量的欧式范数的平方作为目标函数
f ( P ) = P . x 2 + P . y 2 f(P)=\sqrt{P.x^2+P.y^2} f(P)=P.x2+P.y2

import Base: +, -
struct Point
  x::Float64
  y::Float64
end
a::Point + b::Point = Point(a.x + b.x, a.y + b.y)
a::Point - b::Point = Point(a.x - b.x, a.y - b.y)
dist(p::Point) = sqrt(p.x^2 + p.y^2)
a = Point(1, 2)

Point(1.0, 2.0)

b = Point(3, 4)

Point(3.0, 4.0)

dist(a + b)

7.211102550927978

gradient(a -> dist(a + b), a)[1]

(x = 0.5547001962252291, y = 0.8320502943378437)

雅可比矩阵

考虑映射 f ( a ) = 100 ( a 1 2 , a 2 2 , a 3 2 , 0 , 0 , 0 , 0 ) f(a)=100(a_1^2,a_2^2,a_3^2,0,0,0,0) f(a)=100(a12,a22,a32,0,0,0,0), 在点 ( 1 , 2 , 3 , 4 , 5 , 6 , 7 ) (1,2,3,4,5,6,7) (1,2,3,4,5,6,7) 处的雅克比矩阵。

jacobian(a -> 100*a[1:3].^2, 1:7)[1] 

3×7 Matrix{Int64}:
200 0 0 0 0 0 0
0 400 0 0 0 0 0
0 0 600 0 0 0 0
考虑映射 f ( a , x ) = ( a 1 2 x , a 2 2 x , a 3 2 x ) f(a,x)=(a_1^2x,a_2^2x,a_3^2x) f(a,x)=(a12x,a22x,a32x), 在点 a = ( 1 , 2 , 3 ) , x = 1 a=(1,2,3),x=1 a=(1,2,3),x=1 处的雅克比矩阵。

jacobian((a,x) -> a.^2 .* x, [1,2,3], 1)

([2 0 0; 0 4 0; 0 0 6], [1, 4, 9])

考虑映射 f ( a , d ) = ( a 11 a 12 a 21 a 22 a 31 a 32 ) f(a,d)=\left(\begin{matrix} a_{11}a_{12}\\ a_{21}a_{22}\\ a_{31}a_{32}\end{matrix}\right) f(a,d)=a11a12a21a22a31a32 a = ( 1 2 3 4 5 6 ) a=\left(\begin{matrix} 1& 2\\3& 4\\5&6\end{matrix}\right) a=135246, d = 2 d=2 d=2

jacobian((a,d) -> prod(a, dims=d), [1 2; 3 4; 5 6], 2)

([2 0 … 0 0; 0 4 … 3 0; 0 0 … 0 5], [0, 0, 0])

除了数与抽象数组,其他类型的微分结果都是 Nothing, 其他类型包括字符,字符串,元组等等。

Hesse 矩阵 (Hessian)

求二元函数 f ( x ) = x 1 x 2 f(x)=x_1x_2 f(x)=x1x2 的 Hesse 矩阵

hessian(x -> x[1]*x[2], randn(2))

2×2 Matrix{Float64}:
0.0 1.0
1.0 0.0


http://www.kler.cn/a/320376.html

相关文章:

  • OceanBase 分区表详解
  • H.265流媒体播放器EasyPlayer.js视频流媒体播放器关于直播流播放完毕是否能监听到
  • 2024-11-17 -MATLAB三维绘图简单实例
  • Ubuntu 18.04 配置sources.list源文件(无法安全地用该源进行更新,所以默认禁用该源)
  • Istio分布式链路监控搭建:Jaeger与Zipkin
  • Solana 区块链的技术解析及未来展望 #dapp开发#公链搭建
  • ollydbg 小记
  • 每天一道面试题(17):服务网格学习笔记
  • 社区团购的创新与变革——融合开源链动 2+1 模式、AI 智能名片及 S2B2C 商城小程序
  • 2024一线大厂网络安全面试题+答案,看完offe拿到手软!
  • .NET 反序列化加载哥斯拉内存马的工具
  • 计算机毕业设计 基于Python医院预约挂号系统 Django+Vue 前后端分离 附源码 讲解 文档
  • 大语言模型之LlaMA系列- LlaMA 2及LLaMA2_chat(上)
  • 【OSS安全最佳实践】对OSS表格文件中的敏感数据进行脱敏
  • 3分钟,教你判断自己适不适合做项目管理!
  • 前端开发之原型模式
  • FPGA题目记录2
  • 【RDMA】mlxconfig修改和查询网卡(固件)配置--驱动工具
  • 双十一好货推荐有哪些?五大双十一种草好物推荐!
  • chatgpt复旦大学张奇老师《自然语言处理导论》AI好书PDF分享,不看后悔一辈子!
  • 【计算机网络 - 基础问题】每日 3 题(二十三)
  • Java 序列化:为什么你应该手动定义 serialVersionUID?@Serial 注解有什么作用?
  • python基础:函数、模块、库
  • AI篮球投篮分析与投篮姿势的机器学习应用
  • PHP 函数
  • 山西农业大学20240925