当前位置: 首页 > article >正文

《GBDT 算法的原理推导》 11-15更新决策树的叶子节点值 公式解析

本文是将文章《GBDT 算法的原理推导》中的公式单独拿出来做一个详细的解析,便于初学者更好的理解。


公式(11-15)出现在GBDT算法推导的过程中,用于更新决策树的叶子节点值

公式(11-15)如下:

c m j = arg ⁡ min ⁡ c ∑ x i ∈ R m j L ( y i , f m − 1 ( x i ) + c ) c_{mj} = \arg \min_c \sum_{x_i \in R_{mj}} L(y_i, f_{m-1}(x_i) + c) cmj=argcminxiRmjL(yi,fm1(xi)+c)

其中:

  • c m j c_{mj} cmj 表示第 m m m 棵树在叶子节点 R m j R_{mj} Rmj 上的输出值,也就是叶子节点的预测值。
  • R m j R_{mj} Rmj 表示第 m m m 棵树的第 j j j 个叶子节点区域。
  • L ( y i , f ( x i ) ) L(y_i, f(x_i)) L(yi,f(xi)) 是损失函数,衡量样本 x i x_i xi 的真实值 y i y_i yi 和当前模型预测值 f ( x i ) f(x_i) f(xi) 之间的误差。
  • f m − 1 ( x i ) f_{m-1}(x_i) fm1(xi) 是前 m − 1 m-1 m1 轮构建的模型在 x i x_i xi 处的预测值。

1. 公式(11-15)的背景

在GBDT中,每一棵树的任务是对当前模型的误差进行拟合和修正。我们通过新增一棵树 T ( x ; Θ m ) T(x; \Theta_m) T(x;Θm) 来改善当前模型的预测能力。

这棵新树的结构(分裂方式)已经确定,它会将输入样本分配到不同的叶子节点区域 R m j R_{mj} Rmj。接下来,我们需要为每个叶子节点分配一个值,使得这棵树能够最好地拟合该节点区域内的样本误差。

2. 目标是最小化损失

在叶子节点 R m j R_{mj} Rmj 上,我们希望找到一个最优的输出值 c m j c_{mj} cmj,使得它能够最小化该节点区域内所有样本的损失。具体来说,给定前 m − 1 m-1 m1 棵树的预测值 f m − 1 ( x i ) f_{m-1}(x_i) fm1(xi),新树的叶子节点值 c m j c_{mj} cmj 需要满足以下优化目标:

c m j = arg ⁡ min ⁡ c ∑ x i ∈ R m j L ( y i , f m − 1 ( x i ) + c ) c_{mj} = \arg \min_c \sum_{x_i \in R_{mj}} L(y_i, f_{m-1}(x_i) + c) cmj=argcminxiRmjL(yi,fm1(xi)+c)

这意味着,我们为每个叶子节点选择一个最优的常数 c c c,使得该节点区域内所有样本的损失之和最小。

3. 损失函数 L ( y i , f m − 1 ( x i ) + c ) L(y_i, f_{m-1}(x_i) + c) L(yi,fm1(xi)+c)

在GBDT算法中,不同的损失函数 L ( y , f ( x ) ) L(y, f(x)) L(y,f(x)) 会影响叶子节点的最优输出值计算方式。常见的损失函数包括:

  • 平方损失:用于回归任务。
  • 对数损失:用于二分类任务。

不同的损失函数会导致不同的叶子节点值计算方式。下面以平方损失为例来说明如何求解。

例子:平方损失

假设损失函数是平方损失:

L ( y i , f ( x i ) ) = 1 2 ( y i − f ( x i ) ) 2 L(y_i, f(x_i)) = \frac{1}{2} (y_i - f(x_i))^2 L(yi,f(xi))=21(yif(xi))2

代入公式(11-15)中的 L ( y i , f m − 1 ( x i ) + c ) L(y_i, f_{m-1}(x_i) + c) L(yi,fm1(xi)+c)

c m j = arg ⁡ min ⁡ c ∑ x i ∈ R m j 1 2 ( y i − ( f m − 1 ( x i ) + c ) ) 2 c_{mj} = \arg \min_c \sum_{x_i \in R_{mj}} \frac{1}{2} (y_i - (f_{m-1}(x_i) + c))^2 cmj=argcminxiRmj21(yi(fm1(xi)+c))2

我们对 c c c 求导,并让导数等于零,以找到最优的 c c c

∂ ∂ c ∑ x i ∈ R m j 1 2 ( y i − f m − 1 ( x i ) − c ) 2 = 0 \frac{\partial}{\partial c} \sum_{x_i \in R_{mj}} \frac{1}{2} (y_i - f_{m-1}(x_i) - c)^2 = 0 cxiRmj21(yifm1(xi)c)2=0

这等价于:

∑ x i ∈ R m j ( y i − f m − 1 ( x i ) − c ) = 0 \sum_{x_i \in R_{mj}} (y_i - f_{m-1}(x_i) - c) = 0 xiRmj(yifm1(xi)c)=0

解这个方程,可以得到:

c = ∑ x i ∈ R m j ( y i − f m − 1 ( x i ) ) ∣ R m j ∣ c = \frac{\sum_{x_i \in R_{mj}} (y_i - f_{m-1}(x_i))}{|R_{mj}|} c=RmjxiRmj(yifm1(xi))

即:

c m j = ∑ x i ∈ R m j ( y i − f m − 1 ( x i ) ) ∣ R m j ∣ c_{mj} = \frac{\sum_{x_i \in R_{mj}} (y_i - f_{m-1}(x_i))}{|R_{mj}|} cmj=RmjxiRmj(yifm1(xi))

这表明,在平方损失的情况下,叶子节点的输出值 c m j c_{mj} cmj 是该叶子节点区域内所有样本残差( y i − f m − 1 ( x i ) y_i - f_{m-1}(x_i) yifm1(xi))的平均值。

总结

公式(11-15)表示了GBDT算法中如何确定每棵树的叶子节点值。通过最小化叶子节点区域内的损失,可以找到一个最优的输出值 c m j c_{mj} cmj,使得该节点区域的样本预测误差最小化。在不同的损失函数下,最优值 c m j c_{mj} cmj 的计算方式可能有所不同,但原理都是基于最小化损失来确定最佳的叶子节点值。


http://www.kler.cn/a/378544.html

相关文章:

  • 基于 Spring Boot 和 Vue.js 的全栈购物平台开发实践
  • 服务器日志自动上传到阿里云OSS备份
  • 可视化-numpy实现线性回归和梯度下降法
  • HarmonyOS NEXT:华为分享-碰一碰开发分享
  • 数据结构——栈
  • 【面试题】JVM部分[2025/1/13 ~ 2025/1/19]
  • mac 系统下载 vscode
  • 如何设置使PPT的画的图片导出变清晰
  • 自动驾驶-端到端大模型
  • 三层交换实现不同VLAN之间设备的互通
  • SQL 常用语句
  • 【系统架构设计师】2024年上半年真题论文: 论云上自动化运维级其应用(包括解题思路和素材)
  • 项目模块十四:HttpRequest模块
  • 六西格玛项目助力,手术机器人零部件国产化稳中求胜——张驰咨询
  • LLaMA系列一直在假装开源...
  • 基于YOLO11/v10/v8/v5深度学习的危险驾驶行为检测识别系统设计与实现【python源码+Pyqt5界面+数据集+训练代码】
  • 【p2p、分布式,区块链笔记 Torrent】通过网络编程库net集成bittorrent-protocol协议
  • ps技巧,来源于网络
  • Linux -- 信号的常见产生方式
  • MySQL日志——针对实习面试
  • 聚观早报 | 苹果推出新款iMac;华为Mate 70系列将上市
  • 并发编程中的CAS思想
  • 富格林:曝光欺诈陷阱纠正误区
  • ssm042在线云音乐系统的设计与实现+jsp(论文+源码)_kaic
  • 筛选Excel数据
  • 显卡服务器的作用都有哪些?