当前位置: 首页 > article >正文

【初学人工智能原理】【4】梯度下降和反向传播:能改(下)

前言

本文教程均来自b站【小白也能听懂的人工智能原理】,感兴趣的可自行到b站观看。

本文【原文】章节来自课程的对白,由于缺少图片可能无法理解,故放到了最后,建议直接看代码(代码放到了前面)。

代码实现

任务

  1. 在引入b后绘制代价函数界面,看看到底是不是一个碗
  2. 在w和b两个方向上分别求导,得到这个曲面某点的梯度进行梯度下降,拟合数据

绘制三维的方差代价函数

dataset.py

import numpy as np

def get_beans(counts):
	xs = np.random.rand(counts)
	xs = np.sort(xs)
	ys = np.array([(0.7*x+(0.5-np.random.rand())/5+0.5) for x in xs])
	return xs,ys


cost_function_w.py

import dataset
import matplotlib.pyplot as plt
import numpy as np
from mpl_toolkits.mplot3d import Axes3D
m = 100
xs, ys = dataset.get_beans(m)

# 配置图像
plt.title("Size-Toxicity Function'", fontsize=12)  # 设置图像名称
plt.xlabel("Bean size")  # 设置横坐标的名字
plt.ylabel("Toxicity")  # 设置纵坐标的名字
plt.xlim(0, 1)
plt.ylim(0, 1.5)
plt.scatter(xs, ys)

w = 0.1
b = 0.1
y_pre = w * xs * b

plt.plot(xs, y_pre)
plt.show()

fig=plt.figure()
ax=Axes3D(fig)
ax.set_zlim(0,2) # 限制垂直方向坐标轴取值范围

ws=np.arange(-1,2,0.1)
bs=np.arange(-2,2,0.01)

for b in bs:
    es=[]
    for w in ws:
        y_pre=w*xs+b
        e=np.sum((ys-y_pre)**2)*(1/m)
        es.append(e)
    # plt.plot(ws,es)
    ax.plot(ws,es,b,zdir='y') # 3D绘图

plt.show()

在这里插入图片描述

在这里插入图片描述

关于w和b的梯度下降算法

import dataset
import numpy as np
from matplotlib import pyplot as plt


## Create a dataset
n = 100
xs, ys = dataset.get_beans(n)

# 配置图像
plt.title("Size-Toxicity Function",fontsize=12)
plt.xlabel("Bean Size")
plt.ylabel("Toxicity")
plt.scatter(xs,ys)

w=0.1
b=0.1
y_pre=w*xs+b
plt.plot(xs,y_pre)
plt.show()

# 随机梯度下降法
def gsd(w=0.1,b=0.1):
    # 在全部样本上做50次梯度下降
    for _ in range(50):
        for i in range(100):
            x = xs[i]
            y = ys[i]
            # a=x^2
            # b=-2*x*y
            # c=y^2
            # 斜率k=2aw+b
            dw =  2*w**2*w+2*x*b-2*x*y  # e对w求导
            db = 2 * b + 2 * x * w - 2 * y # e对b求导
            alpha = 0.1
            w = w - alpha * dw  # w根据梯度下降的方向走,如w此时的k<0,则w处于抛物线左端,应该往右边走,相反则往左边走
            b=b-alpha*db

            # 绘制动态变化的曲线
            plt.clf()  # 清空窗口
            plt.scatter(xs, ys)
            y_pre = w * xs+b

            # 限制x轴和y轴的范围,使之不自动调整,避免图像抖动
            plt.xlim(0, 1)
            plt.ylim(0, 1.2)
            plt.plot(xs, y_pre)
            plt.pause(0.01)  # 暂停0.01s,因为不暂停的话会无法显示



# 随机梯度下降
gsd()

原文

通过前面的学习,我们已然了解到现在神经网络精髓之一的梯度下降算法,但是如果仔细观察我们设计的预测函数,你就会发现这是一个非常危险和不完善的模型。比如在另外一片海域里,豆豆的大小和毒性的关系是这样的,有些太小的豆豆是不存在的,我们发现不论怎样去调整w都无法得到理想的预测函数。当然更加糟糕的情况是豆豆越大,毒性越低,原因很简单,我们的预测函数y等于w乘以x很明显是一个必须经过原点的直线。

换句话说,这个预测函数直线的自由度被限制住了,只能旋转而不能移动。因为大家很清楚,一个直线完整的函数应该是y等于wx加b。​

之前我们为了遵循如无必要物增新知的理念,一直在刻意的避免这个截距参数b。直到现在我们终于避无可避是时候增加新的知识。截距参数b的作用大家很清楚,可以让直线在平面内自由的平移,而斜率w可以让直线自由的旋转,当我们把直线的平移的自由度还给它之后,这两者的结合才能让直线在整个平面内这真正的自由起来。

我们来看一下加入截距参数b后发生的改变。首先我们带入b重新推演一次预测和梯度下降的过程,当然为了简单起见,我们还是先看单个豆豆样本的情况,这是预测函数。豆豆的大小是x0,毒性是y0,那么预测就是w乘以x0+b那么根据方差代价函数得到方差代价是这样的,你会发现没有b的时候或者说b=0的时候,代价函数就是我们前几节课中的样子,这其实是b=0的一种特殊情况。

那现在既然有了b接下来我们就要看这个b取不同值的时候会对代价函数造成什么样子的影响?

这里我们需要把代价函数的图像从二维变成三维,给b留出一个维度,b=0就是我们之前讨论的抛物线,b=0.1,e和w的关系还是一个标准的开口向上的抛物线,因为b的改变只影响这个抛物线的系数,换句话说改变的只是抛物线的具体样子,而不会让它变成其他形状。同样的道理,b=0.2也是,b=0.3也是等等等,我们好像已经看出一些眉目了,这好像是一个曲面。没错,这里我们的b取值间隔是0.1,描绘出来的效果似乎还是不太明显。当我们把b的取值间隔弄小一点,再小一点,直到无限的小下去,这时候你就会发现果真是一个三维空间中的曲面,那我们该如何去看待这个曲面?
在这里插入图片描述
在有些教程和书籍中,很多时候为了让他看着明显把它画成了一个鼓鼓的碗状。其实对于线性回归问题中,这种豆豆数据形成的代价函数实际上并没有那么的鼓,而是一个扁扁的碗,扁的几乎看不出来是个碗,但是当我们把这个曲面的等高线画出来就可以看出来,这确实也是一个碗。很明显这个碗状曲面的最低点肯定是问题的关键所在。我们回忆一下,在没有b出现的时候,曲线的最低点代表着w取值造成的预测误差最小,那这个曲面最低点意味着什么?

首先我们想一想这个最低点是怎么形成的呢?​

没错,我们每次取不同的w和b都会导致误差,e不相同,这个局面也就是我们带入b后得到的代价函数的图像,而它的最低点也就意味着这里的w和b的取值会让预测误差最小。而如果我们能得到这个最低点的w和b的值放回到预测函数中,那么此时此刻恰如彼时必克预测也就是最好的。现在我们的目标就很明确了,如何在这个曲面上取得最低点处的w和b的值?在没有b出现的美好时刻,也就是说在b=0处,我们沿着w的方向切上一刀,我们知道这将形成一个关于e和w的开口向上的抛物线,然后不断的通过梯度下降算法去调整w最后到达最低点。

但是你会发现此刻曲线的最低点却并不是这个曲面的最低点,换句话说b=0的取值并不是最好的,那么关于b套路其实还是一样的,我们在这一点上,如果沿着b的方向给曲线来上,一刀会怎样呢?你会发现切口形成的曲线似乎也是一个开口向上的抛物线,如果是这样的话那就很nice了,我们在这个抛物线上也向最低点挪动即可,但果真如此吗?我们之前已经分析出来e和w的关系是一个抛物线,现在我们不妨再看一下e和b的关系,这是方差代价函数,要研究b那我们围绕b重新整理一下这个式子,是这样的。​

当w确定的时候,也就是我们沿着b的方向切下一刀,比如当前这个点的w值为w cut,这时候代价函数是这样的,你看当w取固定值的时候,也就是把w看作一个确定值的时候,e和b的关系又是一个标准的开口向上的一元二次函数。所以面对现在这个误差代价函数曲面,我们还可以换个角度去理解它的形成方式。除了可以像一开始那样认为是e关于w的一元二次函数曲线在b取不同值的时候形成的以外,也可以认为是e关于b的一元二次函数曲线,在每次w取不同值的时候形成的。

现在我们在b上要做的事情和在w上一模一样,不断的去调整b仍然向这个曲线的最低点挪动,而具体的方法也是一样的,根据斜率进行下降。​​

我们完整的来看一下这个过程,假设一开始我们的w等于0.1,b也等于0.1,对应的e是这么多,在曲面的这个位置,我们画一个球来显示这个点,正所谓横看成岭侧成峰,我们横看此处看见的是b确定的时候e和w形成的一个曲线,根据此处的斜率调整w大小是斜率乘以学习率阿尔法,方向是根据斜率的正负确定的。我们侧看此处看见的是w确定的时候,e和b形成的一个曲线,根据此处的斜率调整,b大小是斜率乘以学习率阿尔法方向根据斜率振幅确定,把这两个方向上的调整运动合成一个合成的调整运动,这样我们就完成了一次调整,到达下一个点之后,我们继续横看调整w侧看调整b当我们反复进行这个过程的时候,也就逐渐的向这个曲面的最低点挪动了。​

所以说这里同时有w和b的代价函数曲面和只有w的代价函数曲线相比,这个下降的过程本质上是一样的,换汤不换药,或者说只是从w一味药换成了w和b两位味药。​

但是有一点,我们的代价函数已经是一个曲面了,那这个下降的过程,如果我们再说是斜率下降就有点不太合适了,毕竟一个曲面上的某点的斜率是个什么东西,是关于w的还是关于b的呢?要回答这个问题,需要发散一下思维,换一个角度来看这个下降的过程。我们在代价函数的w和b两个方向上分别求得斜率或者说倒数。​​

对于有两个自变量的代价函数,我们先偏向w求导数,再偏向b求导数,为了区分只有一个自变量的情况,我们把在某一个变量上的导数也称之为偏导数。如果我们把对w和对b的偏导数看作向量,把这两个向量合在一起,形成一个新的和向量,沿着这个和向量进行了下降,是这个曲面在该点下降最快的方式,这个和向量在数学里称之为梯度,到此为止。你也就理解了为什么我们说梯度是比斜率更加广泛的一个概念,它是把各个方向上的偏导数当做向量,合起来形成一个总向量,代表了这个点下降最快的方向,当然在二维曲线中因为没有其他方向,梯度和斜率也可以认为是一回事,而为了让这个下降算法的名字更具有广泛性,所以我们一般称之为梯度下降,而不是斜率下降。​

子曰学而时习之,我们已经完整的讲述了梯度下降的过程,那么现在就来回顾总结一下目前为止我们所学到的东西。我们从环境中观察到了一个问题,豆豆的毒性和它的大小有关系,那现在想要准确的去预测这个关系到底是什么。按照McCulloch-Pitts神经元模型,我们使用一个一元一次线性函数去模拟神经元的树突和轴突的行为,这就是预测函数模型,而把我们统计观测而来的数据送入预测函数进行预测的过程就称之为前向传播。因为计算从前往后数据通过预测函数完成一次前向传播,就会得到一个预测值,预测值和统计观测而来的真实值之间存在着误差。

我们选择平方误差作为评估的手段,你会发现这个误差和预测函数中的参数又会形成一种函数关系,我们把这个函数称之为代价函数,因为采用方差去评估预测误差,所以也称之为方差代价函数,描述了预测函数的参数取不同值的时候预测的不同的误差代价,而用这个代价函数去修正预测函数参数的过程也称之为反向传播
在这里插入图片描述
因为计算从后往前,而这个反向传播参数修正的方法,我们使用梯度下降算法噢对了,我的老伙计。在没有截距b的二维代价函数中叫他斜率下降也未尝不可,而在调整的过程中用来调和下降幅度的l法称之为学习率,他的选择影响了调整的速度太大了容易反复横跳,过大的时候甚至不会收敛,而是发散太小了又容易磨磨唧唧,他是设计者根据经验选择出来的,而不断的经历前向传播和反向传播,最后到达代价函数最低点的过程,我们称之为训练或者学习。

这就是所谓的机器学习中的神经网络,但把一个神经元称之为网络似乎不太恰当,因为没有哪一个网络只有一个节点,但以后我们不断的添加神经元,并把它们连接起来,共同工作的时候,也就能称之为神经网络。

而我们所说的前向传播和反向传播,其实也是在多层神经网络出现后才引入了概念,对于单个神经元如此称呼似乎有点别扭,但这些概念在单个神经元上已经初具雏形,面对网络那只是不断的重复而已,我们会在后面学习多层神经网络时候详细说明相传播更一般的行为,这就是目前人工智能机器学习领域独领风骚的连接主义在干的事情。至于为什么这样多个神经元组合成神经网络后就能达到智能的效果,别着急,我们会在接下来的课程中慢慢到来。
在这里插入图片描述


http://www.kler.cn/a/17505.html

相关文章:

  • 【mysql】使用宝塔面板在云服务器上安装MySQL数据库并实现远程连接
  • 2019年下半年试题二:论软件系统架构评估及其应用
  • 微信小程序=》基础=》常见问题=》性能总结
  • 黄色校正电容102j100
  • 【p2p、分布式,区块链笔记 DAM】GUN/SEA(Security, Encryption, Authorization) 模块genkey
  • mysql5.7安装SSL报错解决(2),总结
  • 算法设计与分析期末复习
  • 判断密码判断密码
  • 删除游戏-类似打家劫舍
  • Canvas和SVG有什么区别?
  • java基础知识——26.反射
  • 架构集群部署
  • 深度学习 -- PyTorch学习 torchvision工具学习 Transforms模块 Normalize用法
  • Db2 hardcode一个CTE
  • 科研人必看入门攻略(收藏版)
  • B017_群函数篇
  • ( 数组和矩阵) 287. 寻找重复数 ——【Leetcode每日一题】
  • Python JSON
  • 网络安全合规-数据安全风险评估
  • 【数据结构】图笔记
  • 【泛函分析】区间上的单调有界函数必存在左右极限,间断点必为第一类间断点
  • 抖音营销策略:新手如何利用抖音提高品牌曝光度
  • 多媒体API
  • Mysql 设置 sort_buffer_size
  • Lenovo MORFFHL鼠标对码教程
  • 【软考备战·希赛网每日一练】2023年5月2日