当前位置: 首页 > article >正文

大语言模型---通过数值梯度的方式计算损失值L对模型权重矩阵W的梯度;数值梯度的公式;数值梯度计算过程

文章目录

  • 概要
  • 1. 数值梯度的公式
  • 2. 数值梯度计算过程
  • 3. 数值梯度的特点

概要

前文已经简单介绍梯度,本文主要介绍大语言模型中使用数值梯度的方法实现 损失值 L L L 对模型权重矩阵的梯度计算,而不是传统的链式法则进行梯度计算。如果想要理解整体计算方式,先明白损失值 L L L的计算方式,通过公式了解其和权重矩阵 W V W_V WV的关系。然后再理解损失值 L L L对权重矩阵 W V W_V WV的梯度计算。

1. 数值梯度的公式

数值梯度通过有限差分法近似计算梯度,对权重矩阵 W V W_V WV 中每个元素的梯度 ∂ L ∂ W V i j \frac{\partial L}{\partial W_{V_{ij}}} WVijL
∇ L W V i j = L p l u s − L c u r r e n t h \nabla L_{W_{V_{ij}}} = \frac{L_{plus}-L_{current}}{h} LWVij=hLplusLcurrent

其中,每个参数的含义在下文中有讲解。

2. 数值梯度计算过程

(1) 初始化

  • 给定权重矩阵 W V ∈ F m × n W_V \in \mathbb{F}^{m \times n} WVFm×n,与 W V W_V WV大小相同的梯度矩阵 ∇ L W V = zeros ( m , n ) \nabla L_{W_V} = \text{zeros}(m, n) LWV=zeros(m,n)
  • 确定增量 h h h 的值(如 h = 1 0 − 5 h=10^{−5} h=105)。

(2) 遍历权重矩阵的每个元素
对于 W V W_V WV中的每个元素 W V i j W_{V_{ij}} WVij

  1. 创建一个单位矩阵 E i j E_{ij} Eij,大小与 W V W_V WV相同,且 E i j = 1 E_{ij}=1 Eij=1
  2. 计算损失值:
  • L p l u s = L ( W v + h ∗ E i j ) L_{plus}=L(W_v+h*E_{ij}) Lplus=L(Wv+hEij)
    • W V W_V WV的第 ( i , j ) (i,j) (i,j) 元素增加一个微小值 h h h,得到新的权重矩阵,然后计算损失值 L p l u s L_{plus} Lplus.
  • L c u r r e n t = L ( W v ) L_{current}=L(W_v) Lcurrent=L(Wv):
    • 使用当前的权重矩阵 W V W_V WV计算损失值 L c u r r e n t L_{current} Lcurrent

(3) 梯度估算
通过有限差分公式,计算第 ( i , j ) (i,j) (i,j)元素的梯度:
∇ L W V i j = L p l u s − L c u r r e n t h \nabla L_{W_{V_{ij}}} = \frac{L_{plus}-L_{current}}{h} LWVij=hLplusLcurrent
这个公式的含义是:通过观察 W V i j W_{V_{ij}} WVij 增加 h h h 后损失函数的变化,我们可以估算出损失函数对该参数的敏感程度(梯度)。

3. 数值梯度的特点

优点:

  • 简单直观:无需解析推导梯度公式,直接利用损失函数计算。
  • 适合验证解析梯度:可以作为解析梯度的参考标准,用于检测实现是否正确。

缺点:

  1. 计算效率低
  • 对于权重矩阵 W V ∈ F m × n W_V \in \mathbb{F}^{m \times n} WVFm×n,需要计算 m × n m×n m×n 次损失。
  • 如果网络规模较大,数值梯度的计算会非常耗时。
  1. 数值误差:
  • 梯度近似的精度取决于 h h h 的选择。
  • h h h 太大会导致误差较大, h h h 太小可能引入浮点数精度问题。

http://www.kler.cn/a/410411.html

相关文章:

  • docker run m3e 配置网络,自动重启,GPU等 配置渠道要点
  • 软件/游戏提示:mfc42u.dll没有被指定在windows上运行如何解决?多种有效解决方法汇总分享
  • c++视频图像处理
  • redis的map底层数据结构 分别什么时候使用哈希表(Hash Table)和压缩列表(ZipList)
  • java 老矣,尚能饭否?
  • 【数据结构】—— 线索二叉树
  • macOS上进行Ant Design Pro实战教程(一)
  • 【51单片机】程序实验56.独立按键-矩阵按键
  • 【初阶数据与算法】线性表之顺序表的定义与实现
  • 跨平台开发_RTC程序设计:实时音视频权威指南 2
  • Web day02 Js Vue Ajax
  • Java的字符串操作(二)(代码示例)
  • spring的事务隔离?
  • IEC61850读服务器目录命令——GetServerDirectory介绍
  • Gitlab有趣而实用的功能
  • Ajax学习笔记,第一节:语法基础
  • 电影风格城市夜景旅拍Lr调色教程,手机滤镜PS+Lightroom预设下载!
  • 杂项驱动开发
  • 【JavaEE】Servlet:表白墙
  • CSS 样式入门:属性全知晓
  • Leetcode 组合
  • STM32WB55RG开发(5)----监测STM32WB连接状态
  • C#里怎么样访问文件时间
  • 《Shader入门精要》透明效果
  • Qt笔记-获取HTTP的POST请求提交的数据时需要注意的地方(2024-09-02)
  • 加菲工具 - 好用免费的在线工具集合