当前位置: 首页 > article >正文

【机器学习】机器学习的基本分类-监督学习-梯度提升树(Gradient Boosting Decision Tree, GBDT)

梯度提升树是一种基于**梯度提升(Gradient Boosting)**框架的机器学习算法,通过构建多个决策树并利用每棵树拟合前一棵树的残差来逐步优化模型。


1. 核心思想

  • Boosting:通过逐步调整模型,使后续的模型重点学习前一阶段未能正确拟合的数据。
  • 梯度提升:将误差函数的负梯度作为残差,指导新一轮模型的训练。
与随机森林的区别
特性随机森林梯度提升树
基本思想BaggingBoosting
树的训练方式并行训练顺序训练
树的类型完全树通常是浅树(弱学习器)
应用场景抗过拟合、快速训练高精度、复杂任务

 

2. 算法流程

  1. 输入

    • 数据集 D = \{ (x_i, y_i) \}_{i=1}^{n}​。
    • 损失函数 L(y, \hat{y}),如平方误差、对数似然等。
    • 弱学习器个数 T 和学习率 η。
  2. 初始化模型

    f_0(x) = \arg\min_c \sum_{i=1}^n L(y_i, c)
    • f_0 是一个常数,通常为目标变量的均值(回归)或类别概率的对数(分类)。
  3. 迭代训练每棵弱学习器(树)

    • 第 t 次迭代:
      1. 计算第 t 轮的负梯度(残差):
        r_i^{(t)} = -\left[ \frac{\partial L(y_i, f(x_i))}{\partial f(x_i)} \right]_{f=f_{t-1}}
        残差反映当前模型未能拟合的部分。
      2. 构建决策树 h_t(x) 拟合残差 r_i^{(t)}
      3. 计算最佳步长(叶节点输出值): \gamma_t = \arg\min_\gamma \sum_{i=1}^n L\left(y_i, f_{t-1}(x_i) + \gamma h_t(x_i)\right)
      4. 更新模型: f_t(x) = f_{t-1}(x) + \eta \gamma_t h_t(x) 其中 η 是学习率,控制每棵树的贡献大小。
  4. 输出模型: 最终模型为:

    f_T(x) = \sum_{t=1}^T \eta \gamma_t h_t(x)

 

3. 损失函数

GBDT 可灵活选择损失函数,以下是常用的几种:

  1. 平方误差(MSE,回归问题)

    L(y, \hat{y}) = \frac{1}{2} (y - \hat{y})^2
    • 负梯度: r_i = y_i - f(x_i)
  2. 对数似然(Log-Loss,二分类问题)

    L(y, \hat{y}) = -\left[ y \log \sigma(\hat{y}) + (1-y) \log(1-\sigma(\hat{y})) \right]
    • 负梯度: r_i = y_i - \sigma(f(x_i))
  3. 指数损失(Adaboost)

    L(y, \hat{y}) = e^{-y\hat{y}}

 4. GBDT 的优缺点

优点
  1. 灵活性:支持回归和分类任务,且损失函数可定制。
  2. 高精度:由于采用 Boosting 框架,能取得非常好的预测效果。
  3. 特征选择:内置特征重要性评估,帮助筛选关键特征。
  4. 处理缺失值:部分实现(如 XGBoost)可以自动处理缺失值。
缺点
  1. 训练时间长:由于弱学习器依次构建,训练过程较慢。
  2. 对参数敏感:需要调整学习率、树的数量、最大深度等参数。
  3. 不擅长高维稀疏数据:相比线性模型和神经网络,GBDT 在处理高维数据(如文本数据)时表现一般。

 5. GBDT 的改进

  1. XGBoost

    • 增加正则化项,控制模型复杂度。
    • 支持并行化计算,加速训练。
    • 提供更高效的特征分裂方法。
  2. LightGBM

    • 提出叶子分裂(Leaf-Wise)策略。
    • 适合大规模数据和高维特征场景。
  3. CatBoost

    • 专门针对分类特征优化。
    • 避免目标泄露(Target Leakage)。

 6. GBDT 的代码实现

以下是 GBDT 的分类问题实现:

from sklearn.datasets import make_classification
from sklearn.ensemble import GradientBoostingClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score

# 生成数据
X, y = make_classification(n_samples=1000, n_features=10, n_informative=5, random_state=42)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)

# 创建 GBDT 模型
gbdt = GradientBoostingClassifier(n_estimators=100, learning_rate=0.1, max_depth=3, random_state=42)
gbdt.fit(X_train, y_train)

# 预测
y_pred = gbdt.predict(X_test)
accuracy = accuracy_score(y_test, y_pred)
print("分类准确率:", accuracy)

# 特征重要性
import matplotlib.pyplot as plt
import numpy as np

feature_importances = gbdt.feature_importances_
indices = np.argsort(feature_importances)[::-1]

plt.figure(figsize=(10, 6))
plt.title("Feature Importance")
plt.bar(range(X.shape[1]), feature_importances[indices], align="center")
plt.xticks(range(X.shape[1]), indices)
plt.show()

输出结果

分类准确率: 0.9366666666666666


7. 应用场景

  1. 回归问题:如预测房价、商品销量。
  2. 分类问题:如金融风险预测、垃圾邮件分类。
  3. 排序问题:如搜索引擎的结果排序。
  4. 时间序列问题:预测趋势或模式。

GBDT 是机器学习中的经典算法,尽管深度学习在许多领域占据主导地位,但在表格数据和中小规模数据集的应用中,GBDT 仍然是非常强大的工具。


http://www.kler.cn/a/430197.html

相关文章:

  • Linux 系统报打开的文件过多
  • 如何在小米平板5上运行 deepin 23 ?
  • 后端报错: message: “For input string: \“\““
  • 知识图谱8:深度学习各种小模型
  • 服务路由和服务发现区别是什么?
  • linx使用命令还原数据库(source还原方式)
  • HCIP——VRRP的实验配置
  • 汉明距离算法
  • 【Linux】系统安装内核后重启发现进不去系统
  • Python爬虫:爬取动漫网站的排行榜数据并进行可视化分析
  • docker-compose 部署 mysql redis nginx nacos seata sentinel
  • Halcon 轮廓检测常用算子、原理及应用场景
  • PHP和GD库如何将图片转换为黑白图
  • Unity类银河战士恶魔城学习总结(P167 Blackhole additional vfx 黑洞技能额外特效)
  • 2023年第十四届蓝桥杯Scratch02月stema选拔赛真题-王子与骑士
  • 第三十九篇——条件概率和贝叶斯公式:机器翻译是怎么工作的?
  • 执行“go mod tidy”遇到“misbehavior”错误
  • 2024年华中杯数学建模C题基于光纤传感器的平面曲线重建算法建模解题全过程文档及程序
  • 【算法笔记】前缀和算法原理深度剖析(超全详细版)
  • gozero项目迁移与新服务器环境配置,包含服务器安装包括go版本,Nginx,项目配置包括Mysql,redis,rabbit,域名