当前位置: 首页 > article >正文

详解机器学习经典模型(原理及应用)——GBDT

一、什么是GBDT

        梯度提升决策树(Gradient Boosting Decision Tree, GBDT)是一种集成学习算法,它通过迭代地训练决策树来最小化损失函数,从而提高模型的预测性能。GBDT的核心思想是将多个弱学习器(通常是决策树)的结果累加起来,形成强学习器,即Boosting。与随机森林相同,弱学习器可以是分类树也可以是回归树(有的观点说GBDT只能由回归树构成,这可能是由于GBDT通常使用回归树解决问题从而给人造成了误解,实际上分类指标同样能够用来计算负梯度,具体使用的是回归树还是分类树,主要看损失函数的构造以及任务的目标值类型)。

二、GBDT算法原理

        GBDT的工作原理如下:

1、初始化

        首先,GBDT使用一个初始预测器(通常是一个常数值)对所有样本进行预测。

2、迭代训练

        在每次迭代中,GBDT会计算损失函数的负梯度,这个负梯度将作为残差,用于训练下一棵决策树。损失函数的负梯度可以被视为当前模型预测值与真实值之间差异的度量。对于不同的损失函数,负梯度的计算方式也会有所不同。例如,对于均方误差损失函数(回归问题),负梯度的计算公式为:

r_{ti} = -\left [ \frac{\partial L(y_{i},f(x_{i}))}{\partial f(x_{i})} \right ]_{f(x)=f_{t-1}(x)}=y_{i}-f_{t-1}(x_{i})

        其中,r_{ti}是第t轮迭代中第i个样本的残差,y_{i}是真实值,f_{t-1}(x_{i})是前一轮迭代的预测值。

3、更新模型

        在每次迭代中,新的决策树被添加到模型中,模型的预测值更新为:

\hat{y_{i}}^{(t+1)} = \hat{y_{i}}^{(t)}-\eta h_{t}(x_{i})

        其中h_{t}(x_{i})是第t轮迭代中训练的决策树对样本x_{i}的预测值。\eta是学习率。而对于每一颗决策树(即弱学习器),有:

h_{t}(x) = argmin_{h}\sum_{i=1}^{n}L(y_{i}, \hat{y}_{i}^{(t)}+h (x_{i}))

        这里,h为决策树模型,L是损失函数。

4、终止条件

        最终的GBDT模型是所有弱学习器的加权和:

\hat{y}(x) = \hat{y}^{(0)}+\sum_{t-1}^{T}\eta _{t}\times \alpha_{t} \times h_{t}(x)

        其中,T是迭代次数,\hat{y}^{(0)}是初始预测值,\alpha_{t}是第t轮迭代中决策树的权重(通常与当前决策树的复杂度成反比)。当达到预设的迭代次数或模型性能不再显著提升时,GBDT停止迭代。

三、梯度提升与梯度下降

        梯度下降是一种优化算法,用于最小化损失函数。它通过计算损失函数相对于模型参数的梯度(即损失函数在参数空间中的斜率),然后沿着梯度的反方向更新参数,以此减小损失函数的值。

        梯度提升是一种集成学习算法,“提升”指的是通过添加一个新的弱学习器迭代改进模型的方法),以此来最小化损失函数。在每一步迭代中,梯度提升算法计算当前模型的残差(即损失函数的负梯度),然后将一个新的弱学习器拟合到这些残差上。这个过程可以看作是在每一步迭代中,模型都在尝试修正前一步的预测误差。

        两者都是在每一轮迭代中,利用损失函数相对于模型的负梯度方向的信息来对当前模型进行更新,只不过在梯度下降中直接使用损失函数的负梯度来更新参数,而在梯度提升中使用损失函数的负梯度作为残差的近似值,而不是直接用于更新参数。

四、GBDT应用

1、分类

from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.ensemble import GradientBoostingClassifier
from sklearn.metrics import accuracy_score

# 加载数据集
iris = load_iris()
X, y = iris.data, iris.target

# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# 创建GBDT分类器
gbdt_clf = GradientBoostingClassifier(n_estimators=100, learning_rate=0.1, max_depth=3, random_state=42)

# 训练模型
gbdt_clf.fit(X_train, y_train)

# 预测测试集
y_pred = gbdt_clf.predict(X_test)

# 计算准确率
accuracy = accuracy_score(y_test, y_pred)
print(f"Accuracy: {accuracy:.2f}")

2、回归

from sklearn.datasets import load_boston
from sklearn.model_selection import train_test_split
from sklearn.ensemble import GradientBoostingRegressor
from sklearn.metrics import mean_squared_error

# 加载波士顿房价数据集
boston = load_boston()
X, y = boston.data, boston.target

# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# 创建GBDT回归器
gbdt_reg = GradientBoostingRegressor(n_estimators=100, learning_rate=0.1, max_depth=3, random_state=42)

# 训练模型
gbdt_reg.fit(X_train, y_train)

# 预测测试集
y_pred = gbdt_reg.predict(X_test)

# 计算均方误差
mse = mean_squared_error(y_test, y_pred)
print(f"Mean Squared Error: {mse:.2f}")

# 计算R^2分数
r2_score = gbdt_reg.score(X_test, y_test)
print(f"R^2 Score: {r2_score:.2f}")

五、总结

        虽然实际业务工作中,考虑GBDT模型的时候我们基本都会使用其工程优化版本——XGBoost以及LightGBM,但它仍然是值得深入学习的经典机器学习模型之一。以下是GBDT的一些优缺点:

1、优点

        (1)高准确性:GBDT能够处理高维度、稀疏特征以及非线性关系等复杂问题,因此在训练集和测试集上都表现良好。

        (2)强大的泛化能力:通过组合多个弱分类器形成一个强分类器,减少了过拟合的风险。

        (3)对缺失值的鲁棒性:GBDT能够自动处理缺失值,无需额外的处理步骤。

        (4)可并行化的预测阶段:虽然训练过程是串行的,但预测时可以并行计算,提高了预测速度。

2、缺点

        (1)训练时间较长:GBDT是串行算法,需要按顺序构建每棵决策树,因此训练时间较长。

        (2)对异常值敏感:GBDT在训练过程中容易受到异常值的影响,可能导致模型性能下降。

        (3)无法并行化训练:GBDT的训练过程无法并行化,工程加速只能体现在单颗树构建过程中。


http://www.kler.cn/a/321411.html

相关文章:

  • 深度解析:Android APP集成与拉起微信小程序开发全攻略
  • HarmonyOS NEXT应用开发实战 ( 应用的签名、打包上架,各种证书详解)
  • SQL面试题——蚂蚁SQL面试题 会话分组问题
  • 网络延迟对Python爬虫速度的影响分析
  • Keil基于ARM Compiler 5的工程迁移为ARM Compiler 6的工程
  • 【教程】Ubuntu设置alacritty为默认终端
  • springboot实战学习(7)(JWT令牌的组成、JWT令牌的使用与验证)
  • 计算机毕业设计之:微信小程序的校园闲置物品交易平台(源码+文档+讲解)
  • 【ARM 嵌入式 编译系列 10.5 -- ARM toolchain naming convention】
  • 如何在CMakeList项目中集成GNU Autotools 构建模块
  • JavaSE——Arrays类、System类
  • 网格大师OSGB转OBJ,转换类型中的非拓扑、拓扑、重建有什么区别?
  • 【Docker】01-Docker常见指令
  • 【Linux实践】实验八:Shell程序的创建及变量
  • Scala第二天
  • 【C++笔试强训】如何成为算法糕手Day5
  • 解决TikTok无法注册或注册不了的问题
  • 手机使用技巧:如何修复变砖的 Android 手机
  • 策略模式
  • [笔记]某S厂减速箱部件参数表 - 技术问题海外联系方式
  • JavaScript typeof运算符
  • 实变函数精解【25】
  • Excel锁定单元格,使其不可再编辑
  • QT开发:详解 Qt 多线程编程核心类 QThread:基本概念与使用方法
  • 大语言模型量化方法GPTQ、GGUF、AWQ详细原理
  • 【算法】二叉树中的 DFS