当前位置: 首页 > article >正文

【漫话机器学习系列】033.决策树回归(Decision Tree Regression)

决策树回归(Decision Tree Regression)

决策树回归是一种基于树状结构进行回归分析的监督学习方法。它将输入空间递归地划分为多个区域,并在每个区域内拟合一个简单的常数值,从而对目标变量进行预测。


决策树回归的原理

  1. 树的构建

    • 决策树以树的形式对数据进行划分。
    • 每次划分选择一个特征及其阈值,将数据集分为两个子集。
    • 目标是找到最佳划分,使得子集内的目标变量尽可能一致(即减少误差)。
  2. 划分准则
    通常采用均方误差(MSE, Mean Squared Error)作为划分的评价指标:

                                        MSE = \frac{1}{n} \sum_{i=1}^n (y_i - \hat{y})^2

    其中,y_i 是真实值,\hat{y}​ 是预测值。

  3. 停止条件

    • 达到最大树深度。
    • 叶节点的样本数少于预设值。
    • 划分后误差改善不足。
  4. 预测
    对于新输入数据,沿着决策树从根节点到叶节点,根据划分规则找到其对应的叶节点,返回叶节点中目标变量的均值作为预测值。


构建过程

  1. 根节点的初始化
    将所有数据视为一个整体,计算均值作为预测值,计算当前数据集的均方误差。

  2. 递归划分

    • 遍历每个特征及其所有可能的划分点,计算划分后的均方误差。
    • 选择能最大程度减少误差的特征及阈值进行划分。
  3. 停止划分

    • 当树的深度达到预设值。
    • 当叶节点的样本数小于预设阈值。
    • 当划分后误差改善不足。

优点

  1. 可解释性强
    决策树的结构直观清晰,易于可视化和理解。

  2. 非线性建模能力
    决策树能有效捕获数据中的非线性关系。

  3. 无需特征缩放
    决策树对特征的数值范围不敏感,不需要标准化或归一化。


缺点

  1. 易过拟合
    决策树在深度较大时可能会过拟合,导致泛化能力差。

  2. 对数据分布敏感
    对于小的样本噪声或异常值,可能会导致不稳定的划分。

  3. 无法捕获连续目标变量的平滑关系
    决策树只能在区域内拟合常数值,难以捕获目标变量的连续变化。


改进方法

  1. 剪枝(Pruning)

    • 预剪枝:设置树的最大深度、叶节点的最小样本数等参数,限制树的规模。
    • 后剪枝:先构建一棵完整的树,然后通过去掉不重要的分支来减少过拟合。
  2. 集成学习

    • 随机森林:构建多棵决策树并取平均值。
    • 梯度提升树(GBDT):通过串联多个决策树逐步减小误差。
    • 极端随机树(Extra Trees):进一步随机化特征和划分点选择,降低过拟合风险。

评价指标

  • 均方误差(MSE)

    MSE = \frac{1}{n} \sum_{i=1}^n (y_i - \hat{y}_i)^2
  • 平均绝对误差(MAE, Mean Absolute Error)

    MAE = \frac{1}{n} \sum_{i=1}^n |y_i - \hat{y}_i|
  • 决定系数(R2R^2R2)
    衡量模型对目标变量的解释程度:

    R^2 = 1 - \frac{\sum_{i=1}^n (y_i - \hat{y}_i)^2}{\sum_{i=1}^n (y_i - \bar{y})^2}

    其中 \bar{y} 是目标变量的均值。


代码实现

以下是使用 Python 中的 Scikit-learn 实现决策树回归的代码示例:

from sklearn.tree import DecisionTreeRegressor
from sklearn.metrics import mean_squared_error, r2_score
from sklearn.model_selection import train_test_split
import numpy as np

# 生成模拟数据
np.random.seed(0)
X = np.sort(np.random.rand(100, 1), axis=0)
y = np.sin(2 * np.pi * X).ravel() + np.random.randn(100) * 0.1

# 数据划分
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# 构建决策树回归模型
model = DecisionTreeRegressor(max_depth=4)
model.fit(X_train, y_train)

# 预测
y_pred = model.predict(X_test)

# 评价模型
mse = mean_squared_error(y_test, y_pred)
r2 = r2_score(y_test, y_pred)
print(f"Mean Squared Error: {mse:.3f}")
print(f"R^2 Score: {r2:.3f}")

# 可视化结果
import matplotlib.pyplot as plt
plt.scatter(X_test, y_test, color="blue", label="True Values")
plt.scatter(X_test, y_pred, color="red", label="Predicted Values")
plt.legend()
plt.title("Decision Tree Regression")
plt.xlabel("X")
plt.ylabel("y")
plt.show()

输出结果

Mean Squared Error: 0.038
R^2 Score: 0.939

 


应用场景

  1. 房地产价格预测
    根据特征(面积、位置、房龄等)预测房价。

  2. 市场营销分析
    根据用户行为数据预测用户对产品的需求。

  3. 时间序列分析
    使用历史数据预测未来值。


总结

决策树回归是简单易用的回归模型,特别适合处理非线性和非参数问题。然而,单独使用决策树可能会过拟合或欠拟合,因此需要通过剪枝或集成方法进一步提升模型的鲁棒性和泛化能力。


http://www.kler.cn/a/466372.html

相关文章:

  • springboot远程链接spark
  • 【机器学习:一、机器学习简介】
  • 【python因果库实战15】因果生存分析4
  • GESP真题 | 2024年12月1级-编程题4《美丽数字》及答案(C++版)
  • 《Rust权威指南》学习笔记(三)
  • 数学建模入门——建模流程
  • 移动构造函数详解
  • MySQL使用通用二进制文件安装到Unix/Linux
  • 32单片机从入门到精通之开发环境——调试工具(七)
  • nodeJS下npm和yarn的关系和区别详解
  • 嵌入式应用软件开发中C语言方向面试题
  • ClickHouse副本搭建
  • 关于AI面试系统2025年趋势评估!
  • 【Multisim用74ls92和90做六十进制】2022-6-12
  • dns网址和ip是一一对应的吗?
  • AMP 混合精度训练中的动态缩放机制: grad_scaler.py函数解析( torch._amp_update_scale_)
  • Android 网络判断
  • Couchbase 的 OLAP 能力现状以及提升 OLAP 能力的方法
  • Android:动态去掉RecyclerView动画导致时长累加问题解决
  • 【蓝桥杯比赛-C++组-经典题目汇总】
  • cka考试-03-k8s版本升级
  • SpringBootWeb案例-2
  • 图形 3.5 Early-z和Z-prepass
  • Mysql监视器搭建
  • FPGA、STM32、ESP32、RP2040等5大板卡,结合AI,更突出模拟+数字+控制+算法
  • 仓储机器人底盘的研究