当前位置: 首页 > article >正文

AI学习记录 - 模型训练中怎么反向传播以及学习率的影响

画图不易,有用点赞

解释反向传播中具体的运行

假设我们有个简易的神经元连接,如下

在这里插入图片描述

目的是求解 绿色w = 5 权重对结果 60的影响程度,因为我们得知道 绿色w = 5 对结果的影响程度和影响方向,才好将 绿色w = 5 是增大还是缩小,增大多少或者减少多少。
我们知道神经元其实y = w * x + b,多层之后,如图所示

在这里插入图片描述

根据模型直接得出的结果是 60,但是假设真实值是 90,那么60太小了,我们需要增大 绿色w=5 的权重,让 60慢慢接近90,怎么做呢?

我们知道了求影响程度,也是一级一级的求,求的是当前公式对下一级的影响程度也就是导数,最后相乘叠加起来,上图知道 w = 5 对 60 的影响程度是 500,当 绿色w = 5 增加 1 , 60 变成 560, 在实际训练过程中,500太大了,也就是直接求出来的导数太大,我们会先使用学习率进行缩小计算:

500 * 0.01(学习率) = 5 (变化程度)
5(原权重) + 5 (变化程度) = 10 (新权重)

我们将 10 覆盖绿色框的 5 ,这不就更新权重了吗,更新权重之后,又将 红色框 2 输入,得出 1060, 哇。。。。。。,学习率太大了,可以变成0.00001试试。

在这里插入图片描述

在实际的场景中,计算公式种类铁定不会这么单一,会融合了超级多的复杂的数学公式,这就涉及到复合多元函数求导数了,我们要依据现行的导数公式进行拆分,拆分成n个小公式,分别对n个小公式求导,然后将所有的导数相乘,就可以得出任意一个权重和偏置对结果的影响程度,然后调整它。

分割线——————————

重磅结论:大模型本质是个巨大的复合的多元的数学公式,数学公式是由无数个简单的《单元公式》组成,单元公式通过树形结构组成巨大公式,最顶部的节点就是所有小公式的汇合,通过将某个叶子结点的常量变成变量,从而求出这个常量所在的位置对于结果的影响,也就是所在位置每增加1或者减少1,对对顶部节点的影响程度。

怎么求解其中一个变量对结果的影响程度?

再庞大的数学公式,也是由一个个单元公式组成,单元公式由《计算方式》和《单个或多个值》两部分组成,计算方式是固定的,值有个特性就是可以变,把值当成是一个可以变的东西,就是变量例如1+1=y 变成x+1=y,把它变成可变的x,那y就可以对x求导,求导就是求x对y的影响程度。
假设对其中一个单元公式求导,每一个单元公式都有其对应的导数公式,x输入进单元公式,输出y,那么将x输入进单元公式对应的导数公式,就得到了一个导数,也就是得到了当前公式对下一级的影响程度。一般一个公式会有很多的变量但是不需要拆分的情况,我们就直接把当前公式的上一级的输出当成是当前的公式的常量值即可。
        
        sum_h1 = self.w1 * x[0] + self.w2 * x[1] + self.b1
        h1 = sigmoid(sum_h1)

        sum_h2 = self.w3 * x[0] + self.w4 * x[1] + self.b2
        h2 = sigmoid(sum_h2)

        sum_o1 = self.w5 * h1 + self.w6 * h2 + self.b3
        o1 = sigmoid(sum_o1)
        y_pred = o1

        # --- Calculate partial derivatives.
        # --- Naming: d_L_d_w1 represents "partial L / partial w1"
        d_L_d_ypred = -2 * (y_true - y_pred)

         # Neuron o1
        d_ypred_d_w5 = h1 * deriv_sigmoid(sum_o1)
        d_ypred_d_w6 = h2 * deriv_sigmoid(sum_o1)
        d_ypred_d_b3 = deriv_sigmoid(sum_o1)

        d_ypred_d_h1 = self.w5 * deriv_sigmoid(sum_o1)
        d_ypred_d_h2 = self.w6 * deriv_sigmoid(sum_o1)

        # Neuron h1
        d_h1_d_w1 = x[0] * deriv_sigmoid(sum_h1)
        d_h1_d_w2 = x[1] * deriv_sigmoid(sum_h1)
        d_h1_d_b1 = deriv_sigmoid(sum_h1)
``












# 画图不易,有用点赞











http://www.kler.cn/a/281710.html

相关文章:

  • Springboot开发常见问题及对应的解决方案
  • 刷题强训(day09)【C++】添加逗号、跳台阶、扑克牌顺子
  • OSRM docker环境启动
  • 领海基点的重要性-以黄岩岛(民主礁)的领海及专属经济区时空构建为例
  • Android CPU核分配关联进程
  • AI工具百宝箱|任意选择与Chatgpt、gemini、Claude等主流模型聊天的Anychat,等你来体验!
  • CSS 的font-synthesis属性与中文体验增强
  • 手机号码归属地查询如何用PHP进行调用
  • zoom 会议 javascript 转录例子
  • 第四十篇-TeslaP40+Ollama+Ollama-WebUI(自编译)
  • Python-MNE-源定位和逆问题01:源估计(SourceEstimate)数据结构
  • Nginx 部署前端 Vue 项目全攻略
  • Spring WebFlux – CVE-2023-34034 – 撰写和概念验证
  • Jmeter下载、配置环境变量
  • 【vue3】wangEditor 5在vue3中的使用
  • 【KDD2024】大数据基础工程技术集群异常检测论文入选
  • 【netty系列-08】深入Netty组件底层原理和基本实现
  • stable-diffusion-webui 部署 ,启用 api 服务
  • TPM管理培训究竟需要多少天?完整攻略在此
  • 光伏设计中组串逆变和微型逆变是什么意思?有什么区别?
  • 433 国乒启发式:一切方法的尽头都是本能反应
  • 提升广告效果:Facebook广告投放步骤与实用工具解析
  • GraphRAG论文阅读笔记
  • 构建开发全能型档期预约系统
  • spring整合redis(常用数据类型操作)
  • java 实现文本转音频