当前位置: 首页 > article >正文

【课堂笔记】线性回归梯度下降的矩阵求导推导

参考文章
参考文章
参考文章

线性回归

给定数据集 D = {   X , y   } \mathcal{D}=\set{X, y} D={X,y},可学习参数 w ∈ R D w \in \mathbb{R}^D wRD
y = ( y 1 y 2 ⋮ y N ) ∈ R N y = \begin{pmatrix} y_1 \\ y_2 \\ \vdots \\ y_N \end{pmatrix} \in \mathbb{R}^N y= y1y2yN RN X = ( x 11 x 12 ⋯ x 1 D x 21 x 22 ⋯ x 2 D ⋮ ⋮ ⋱ ⋮ x N 1 x N 2 ⋯ x N D ) ∈ R N × D X = \begin{pmatrix} x_{11} & x_{12} & \cdots & x_{1D}\\ x_{21} & x_{22} & \cdots & x_{2D}\\ \vdots & \vdots & \ddots & \vdots\\ x_{N1} & x_{N2} & \cdots & x_{ND} \end{pmatrix} \in \mathbb{R}^{N \times D} X= x11x21xN1x12x22xN2x1Dx2DxND RN×D
定义损失向量 e = y − X w = ( e 1 e 2 ⋮ e N ) ∈ R N e = y - Xw = \begin{pmatrix} e_1 \\ e_2 \\ \vdots \\ e_N \end{pmatrix} \in \mathbb{R}^N e=yXw= e1e2eN RN,其中 e i = y i − x i T w e_i = y_i - x_i^Tw ei=yixiTw
M S E MSE MSE L ( w ) = 1 2 N ∑ N n = 1 ( y n − x n T w ) 2 = 1 2 N e T e \mathcal{L}(w) = \frac{1}{2N}\underset{n=1}{\overset{N}{\sum}}(y_n - x_n^Tw)^2=\frac{1}{2N}e^Te L(w)=2N1n=1N(ynxnTw)2=2N1eTe

然后计算 ∂ L ( w ) ∂ w \frac{\partial \mathcal{L}(w)}{\partial w} wL(w)

L ( w ) = 1 2 N e T e = 1 2 N ( y − X w ) T ( y − X w ) = 1 2 N ( y T − w T X T ) ( y − X w ) \mathcal{L}(w) = \frac{1}{2N}e^Te = \frac{1}{2N}(y - Xw)^T(y - Xw)=\frac{1}{2N}(y^T-w^TX^T)(y - Xw) L(w)=2N1eTe=2N1(yXw)T(yXw)=2N1(yTwTXT)(yXw)
= 1 2 N ( y T y − y T X w − w T X T y + w T X T X w ) =\frac{1}{2N}(y^Ty-y^TXw-w^TX^Ty+w^TX^TXw) =2N1(yTyyTXwwTXTy+wTXTXw)

∂ y T X w ∂ w = X T y \frac{\partial y^TXw}{\partial w} = X^Ty wyTXw=XTy

∂ w T X T y ∂ w = ∂ y T X w ∂ w = X T y \frac{\partial w^TX^Ty}{\partial w}=\frac{\partial y^TXw}{\partial w}=X^Ty wwTXTy=wyTXw=XTy,这里是因为标量转置等于自己。

∂ w T X T X w ∂ w = 2 X T X w \frac{\partial w^TX^TXw}{\partial w}=2X^TXw wwTXTXw=2XTXw

因此 ∇ L ( w ) = ∂ L ( w ) ∂ w = − 1 2 N ( 2 X T y − 2 X T X w ) = − 1 N X T e \nabla\mathcal{L}(w) = \frac{\partial \mathcal{L}(w)}{\partial w} = -\frac{1}{2N}(2X^Ty - 2X^TXw) = -\frac{1}{N}X^Te L(w)=wL(w)=2N1(2XTy2XTXw)=N1XTe


http://www.kler.cn/a/563439.html

相关文章:

  • redis-bitmap使用场景
  • Nacos + Dubbo3 实现微服务的Rpc调用
  • Vue组件间通信的方式
  • 毕业项目推荐:基于yolov8/yolo11的苹果叶片病害检测识别系统(python+卷积神经网络)
  • angular舒尔特方格
  • 医院HIS接入大模型:算力基础设施与训练能力的深度剖析与测算
  • Docker基础-常见命令
  • 面试之《react中,fiber更新时,怎么判断fiber是否应该在当前idle内执行》
  • osgEarth安装总结
  • 将夸克网盘的webdav挂载成本地磁盘驱动器时报错“405“
  • 学习汇编前置知识第二天
  • nginx代理后502
  • seacmsv9 SQL注入漏洞
  • DFS:二叉树的深搜与回溯
  • 冒泡排序(Bubble Sort)详细教程:Java实现与优化
  • Python使用列表实现栈、队列学习记录
  • 【C++修炼之路】C++类与对象:面向对象编程的第一步
  • 字节火山引擎-大模型声音复刻,流式语音合成接口
  • Linux网络之传输层协议(UDP,TCP协议)
  • OpenGL ES -> GLSurfaceView绘制点、线、三角形、正方形、圆(顶点法绘制)