当前位置: 首页 > article >正文

GBDT算法Python代码实现

### GBDT定义
class GBDT(object):
    def __init__(self, n_estimators, learning_rate, min_samples_split,
                 min_gini_impurity, max_depth, regression):
        ### 常用超参数
        # 树的棵树
        self.n_estimators = n_estimators
        # 学习率
        self.learning_rate = learning_rate
        # 结点最小分裂样本数
        self.min_samples_split = min_samples_split
        # 结点最小基尼不纯度
        self.min_gini_impurity = min_gini_impurity
        # 最大深度
        self.max_depth = max_depth
        # 默认为回归树
        self.regression = regression
        # 损失为平方损失
        self.loss = SquareLoss()
        # 如果是分类树,需要定义分类树损失函数
        # 这里省略,如需使用,需自定义分类损失函数
        if not self.regression:
            self.loss = None
        # 多棵树叠加
        self.estimators = []
        for i in range(self.n_estimators):
            self.estimators.append(RegressionTree(min_samples_split=self.min_samples_split,
                                             min_gini_impurity=self.min_gini_impurity,
                                             max_depth=self.max_depth))
    # 拟合方法
    def fit(self, X, y):
        # 前向分步模型初始化,第一棵树
        self.estimators[0].fit(X, y)
        # 第一棵树的预测结果
        y_pred = self.estimators[0].predict(X)
        # 前向分步迭代训练
        for i in range(1, self.n_estimators):
            gradient = self.loss.gradient(y, y_pred)
            self.estimators[i].fit(X, gradient)
            y_pred -= np.multiply(self.learning_rate, self.estimators[i].predict(X))
            
    # 预测方法
    def predict(self, X):
        # 回归树预测
        y_pred = self.estimators[0].predict(X)
        for i in range(1, self.n_estimators):
            y_pred -= np.multiply(self.learning_rate, self.estimators[i].predict(X))
        # 分类树预测
        if not self.regression:
            # 将预测值转化为概率
            y_pred = np.exp(y_pred) / np.expand_dims(np.sum(np.exp(y_pred), axis=1), axis=1)
            # 转化为预测标签
            y_pred = np.argmax(y_pred, axis=1)
        return y_pred

我们来详细地解释这个 GBDT 类的实现,包括每个步骤的逻辑细节。

GBDT 类结构

这个 GBDT 类实现了梯度提升决策树(GBDT)的基本功能。GBDT 是一种集成学习方法,通过逐步增加树模型来提升整体模型的预测性能。每棵树都在前一棵树的基础上拟合残差,即当前模型预测值和真实值之间的误差。下面逐步解析代码。


1. __init__ 方法:初始化模型参数

def __init__(self, n_estimators, learning_rate, min_samples_split,
             min_gini_impurity, max_depth, regression):

作用:初始化 GBDT 模型的超参数,这些超参数决定了模型的结构和训练方式。

具体参数解析
  • n_estimators:要训练的树的数量,也称为“迭代次数”或“树的棵数”。每一轮(每一棵树)都会拟合前面所有树的预测误差,增加树的数量通常可以提升模型性能,但也增加了计算开销。

  • learning_rate:学习率,控制每一棵新树对整体模型的贡献。较低的学习率使模型更稳定,但需要更多树来达到良好的性能。可以理解为“步长”,每棵树更新预测结果的步长。

  • min_samples_split:每个节点进行分裂所需的最小样本数,防止生成过小的叶子节点。如果一个节点的样本数少于该值,则不再继续分裂。

  • min_gini_impurity:每个节点分裂的最小基尼不纯度,用于控制节点分裂的标准。在构建树的过程中,分裂节点时要求基尼不纯度达到一定标准,这样可以保证分裂质量。

  • max_depth:树的最大深度,防止树生长得太深而导致过拟合。设置一个较小的最大深度可以限制模型的复杂度,使模型具有更好的泛化能力。

  • regression:用于指示模型是回归任务还是分类任务。True 表示回归任务,False 表示分类任务。


定义损失函数
self.loss = SquareLoss()
if not self.regression:
    self.loss = None
  • 如果 regression 为 True,表示这是一个回归任务,默认使用平方损失(SquareLoss)。平方损失的梯度是预测值与真实值之差,即: gradient = y − y_pred \text{gradient} = y - \text{y\_pred} gradient=yy_pred
  • 如果是分类任务,需要一个适合分类的损失函数,不过代码中未定义分类损失函数,需要进一步扩展。
初始化树的集合
self.estimators = []
for i in range(self.n_estimators):
    self.estimators.append(RegressionTree(min_samples_split=self.min_samples_split,
                                          min_gini_impurity=self.min_gini_impurity,
                                          max_depth=self.max_depth))
  • self.estimators 是一个列表,用来存储每棵树(即 RegressionTree 的实例)。
  • 每棵树都初始化为一个 RegressionTree 对象,参数包括 min_samples_splitmin_gini_impuritymax_depth。这些参数确保每棵树在生成时有一致的控制条件。
  • 循环创建 n_estimators 棵树。每一轮迭代训练时,都会使用列表中的下一棵树拟合当前的残差。

2. fit 方法:训练模型

def fit(self, X, y):

作用:该方法用于训练模型。GBDT的训练过程是前向逐步添加树的过程,每棵树拟合的是当前模型的残差,从而逐步优化模型。

具体步骤解析
  1. 初始化模型:训练第一棵树,并生成初始预测值。

    self.estimators[0].fit(X, y)
    y_pred = self.estimators[0].predict(X)
    
    • 直接用第一棵树拟合 X X X y y y 数据,将 y 作为真实值直接拟合。
    • 生成的初始预测值 y_pred 是对数据的第一个估计。

    这样我们就有了初始模型的预测值 y_pred

  2. 逐步拟合残差:在接下来的每一轮中,通过计算当前预测的梯度(残差),让下一棵树来拟合这个残差,逐步优化模型。

    for i in range(1, self.n_estimators):
        gradient = self.loss.gradient(y, y_pred)
        self.estimators[i].fit(X, gradient)
        y_pred -= np.multiply(self.learning_rate, self.estimators[i].predict(X))
    
    • 计算梯度(残差)gradient = self.loss.gradient(y, y_pred),在这里,self.loss.gradient 计算的是损失函数(这里是平方损失)的梯度,也就是真实值 y y y 与当前预测值 y_pred 的差。

      • 对于平方损失,梯度就是残差 r i = y i − y ^ i r_i = y_i - \hat{y}_i ri=yiy^i
    • 拟合残差:用新树 self.estimators[i] 拟合当前梯度 gradient,使其尽量接近残差,这样下一棵树的预测就能够纠正前面模型的不足。

    • 更新预测值y_pred -= np.multiply(self.learning_rate, self.estimators[i].predict(X))

      • 用新的树的预测值 self.estimators[i].predict(X) 乘以学习率来更新 y_pred,使当前预测值 y_pred 更接近真实值。
      • 这里减去的是每棵树的预测值乘以学习率。

通过这个过程,每一轮都会让模型的预测值 y_pred 更加接近真实值。


3. predict 方法:模型预测

def predict(self, X):

作用:预测新数据 X X X 的结果,基于训练好的 GBDT 模型。

具体步骤解析
  1. 初始化预测值:使用第一棵树生成初始预测。

    y_pred = self.estimators[0].predict(X)
    

    这里初始化预测值 y_pred,它是由第一棵树生成的。

  2. 逐步加上后续树的预测:每一棵树的预测值都会乘以学习率,然后从 y_pred 中减去。

    for i in range(1, self.n_estimators):
        y_pred -= np.multiply(self.learning_rate, self.estimators[i].predict(X))
    

    逐步累加后续树的预测值来更新 y_pred,让模型的预测更加准确。

  3. 分类任务的处理:如果是分类任务,需要对 y_pred 进行后处理。

    if not self.regression:
        y_pred = np.exp(y_pred) / np.expand_dims(np.sum(np.exp(y_pred), axis=1), axis=1)
        y_pred = np.argmax(y_pred, axis=1)
    
    • 计算类别概率:使用 softmax 函数(通过 np.exp 和归一化)计算各类的概率分布。
    • 选择预测标签:选择概率最高的类别作为预测标签。
  4. 返回预测值:对于回归任务,直接返回 y_pred;对于分类任务,返回分类标签。


代码总结

这个 GBDT 类实现了梯度提升决策树的基础原理:

  • 每棵树都在拟合前一轮模型的残差。
  • 每次训练后更新预测值,使模型逐步逼近真实值。
  • 对于分类任务,最后还要将连续输出转化为类别标签。

http://www.kler.cn/a/378562.html

相关文章:

  • 【深度学习项目】语义分割-DeepLab网络(DeepLabV3介绍、基于Pytorch实现DeepLabV3网络)
  • 查看电脑或笔记本CPU的核心数方法及CPU详细信息
  • AutoGen入门——快速实现多角色、多用户、多智能体对话系统
  • Linux下MySQL的简单使用
  • 【环境搭建】Metersphere v2.x 容器部署教程踩坑总结
  • 深入理解GPT底层原理--从n-gram到RNN到LSTM/GRU到Transformer/GPT的进化
  • HTML5和CSS3 介绍
  • 加强版 第六节 图像轮廓几何属性分析
  • 无人机维修培训班开班课程技术详解
  • 「Mac畅玩鸿蒙与硬件17」鸿蒙UI组件篇7 - Animation 组件基础
  • npm入门教程17:准备发布的npm包
  • 家具制造的效率与美观并重,玛哈特矫平机让家具产品更具竞争力。
  • 2024前端面试训练计划-高频题-网络基础篇
  • QT中TextEdit或者QLineEdit以十六进制显示数组数据
  • uni-app 下拉刷新、 上拉触底(列表信息)、 上滑加载(短视频) 一键搞定
  • nginx配置转发到elk的kibana的服务器
  • 【开发工具——依赖管理工具——Maven】
  • unity c# Tcp网络通讯
  • C++ 函数调用时的参数传递方法
  • 线性数据结构之队列
  • 【读书笔记/深入理解K8S】集群控制器
  • 《GBDT 算法的原理推导》 11-15更新决策树的叶子节点值 公式解析
  • mac 系统下载 vscode
  • 如何设置使PPT的画的图片导出变清晰
  • 自动驾驶-端到端大模型
  • 三层交换实现不同VLAN之间设备的互通