当前位置: 首页 > article >正文

从零深度学习:(2)最小二乘法

今天我们从比较简单的线性回归开始讲起,还是一样我们先导入包

import numpy as np
import torch
import matplotlib as mpl
import matplotlib.pyplot as plt
a = torch.arange(1,5).reshape(2,2).float()
a

我们利用刚刚导入的画图的包将这两个点画出来,将1和3先索引出来作为横坐标,2和4作为纵坐标传入给plot,'o'表示画的是点而不是线

#画出上面两个点
plt.plot(a[:,0],a[:,1],'o')

                 

现在我们希望找到一条直线去穿过拟合这两个点,也就是所谓的线性回归,不妨设方程如下:

y = ax+b

我们在初中就学过两个点能够带入两个方程进行求解,将a和b通过解方程的形式求解出来。除了这种矩阵求解以外,我们还可以转化为一个优化问题来进行求解。其中优化问题最关键的两个就是优化指标和优化目标函数。我们现在的任务是找到一条直线拟合这两个点,所以显然目标就是将这两个点横坐标带进方程解析式的预测值y和实际的y(2,4)之间的误差变小。

我们可以在markdown中渲染得到如下表格,右侧是预测值和真实值的差值:

为了让差距变小,一个很朴素的想法就是求和变的最小,但是由于这里有正有负,可能会出现正负抵消的情况,所以这里我们采用先平方再求和,也就是所谓的误差平方和SSE:

至此我们已经完成了优化问题的转化,我们现在的目标就是找到a和b为何值的时候这个差值函数最小,因此上面这个函数也叫做目标函数。

我们导入画图工具包将这个函数图像画出来:

from matplotlib import pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
x = np.arange(-1, 3, 0.1)  # 增加步长,减少数据点数量
y = np.arange(-1, 3, 0.1)
a, b = np.meshgrid(x, y)
SSE = (2 - a - b)**2 + (4 - 3*a - b)**2

fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.plot_surface(a, b, SSE, cmap='rainbow')
plt.show()

我们不难看出这是一个凸函数,而对于一个凸函数来说,最小值显然存在,所以根据这一点,我们可以给出求解凸函数最小值的一般方法,也就是最小二乘法。关于这个最小二乘法,我们在高数和概率论的学习中都有涉及,如果有不懂的宝子可以去补一下。当然凸函数优化方法还有很多,我们会在后续的学习中陆续提及。

所以这里就是对a和b分别求偏导并令其等于0,即可求解出来。

求解得出方程为:y = x +1,也就是说当(a,b)等于(1,1)的时候函数取得最小值。

在求解完之后我们也可以通过借助autograd模块来帮助我们验证导数是否为零。

autograd

我们可以在jupyter中输入如上代码,会发现如果你的张量requires_grad属性等于True,你每计算一步都会记录在grad_fn当中。例如这里的y是通过x乘法得到的,所以下面是Mul也就是乘法的缩写,同理z的Pow是power的缩写。

我们也可以通过.grad_fn来查看具体内容:

grad_fn 存储了当前张量的计算源和操作类型,用于梯度计算。具体来说,它指向一个与该张量相关的操作对象,操作对象是由上次计算生成的,这些对象的存在是为了在反向传播时提供梯度计算的方法,并且它们是由 PyTorch 自动生成并维护的。

同时,这是链式存储的一部分。在反向传播中,PyTorch 会按照计算图的反向顺序计算每个张量的梯度。这些 grad_fn 实际上是梯度计算的链条,记录了张量是如何从前一个操作得到的,并允许在反向传播时依赖于这些操作生成梯度。

所以根据这个回溯机制,我们可以画出输出张量是怎么一步一步得来的并画出张量计算图,如下:

PyTorch的计算图是动态计算图,会根据可微分张量的计算过程自动生成,并且伴随着新张量或运算的加入不断更新,这使得PyTorch的计算图更加灵活高效,并且更加易于构建,动态图也更加适用于面向对象编程。


http://www.kler.cn/a/507007.html

相关文章:

  • c++ 中的容器 vector、deque 和 list 的区别
  • ElasticSearch下
  • 如何异地远程访问本地部署的Web-Check实现团队远程检测与维护本地站点
  • 机器学习第一道菜(一):线性回归的理论模型
  • MDX语言的数据库交互
  • 使用 Charles 调试 Flutter 应用中的 Dio 网络请求
  • 网安——CSS
  • [Linux]——进程(2)
  • “AI智能服务平台系统,让生活更便捷、更智能
  • list的模拟实现详解
  • 核心前端技术详解
  • Jupyter notebook中运行dos指令运行方法
  • Java进阶-在Ubuntu上部署SpringBoot应用
  • 微软开源AI Agent AutoGen 详解
  • Docker部署Spring Boot + Vue项目
  • ParcelFileDescriptor+PdfRenderer在Android渲染显示PDF文件
  • Spring Boot中使用AOP实现权限管理
  • Python 的时间处理模块 datetime 详解
  • 图论1-问题 B: 算法7-4,7-5:图的遍历——深度优先搜索
  • 博图 linucx vmware
  • css 实现自定义虚线
  • QT 通过QAxObject与本地应用程序读取Excel内容
  • 汽车故障码U100187 LIN1Communication time out 解析和处理方法
  • 【50个服务器常见端口】
  • 【Linux】sed编辑器二
  • 基于华为云车牌识别服务设计的停车场计费系统【华为开发者空间-鸿蒙】