当前位置: 首页 > article >正文

反向传播算法

反向传播算法的数学解释

反向传播算法是深度学习中用于训练神经网络的核心算法。它通过计算损失函数相对于网络权重的梯度来更新权重,从而最小化损失。

反向传播的基本原理

反向传播算法基于链式法则,它按层反向传递误差,从输出层开始,逐层向后至输入层。

1. 损失函数

  • 假设损失函数为 L L L,用于衡量预测输出 y ^ \hat{y} y^ 和实际标签 y y y 之间的差异。

2. 链式法则

  • 链式法则用于计算损失函数相对于网络中每个权重的梯度。对于每个权重 W W W

    ∂ L ∂ W = ∂ L ∂ y ^ × ∂ y ^ ∂ W \frac{\partial L}{\partial W} = \frac{\partial L}{\partial \hat{y}} \times \frac{\partial \hat{y}}{\partial W} WL=y^L×Wy^

3. 梯度传播

  • 在多层网络中,梯度需要通过每一层反向传播。对于层 l l l 的权重 W l W_l Wl

    ∂ L ∂ W l = ∂ L ∂ y ^ × ∂ y ^ ∂ a l × ∂ a l ∂ W l \frac{\partial L}{\partial W_l} = \frac{\partial L}{\partial \hat{y}} \times \frac{\partial \hat{y}}{\partial a_l} \times \frac{\partial a_l}{\partial W_l} WlL=y^L×aly^×Wlal

    其中 a l a_l al 是层 l l l 的激活输出。

4. 权重更新

  • 权重通过梯度下降法更新:

    W new = W old − η × ∂ L ∂ W W_{\text{new}} = W_{\text{old}} - \eta \times \frac{\partial L}{\partial W} Wnew=Woldη×WL

    其中 η \eta η 是学习率。

反向传播的步骤

  1. 前向传播:计算每层的激活输出直至输出层。
  2. 损失计算:计算预测输出与实际标签的损失。
  3. 反向传播:从输出层开始,逐层向后计算损失函数相对于每个权重的梯度。
  4. 更新权重:根据计算得到的梯度更新网络的权重。

反向传播使得深度神经网络能够通过学习数据中的复杂模式来优化其性能,这是现代深度学习应用的基石。

代码

import tensorflow as tf
from tensorflow.keras.layers import Dense
from tensorflow.keras.models import Sequential

# 创建一个简单的神经网络
model = Sequential([
    Dense(10, activation='relu', input_shape=(784,)),
    Dense(10, activation='softmax')
])

# 编译模型,使用交叉熵损失函数和SGD优化器
model.compile(optimizer='sgd', loss='categorical_crossentropy', metrics=['accuracy'])

# 假设有训练数据 X_train, y_train
# X_train = ... # 输入数据
# y_train = ... # 标签数据

# 训练模型
# model.fit(X_train, y_train, epochs=10)

# 在这个过程中,TensorFlow 自动执行前向传播、损失计算、反向传播和权重更新

在这个示例中,我们定义了一个含有两层的简单神经网络,并使用随机梯度下降(SGD)作为优化器。在训练过程中,TensorFlow 会自动处理前向传播、损失计算、反向传播和权重更新的步骤


http://www.kler.cn/a/163170.html

相关文章:

  • mysql 更改 字段长度
  • 生成模型——PixelRNN与PixelCNN
  • 16008.行为树(五)-自定义数据指针在黑板中的传递
  • luckfox-pico-max学习记录
  • LLMs之PDF:zeroX(一款PDF到Markdown 的视觉模型转换工具)的简介、安装和使用方法、案例应用之详细攻略
  • Sql server 备份还原方法
  • 分析阿里巴巴的微服务依赖图和性能
  • 生产上线需要注意的安全漏洞
  • 【优选算法系列】【专题二滑动窗口】第四节.30. 串联所有单词的子串和76. 最小覆盖子串
  • 详解Keras3.0 Models API: Model class
  • Linux gtest单元测试
  • 基于Java医院挂号管理系统
  • sql2005日志文件过大如何清理
  • C/C++,优化算法——双离子推销员问题(Bitonic Travelling Salesman Problem)的计算方法与源代码
  • 二分查找|前缀和|滑动窗口|2302:统计得分小于 K 的子数组数目
  • linux常用命令-pip命令详解(超详细)
  • 判断css文字发生了截断,增加悬浮提示
  • 一. 初识数据结构和算法
  • StoneDB-8.0-V2.2.0 企业版正式发布!性能优化,稳定性提升,持续公测中!
  • 十七、FreeRTOS之FreeRTOS事件标志组
  • 麒麟系统进入救援模式或者是crtl D界面排查方法
  • Linux下通过find找文件---通过修改时间查找(-mtime)
  • 网络工程师【目录】
  • Python 潮流周刊#29:Rust 会比 Python 慢?!
  • 初识人工智能,一文读懂人工智能概论(1)
  • win10 笔记本卡顿优化