当前位置: 首页 > article >正文

机器学习扫盲系列(2)- 深入浅出“反向传播”-1

系列文章目录

机器学习扫盲系列(1)- 序
机器学习扫盲系列(2)- 深入浅出“反向传播”-1


文章目录

  • 前言
  • 一、神经网络的本质
  • 二、线性问题
    • 解析解的不可行性
    • 梯度下降与随机梯度下降
    • 链式法则
  • 三、非线性问题
    • 激活函数


前言

反向传播(Backpropagation) 是神经网络中重要且难以理解的概念之一,甚至可以说如果理解了反向传播的工作机制,你就基本理解了神经网络的工作原理。所以我们从“反向传播”切入,由此揭开神经网络的神秘面纱。学习完之后你会发现,反向传播只是一个非常简单的过程,它告诉我们在神经网络中需要怎么改变参数。


一、神经网络的本质

神经网络的本质其实就是根据数据集(点)拟合预测/推理函数(曲线),所以简单来说神经网络其实是一个“极为复杂的曲线拟合机器”。 如下图所示,左图的点作为训练数据,神经网络经过训练之后会拟合处右图的曲线,这样对新数据x,可以预测/推理出y的值。
在这里插入图片描述

二、线性问题

为了简单起见,我们使用一个线性回归模型作为示例来演示。显然,我们在对训练数据拟合成函数(曲线)之前,需要先假设一个预测的函数,如 y = weight * x + bias。有了预测函数之后,我们就可以对每个数据计算损失(实际值和预测值的偏差)。这里我们使用 MSE(mean squared error loss 均方误差) 作为损失函数, 针对预测函数中参数的不同值,我们都可以计算出这批训练数据的总的损失。
在这里插入图片描述
如下图,直线是我们假设的预测函数,图上的点是训练数据。我们试着调整weight 和 bias 的值,看看损失如何变化。
在这里插入图片描述
上面提到,我们要让损失函数的值最小才能得到最完美的预测函数, 那我们先看看损失函数的图形,如下图。x、y轴分别表示 weight 和 bias, z 轴的值就是对应的损失值。大家想想,怎样才能找到损失值最小的位置(weight 和 bias)?
在这里插入图片描述
由此引出梯度(导数,在某个位置的瞬时变化率)的定义。下面的公式给的是只有一个变量x 的情况。因此我们还是先从简单的图形入手。
在这里插入图片描述
假设损失函数 loss = x^2 - 1, 我们得到下面的图形:
在这里插入图片描述
虽然肉眼很容易能看出来最小值在哪个位置,但是对于计算机来说却很难,尤其是函数非常复杂的时候。那怎么才能找到让 loss 最(极)小的位置呢? 首先,我们观察这个图形,在什么情况下 loss 值最小?是不是当梯度为0或者梯度接近于0的时候 loss 最(极)小。这个时候,有人可能会说,那直接令 梯度=0,解析出这个函数的变量x值就行了吧(求解析解)?
其实不然,为什么?

解析解的不可行性

1. 解析解的不可行性,数学复杂度
对于绝大多数模型(如神经网络、支持向量机等),损失函数是高维非线性的,其梯度方程(∇L(θ)=0)常无法分解为闭式解(closed-form solution)。例如:

  • 线性回归可以求得解析解(θ=(XᵀX)⁻¹Xᵀy),但仅因模型是凸且线性。
  • 对复杂模型(如深度学习),损失函数的高度非线性导致方程求解需要多项式时间之外的计算量。

2. 计算资源限制, 维数灾难,高维参数空间的矩阵运算代价极高。例如:

  • 当参数维度为n时,求解线性方程组的计算复杂度为O(n³),而n=1e⁶时需1e¹⁸次运算。
  • 实际深度学习模型的参数规模可达1e⁹量级(如GPT-3),直接求解完全不现实。

3. 非凸优化问题,局部极小值与鞍点

  • 非凸损失函数(如神经网络的损失函数)的梯度为零点可能是局部极小值或鞍点,而全局极小值难以定位。
  • 鞍点尤其在高维空间中普遍存在(概率随维度指数级增长),直接求解梯度为零可能落入低质量解。

4. 数据驱动动态优化,在线学习与大数据场景

  • 当数据集过大时,直接求解需一次性加载全体数据(内存不足问题)。

梯度下降与随机梯度下降

ok,那我们想想如果不用解析解,你会怎么找到最小值?显然,聪明的你肯定会想到这样做:
Prepare阶段:使用训练数据求出损失函数(正向传播)

  1. 随机初始化一个位置值 x1
  2. 计算x1处的梯度d1
  3. 计算 步长=d1 * 学习率(通过调参设置)(这里你会发现一个很巧妙的地方,离最优点越远,步长越长)
  4. 更新位置 x2 = x1 - 步长 = x1 - d1 * 学习率
  5. 循环执行,直到梯度接近0 或者 步数达到最大值

恭喜,你发明了梯度下降算法。

在继续之前,大家想想这个算法有没有什么问题?如果训练数据量非常大,我们的损失函数就会异常复杂,因为计算损失函数时需要将参数加载到内存/显存中,过大的训练数据显然无法运行。怎么办?

把全量数据(epoch)按固定大小随机分组,每次只拿这一组的数据(batch)计算损失函数。

恭喜,你发明了随机梯度下降算法。

以上只是一个参数的情况,如果涉及两个参数比如 weight 和 bias,该怎么处理?
还是让我们先回到一开始的预测函数 y = weight * x + bias。 这里的两个参数weight 和 bias 都会影响 loss值 ,那我们需要计算每个参数对loss的影响程度,即偏导数(梯度),然后根据偏导数不断迭代更新相应参数的值,从而找到最优解(loss 最小)。我们看看有两个参数时,loss的变化情况,为简单起见,我们用二维等高线来画。

在这里插入图片描述
用loss的颜色深浅表示loss值。从图上能很明显看出当 weight = -1, bias = 5时loss为0,也就是我们的目标优化位置。

举个例子,顺便复习一下刚刚梯度下降的步骤,让我们试着从图上随机取点来优化我们的参数(初始化)。

第二步 需要分别计算两个参数在这个位置的偏导数(梯度)
在这里插入图片描述

第三步:分别计算两个参数的步长,第四步: 更新w 和 b
在这里插入图片描述
回顾以上步骤,大家再想想哪一步是比较难的?

对,第3步,怎么计算损失函数的梯度值?

链式法则

先看下这张图:
在这里插入图片描述
如果你想知道蓝色值(loss)如何被粉色值(参数)影响,你会怎么观察?
先从蓝色值开始:

  1. 观察蓝色值被橘色值影响的程度
  2. 观察橘色值被绿色值影响的程度
  3. 观察绿色值被粉色值影响的程度

恭喜,你发明了链式法则! 看公式:

在这里插入图片描述
注意,这里的neuron就是预测函数y。当你想知道loss被w的影响程度(loss 关于 w的梯度), 你可以先计算loss被neuron(预测值)的影响程度,再计算neuron(预测值)被w的影响程度,两个相乘,你就得到了loss 关于 w的梯度。 再观察下,我们是从后往前计算梯度(loss -> 预测值 -> 参数),这也是反向传播这一名词的由来,这样做的好处是可以利用前向传播过程中的计算结果,而不需要重复计算,节约资源和时间。

ok,来看一个具体的示例。
对于 y=wx + b, 我们假设有一条训练数据(x=2.1, y=4),w=1,b=0
在这里插入图片描述
正向传播:
在这里插入图片描述
反向传播:
在这里插入图片描述
在这里插入图片描述
分别更新参数,这里我们设置学习率为 0.1 (lr = 0.1),学习率的更新也是人工/自动调参的一部分:
在这里插入图片描述
看看更新参数后loss值:
在这里插入图片描述
3.61 -> 2.87, loss 下降了!

三、非线性问题

还是回到一开始的图,如果是这些点,你会怎么选择函数来拟合呢?
在这里插入图片描述
聪明的你可能会想到用一个相对复杂的嵌套函数来拟合这个曲线:
y= log(1 + e^(w11 * x + b11)) * w21 + log(1 + e^(w12 * x + b12)) * w22 + b2

恭喜,你发明了神经网络

也就是说,上面的公式其实可以用一个简单的神经网络结构来表示:
在这里插入图片描述
那么它的正向传播过程就可以表示为:
在这里插入图片描述
而反向传播计算梯度的过程就可以这样表示,这里就以w21和w11为例:

在这里插入图片描述
注意: 如果在计算某个参数w11的梯度时用到了另一个参数w21的值,w21的值应该取当前值,而不是优化后的值。等计算完梯度,再统一对所有参数进行迭代更新。

到这里,我们应该能发现,为了拟合更加复杂的曲线,我们可以在每层添加更多神经元、更多层以及在输出中添加非线性(激活函数)来实现。如果我们把整个神经网络视为一个函数,那么通过添加更多神经元和更多层,我们就可以创建一个嵌套更多的函数。这样做有几大好处:

  1. 创建了更多可以调整的参数来拟合输出结果
  2. 保证可微: 可以计算梯度,也就是没有尖锐拐角或断层
  3. 通过添加具有失活点的非线性激活函数,某些神经元可能对输出没有影响,而其他神经元可能变得更加活跃,从而导致输出不必遵循线性约束。

激活函数

我们这里使用的是softplus作为激活函数
在这里插入图片描述
还有比较常见的激活函数:
常见激活函数:

  • Relu:
    在这里插入图片描述
  • sigmoid:
    在这里插入图片描述

未完待续。。。


http://www.kler.cn/a/590229.html

相关文章:

  • 3.17学习总结 java数组
  • 18.使用读写包操作Excel文件:xlrd、xlwt 和 xlutils 包
  • 浅谈AI落地之-关于数据增广的思考
  • Tomcat线程池详解,为什么SpringBoot最大支持200并发?
  • 从零搭建微服务项目Pro(第6-1章——Spring Security+JWT实现用户鉴权访问与token刷新)
  • 【前端】入门基础(一)html标签
  • Git 面试问题,解决冲突
  • ‌RTSPtoWeb, 一个将rtsp转换成webrtc的开源项目
  • C++之list类及模拟实现
  • Redis 安装详细教程(小白版)
  • 《企业级 Webpack 5 优化实战:构建速度提升 400% 的完整方案》
  • VO和DO在前后端中的对应关系详解
  • 中间件漏洞之weblogic
  • Centos离线安装openssl-devel
  • C/C++蓝桥杯算法真题打卡(Day6)
  • “查找”功能发展到今天,便利了生活哪些地方?
  • Bash语言的堆
  • DNS主从服务器
  • 【Linux篇】:初步理解何为进程--从硬件“原子“到PCB“粒子“的进程管理革命
  • Spring Cloud Stream - 构建高可靠消息驱动与事件溯源架构