当前位置: 首页 > article >正文

梯度下降法求解局部最小值深入讨论以及 Python 实现

文章目录

0. 前期准备

在开始讲梯度下降法求解函数的局部最小值之前, 你需要有梯度下降法求解函数的最小值的相关知识。

如果你还不是很了解,可以参看我的另外一篇文章:

梯度下降法以及 Python 实现(小提示:小手点一点就可以跳转了)

1. 局部最小值

对于一些函数而言,在某个自变量区间范围内存在一个最小值,且这个最小值比全局最小值大。那么我们就称这个最小值为局部最小值。

数学定义:

如果存在一个 ε > 0 \varepsilon > 0 ε>0,使得所有满足 ∣ x − x ∗ ∣ < ε \lvert x-x^* \rvert < \varepsilon xx<ε x x x 都有 f ( x ∗ ) ≤ f ( x ) f(x^*) \leq f(x) f(x)f(x),我们就把点 x ∗ x^* x 对应的函数值 f ( x ∗ ) f(x^*) f(x) 称为函数 f ( x ) f(x) f(x) 的一个局部最小值。

2. 例子

上面的定义可能不好理解,用图来表示就很非常直观了。接下来,直接给出一个例子:

使用梯度下降法来计算下面函数 f ( x ) f(x) f(x) 的最小值。
f ( x ) = x 4 + 2 x 3 − 3 x 2 − 2 x . f(x) = x^4 + 2x^3 - 3x^2 - 2x. f(x)=x4+2x33x22x.


f ( x ) f(x) f(x) x x x 的微分为:
f ′ ( x ) = d f ( x ) d x = 4 x 3 + 6 x 2 − 6 x − 2 f'(x) = \frac{\mathrm{d}f(x)}{\mathrm{d}x} = 4x^3 + 6x^2 - 6x - 2 f(x)=dxdf(x)=4x3+6x26x2

设置学习率 η = 0.01 \eta=0.01 η=0.01,初始值 x 0 = 2.0 x^0=2.0 x0=2.0

eta = 0.01
x = 2.0

Python 代码实现:

import numpy as np
import matplotlib.pyplot as plt


def my_func(x):
    """
    $y = x^4 + 2x^3 - 3x^2 - 2x$
    :param x: 变量
    :return: 函数值
    """
    return x ** 4 + 2 * x ** 3 - 3 * x ** 2 - 2 * x


def grad_func(x):
    """
    函数 $y = 4x^3 + 6x^2 - 6x - 2$ 的导数
    :param x: 变量
    :return: 导数值
    """
    return 4 * x ** 3 + 6 * x ** 2 - 6 * x - 2
  

eta = 0.01
x = 2.0
record_x = []
record_y = []

for i in range(20):
    y = my_func(x)
    record_x.append(x)
    record_y.append(y)

    x -= eta * grad_func(x)


print(np.round(record_x, 4))
print(np.round(record_y, 4))

x_f = np.linspace(-3, 2)
y_f = my_func(x_f)

plt.plot(x_f, y_f, linestyle='--', color='red')
plt.scatter(record_x, record_y)

plt.xlabel('x', size=14)
plt.ylabel('y', size=14)
plt.grid()
plt.show()

x x x 的变化过程:

[2. 1.58 1.3872 1.2682 1.1862 1.1262 1.0805 1.0449 1.0164 0.9934 0.9746 0.959 0.9461 0.9353 0.9262 0.9185 0.912 0.9065 0.9018 0.8978]

f ( x ) f(x) f(x) 的变化过程:

[16. 3.4714 0.495 -0.6951 -1.2755 -1.5919 -1.7774 -1.8916 -1.9647 -2.0128 -2.0451 -2.0672 -2.0826 -2.0933 -2.101 -2.1064 -2.1103 -2.1132 -2.1152 -2.1167]

在这里插入图片描述

可以看到,初始值 x 0 = 2.0 x^0=2.0 x0=2.0 在图像的右上角。经过梯度下降法求解, f ( x ) f(x) f(x) 的值落在了一个局部最小值上。

3. 讨论

我们来讨论一下,在知晓函数图像的情形下,如何才能得到正确的全局最小值。

3.1 增加迭代次数

在上面的这种情况下,我们继续迭代,也不无法得到最小值的。也即是说 x x x 被右侧的局部最小值抓取并无法脱离,陷入了局部最小值陷阱。

为什么说是陷阱呢?因为假如你的目的是求解这个函数的最小值,现在梯度下降法停止了,得到了一个局部最小值。但是你并不清楚这是局部最小值,还开心的以为是全局最小值。你就将这个局部最小值当做是全局最小值进行处理了。你就陷入了这个陷阱中。

使用 Python 模拟一下:
将迭代的次数设置为 100 次(200 次也可以,甚至更多次都行)。
只需要修改

for i in range(100):

或者设置一个变量

num_max_iterations = 100
for i in range(num_max_iterations):

x x x 的值会在 0.8733 处停止变化:

[2. 1.58 1.3872 1.2682 1.1862 1.1262 1.0805 1.0449 1.0164 0.9934
0.9746 0.959 0.9461 0.9353 0.9262 0.9185 0.912 0.9065 0.9018 0.8978
0.8943 0.8914 0.8889 0.8867 0.8848 0.8832 0.8819 0.8807 0.8797 0.8788
0.878 0.8774 0.8768 0.8763 0.8759 0.8756 0.8752 0.875 0.8747 0.8745
0.8744 0.8742 0.8741 0.874 0.8739 0.8738 0.8737 0.8737 0.8736 0.8736
0.8735 0.8735 0.8735 0.8734 0.8734 0.8734 0.8734 0.8734 0.8734 0.8734
0.8733 0.8733 0.8733 0.8733 0.8733 0.8733 0.8733 0.8733 0.8733 0.8733
0.8733 0.8733 0.8733 0.8733 0.8733 0.8733 0.8733 0.8733 0.8733 0.8733
0.8733 0.8733 0.8733 0.8733 0.8733 0.8733 0.8733 0.8733 0.8733 0.8733
0.8733 0.8733 0.8733 0.8733 0.8733 0.8733 0.8733 0.8733 0.8733 0.8733]

f ( x ) f(x) f(x) 的值会在 -2.1209 处停止变化:

[16. 3.4714 0.495 -0.6951 -1.2755 -1.5919 -1.7774 -1.8916 -1.9647
-2.0128 -2.0451 -2.0672 -2.0826 -2.0933 -2.101 -2.1064 -2.1103 -2.1132
-2.1152 -2.1167 -2.1178 -2.1186 -2.1192 -2.1196 -2.12 -2.1202 -2.1204
-2.1205 -2.1206 -2.1207 -2.1207 -2.1208 -2.1208 -2.1208 -2.1208 -2.1208
-2.1208 -2.1209 -2.1209 -2.1209 -2.1209 -2.1209 -2.1209 -2.1209 -2.1209
-2.1209 -2.1209 -2.1209 -2.1209 -2.1209 -2.1209 -2.1209 -2.1209 -2.1209
-2.1209 -2.1209 -2.1209 -2.1209 -2.1209 -2.1209 -2.1209 -2.1209 -2.1209
-2.1209 -2.1209 -2.1209 -2.1209 -2.1209 -2.1209 -2.1209 -2.1209 -2.1209
-2.1209 -2.1209 -2.1209 -2.1209 -2.1209 -2.1209 -2.1209 -2.1209 -2.1209
-2.1209 -2.1209 -2.1209 -2.1209 -2.1209 -2.1209 -2.1209 -2.1209 -2.1209
-2.1209 -2.1209 -2.1209 -2.1209 -2.1209 -2.1209 -2.1209 -2.1209 -2.1209
-2.1209]

在这里插入图片描述
结果和上面迭代次数为 20 的情形是一致的。说明在这种情形下,增加迭代次数是无效的。

3.2 修改学习率 η \eta η

将这个局部最小值想象成一个“小水坑”,那么我们能不能不要踩到“水坑”里,一下子“跳”过去呢?

答案是可以的。学习率 η \eta η 就可以实现这个功能。

同样的 x 0 = 2.0 x^0=2.0 x0=2.0,学习率 η \eta η 设置为 0.1。

eta = 0.1
x = 2.0

x x x 的变化:

[ 2. -2.2 -1.9648 -2.2259 -1.9227 -2.2513 -1.879 -2.2711 -1.8428
-2.2828 -1.8207 -2.2879 -1.811 -2.2896 -1.8076 -2.2901 -1.8066 -2.2903
-1.8063 -2.2903]

f ( x ) f(x) f(x) 的变化:

[16. -7.9904 -7.9187 -7.9206 -7.7945 -7.8352 -7.6367 -7.7556 -7.4858
-7.7035 -7.3855 -7.6799 -7.3396 -7.6716 -7.3235 -7.6691 -7.3186 -7.6683
-7.3172 -7.6681]

图像求解过程:
在这里插入图片描述

可以看到当 η \eta η 变大后, x x x 直接就跳到了 f ( x ) f(x) f(x) 的图像左侧,进入了全局最小值的范围内。

3.3 修改初始值 x 0 x^0 x0

设置合适的初始值,就可以得到正确的全局最小值。

比如观察图像,可以看到全局最小值的点, x x x 落在区间 ( − 3 , − 2 ] (-3, -2] (3,2] 之间。

所以我们设置 x 0 = − 3 x^0 = -3 x0=3,学习率 η = 0.01 \eta=0.01 η=0.01

x = -3.0
eta = 0.01

x x x

[-3. -2.62 -2.4497 -2.3487 -2.2824 -2.2363 -2.2032 -2.1788 -2.1607
-2.1469 -2.1365 -2.1285 -2.1223 -2.1175 -2.1138 -2.1109 -2.1087 -2.1069
-2.1056 -2.1045]

f ( x ) f(x) f(x)

[ 6. -4.2027 -6.493 -7.3339 -7.7058 -7.8879 -7.9828 -8.0345 -8.0635
-8.0801 -8.0897 -8.0954 -8.0988 -8.1008 -8.102 -8.1028 -8.1032 -8.1035
-8.1036 -8.1037]

在这里插入图片描述

那假如同样的初始值,但是将学习率变化呢?

  • 合适的初始值,搭配过大的学习率

设置 x 0 = − 3.0 x^0=-3.0 x0=3.0 η = 0.1 \eta=0.1 η=0.1

x = -3.0
eta = 0.1

在这里插入图片描述
从图像可以看出,这种情形下,仍然无法得到正确的全局最小值。

  • 合适的初始值,搭配过小的学习率

设置 x 0 = − 3.0 x^0=-3.0 x0=3.0 η = 0.001 \eta=0.001 η=0.001

x = -3.0
eta = 0.001

在这里插入图片描述
从图像可以看出,这种情形下, f ( x ) f(x) f(x) 向着全局最小值收敛,但在固定迭代次数内,还没有达到。这需要增加迭代次数,消耗更多的资源。

4. 总结

初始值 x 0 x^0 x0、学习率 η \eta η、迭代次数,共同影响梯度下降法的求解过程。

其中, x 0 x^0 x0 η \eta η 决定了方向,起了关键性作用。

参考

-《用Python编程和实践!数学教科书》


http://www.kler.cn/a/427451.html

相关文章:

  • Android H5调起微信支付宝支付
  • ip所属地址是什么意思?怎么改ip地址归属地
  • CmakeLists学习刨根问底
  • 洛谷P4913 【深基16.例3】二叉树深度(c嘎嘎)
  • 普及组集训图论--判断负环
  • HarmonyOS 5.0应用开发——UIAbility跳转
  • Java --- 注解(Annotation)
  • 【SpringBoot】StopWatch工具类的使用
  • 【MySQL】视图详解
  • labview使用modbus library实现modbus通信
  • flask+pyecharts实现可登录可视化大屏
  • RT Thread Studio新建STM32F407IG工程文件编译提示错误
  • MYSQL PARTITIONING分区操作和性能测试
  • 志愿服务管理系统设计与实现
  • 网络安全基本原则
  • 原型模式(Prototype Pattern)——对象克隆、深克隆与浅克隆及适用场景
  • 排序算法入门:分类与基本概念详解
  • 单链表---回文结构
  • 静态路由与交换机配置实验
  • springboot的 nacos 配置获取不到导致启动失败及日志不输出问题