当前位置: 首页 > article >正文

自动微分-梯度!

前言背景知识:

梯度下降(Gradient descent,GD)

正文:

        自动微分为机器学习、深度学习神经网络的核心知识之一,若想更深一步使用神经网络进行具体问题研究,那么自动微分不得不了解。   “工欲善其事,必先利其器”,事为我们的研究问题,器乃神经网络、自动微分工具,这就不得不提一提现深度学习框架TensorFlow、PyTorch,还有即时编译、自动微分工具JAX。本文将使用Pytorch和Jax实现自动微分的基本案例。方便大家学习,入门自动微分工具,提供案例模板为Pytorch版和Jax版,为什么没有Tensorflow,因为觉得它不好用!

        注:这些框架、工具在CPU、GPU,又或是TPU的运行效率可执行查阅,作者认为Jax可能会成为潮流。

        已知基础单层线性回归模型如下:

                h_{\theta }(x)=\theta_{0}+\theta_{1}x

                J(\theta )=\frac{1}{2m}\sum_{i=1}^{m}(h_{\theta} (x^{(i)})-y^{(i)})^{2},\theta =\left \{ \theta_{0},\theta_{1} \right \}

        则有J关于theta偏导如下:

                  \frac{\partial J}{\partial \theta_{0}}=\frac{1}{m}\sum_{i=1}^{m}(h_{\theta)}(x^{(i))})-y^{(i)})

                 \frac{\partial J}{\partial \theta_{1}}=\frac{1}{m}\sum_{i=1}^{m}(h_{\theta)}(x^{(i))})-y^{(i)})

原生代码实现计算:

def h_x(theta0,theta1,x):
	return theta0+theta1*x

def SG(m,theta0,theta1,X,Y):
	sum = 0
	for i in range(0,m):
		sum += (h_x(theta0,theta1,X[i])-Y[i])
	theta0_grad = 1.0/m*sum
	
	sum = 0
	for i in range(0,m):
		sum += (h_x(theta0,theta1,X[i])-Y[i])*X[i]
	theta1_grad = 1.0/m*sum
    
	print("O_SG_grad_caculate : {} , {} ".format(theta0_grad,theta1_grad))

#损失函数
def loss(m,theta0,theta1,X,Y):
	result = 0.0
	for i in range(0,m):
		result += (h_x(theta0,theta1,X[i])-Y[i])**2

	return result/(2*m)

X = [1,2,3,4,5,6]
Y = [13,14,20,21,25,30]

theta0 = 0.0
theta1 = 0.0

m = 6

y_pre = h_x(theta0,theta1,X)
loss = loss(m,theta0,theta1,X,Y)

print(loss)
SG(m,theta0,theta1,X,Y)

输出:

loss : 227.58333333333334

O_SG_grad_caculate : -20.5 , -81.66666666666666

Pytorch自动微分:

        1)torch.autograd.grad计算微分:

import torch

def h_x(theta0,theta1,x):
	return theta0+theta1*x

def SG_Torch(theta0,theta1,loss):
	theta0_grad,theta1_grad = torch.autograd.grad(loss,[theta0,theta1])
	print("T_SG_grad_caculate : {} , {} ".format(theta0_grad,theta1_grad))

X = torch.tensor([1,2,3,4,5,6])
Y = torch.tensor([13,14,20,21,25,30])

theta0 = torch.tensor(0.0,requires_grad=True)
theta1 = torch.tensor(0.0,requires_grad=True)

y_pre = h_x(theta0,theta1,X)
loss = torch.mean((y_pre - Y)**2/2)

#print loss res
print("loss : {}".format(loss))
#print grad res
SG_Torch(theta0,theta1,loss)

       输出:

loss : 227.5833282470703

T_SG_grad_caculate : -20.500001907348633 , -81.66667175292969

         2)loss.backward()实现计算微分,回传theta0和theta1两叶子节点(设置要求grad)

        

import torch

def h_x(theta0,theta1,x):
	return theta0+theta1*x

X = torch.tensor([1,2,3,4,5,6])
Y = torch.tensor([13,14,20,21,25,30])

theta0 = torch.tensor(0.0,requires_grad=True)
theta1 = torch.tensor(0.0,requires_grad=True)

y_pre = h_x(theta0,theta1,X)
loss = torch.mean((y_pre - Y)**2/2)

loss.backward()

#print loss res
print("loss : {}".format(loss))
#print grad res
print("T_SG_grad_caculate : {} , {} ".format(theta0.grad,theta1.grad))

        输出:

loss : 227.5833282470703

T_SG_grad_caculate : -20.500001907348633 , -81.66667175292969

Jax自动微分:

import jax
import jax.numpy as np
from jax import grad

def h_x(theta0,theta1,x):
	return theta0+theta1*x

def loss(theta0,theta1,X,Y):
    y_pre = h_x(theta0,theta1,X)
    loss = np.mean((y_pre - Y)**2/2)
    return loss

def SG_Jax(theta0,theta1,l,X,Y):
    g_L_theta0 = grad(loss,argnums = 0)
    g_L_theta1 = grad(loss,argnums = 1)
    
    theta0_grad = g_L_theta0(theta0,theta1,X,Y)
    theta1_grad = g_L_theta1(theta0,theta1,X,Y)
    
    print("J_SG_grad_caculate : {} , {} ".format(theta0_grad,theta1_grad))
    
X = np.array([1,2,3,4,5,6])
Y = np.array([13,14,20,21,25,30])

theta0 = 0.0
theta1 = 0.0

l = loss(theta0,theta1,X,Y)

#print loss res
print("loss : {}".format(l))
#print grad res
SG_Jax(theta0,theta1,l,X,Y)

        输出:

loss : 227.58334350585938
J_SG_grad_caculate : -20.500001907348633 , -81.66667175292969

http://www.kler.cn/news/329934.html

相关文章:

  • 大数据-153 Apache Druid 案例 从 Kafka 中加载数据并分析
  • JavaScript---BOM,DOM 对象
  • QT系统学习篇(3)- Qt开发常用算法及控件原理
  • 森林火灾检测数据集 7400张 森林火灾 带标注 voc yolo
  • 【计算机网络】传输层UDP和TCP协议
  • HarmonyOS鸿蒙 Next 实现协调布局效果
  • MySQL踩坑点:字符集和排序规则
  • 架构视图和视角
  • 【重学 MySQL】四十六、创建表的方式
  • 2024 全新体验:国学心理 API 接口来袭
  • ES索引生命周期管理
  • 一次oracle迁移11g到19c后用到的对象数量对比脚本
  • Golang 服务器虚拟化应用案例
  • Django学习笔记四:urls配置详解
  • Geoserver关于忘记密码的解决方法
  • 无头双向不循环链表的模拟
  • 千兆网络变压器HX84801SP POE应用主板
  • 秋招|面试|群面|求职
  • 服务架构的演进之路:从单体应用到Serverless
  • 【初阶数据结构】排序——归并排序
  • Stable Diffusion绘画 | 来训练属于自己的模型:打标处理与优化
  • 接口测试入门:深入理解接口测试!【电商API接口测试】
  • 【Qt】系统相关学习--底层逻辑--代码实践
  • 【Redis】主从复制(上)
  • linux文件编程_进程通信
  • 《中安未来护照阅读器 —— 机场高效通行的智慧之选》
  • 一、前后端分离及drf的概念
  • 15 种高级 RAG 技术 从预检索到生成
  • Linux开发讲课45--- 链表
  • 音视频入门基础:FLV专题(8)——FFmpeg源码中,解码Tag header的实现