当前位置: 首页 > article >正文

人工智能-优化算法之梯度下降

梯度下降

尽管梯度下降(gradient descent)很少直接用于深度学习, 但了解它是理解下一节随机梯度下降算法的关键。 例如,由于学习率过大,优化问题可能会发散,这种现象早已在梯度下降中出现。 同样地,预处理(preconditioning)是梯度下降中的一种常用技术, 还被沿用到更高级的算法中。 让我们从简单的一维梯度下降开始。

下面我们来展示如何实现梯度下降。为了简单起见,我们选用目标函数f(x)=x^{2}。 尽管我们知道x=0f(x)能取得最小值, 但我们仍然使用这个简单的函数来观察(x)的变化。

%matplotlib inline
from mxnet import np, npx
from d2l import mxnet as d2l

npx.set_np()

def f(x):  # 目标函数
    return x ** 2

def f_grad(x):  # 目标函数的梯度(导数)
    return 2 * x
%matplotlib inline
import numpy as np
import torch
from d2l import torch as d2l

def f(x):  # 目标函数
    return x ** 2

def f_grad(x):  # 目标函数的梯度(导数)
    return 2 * x
%matplotlib inline
import numpy as np
import tensorflow as tf
from d2l import tensorflow as d2l

def f(x):  # 目标函数
    return x ** 2

def f_grad(x):  # 目标函数的梯度(导数)
    return 2 * x
%matplotlib inline
import warnings
from d2l import paddle as d2l

warnings.filterwarnings("ignore")
import numpy as np
import paddle

def f(x):  # 目标函数
    return x ** 2

def f_grad(x):  # 目标函数的梯度(导数)
    return 2 * x

接下来,我们使用x=10作为初始值,并假设eta=0.2。 使用梯度下降法迭代\(x\)共10次,我们可以看到,(x)的值最终将接近最优解。

def gd(eta, f_grad):
    x = 10.0
    results = [x]
    for i in range(10):
        x -= eta * f_grad(x)
        results.append(float(x))
    print(f'epoch 10, x: {x:f}')
    return results

results = gd(0.2, f_grad)

 epoch 10, x: 0.060466

def gd(eta, f_grad):
    x = 10.0
    results = [x]
    for i in range(10):
        x -= eta * f_grad(x)
        results.append(float(x))
    print(f'epoch 10, x: {x:f}')
    return results

results = gd(0.2, f_grad)

 epoch 10, x: 0.060466

def gd(eta, f_grad):
    x = 10.0
    results = [x]
    for i in range(10):
        x -= eta * f_grad(x)
        results.append(float(x))
    print(f'epoch 10, x: {x:f}')
    return results

results = gd(0.2, f_grad)

 epoch 10, x: 0.060466

def gd(eta, f_grad):
    x = 10.0
    results = [x]
    for i in range(10):
        x -= eta * f_grad(x)
        results.append(float(x))
    print(f'epoch 10, x: {float(x):f}')
    return results

results = gd(0.2, f_grad)

 epoch 10, x: 0.060466

对进行x优化的过程可以绘制如下:

def show_trace(results, f):
    n = max(abs(min(results)), abs(max(results)))
    f_line = np.arange(-n, n, 0.01)
    d2l.set_figsize()
    d2l.plot([f_line, results], [[f(x) for x in f_line], [
        f(x) for x in results]], 'x', 'f(x)', fmts=['-', '-o'])

show_trace(results, f)

 [07:12:32] ../src/storage/storage.cc:196: Using Pooled (Naive) StorageManager for CPU

def show_trace(results, f):
    n = max(abs(min(results)), abs(max(results)))
    f_line = torch.arange(-n, n, 0.01)
    d2l.set_figsize()
    d2l.plot([f_line, results], [[f(x) for x in f_line], [
        f(x) for x in results]], 'x', 'f(x)', fmts=['-', '-o'])

show_trace(results, f)

 

def show_trace(results, f):
    n = max(abs(min(results)), abs(max(results)))
    f_line = tf.range(-n, n, 0.01)
    d2l.set_figsize()
    d2l.plot([f_line, results], [[f(x) for x in f_line], [
        f(x) for x in results]], 'x', 'f(x)', fmts=['-', '-o'])

show_trace(results, f)

 

def show_trace(results, f):
    n = max(abs(min(results)), abs(max(results)))
    f_line = paddle.arange(-n, n, 0.01, dtype='float32')
    d2l.set_figsize()
    d2l.plot([f_line, results], [[f(x) for x in f_line], [
        f(x) for x in results]], 'x', 'f(x)', fmts=['-', '-o'])

show_trace(results, f)

 


http://www.kler.cn/news/149548.html

相关文章:

  • RLHF:强化学习结合大预言模型的训练方式
  • 在Mysql中,什么是回表,什么是覆盖索引,索引下推?
  • Qt 软件调试(一) Log日志调试
  • MapReduce概念
  • 简化文件上传流程:学习如何封装Vue2拖拽上传组件
  • 4.ORB-SLAM3中如何实现稠密建图(二):稠密建图如何控制三大线程与稠密建图代码解析
  • 额,收到阿里云给的赔偿了!
  • OpenCV | 傅里叶变换——低通滤波器与高通滤波器
  • 西南科技大学C++程序设计实验二(类与对象一)
  • 做到这一点,运维可高枕无忧
  • 读天下杂志读天下杂志社读天下编辑部简介
  • 王者荣耀游戏制作
  • 从零构建属于自己的GPT系列2:预训练中文模型加载、中文语言模型训练、逐行代码解读
  • During handling of the above exception, another exception occurred解决方案
  • vue项目实现生成一个简单二维码
  • 前端面试灵魂提问
  • 浅析智慧社区建设趋势及AI大数据监管平台方案设计
  • wsj0数据集原始文件.wv1.wv2转换成wav文件
  • Kanna库编写数据抓取代码示例
  • C# 线程(1)
  • 分布式运用之ELK企业级日志分析系统
  • 【C 语言经典100例】C 练习实例14 - 将一个正整数分解质因数
  • java基础之集合概览
  • 模拟退火算法应用——求解TSP问题
  • Django回顾2
  • 流畅的Python (节选)
  • VC++、MFC中操作excel时,Rang和Rangs的区别是什么?
  • 贪心算法(新坑)
  • 数据收集和准备:打造高质量的数据集
  • 【SpringBoot篇】登录校验 — JWT令牌