当前位置: 首页 > article >正文

[机器学习] 决策树

决策树

决策树是一种常用的机器学习算法,用树形结构解决分类和回归问题。它是一种监督学习算法,通过学习简单的决策规则从数据特征中推断出目标变量。

  • 输出变量是离散的——分类问题
  • 输出变量是连续的——回归问题

核心思想

模仿人类决策过程,通过一系列的问题(通常是二元选择)来逐步缩小选择范围,最终达到一个结论。

主要特点

  • 树形结构:决策树由节点(Node)和边(Edge)组成,形似树状结构。每个内部节点代表一个特征上的测试,每条边代表测试的结果,每个叶节点(Leaf Node)代表一个决策结果。
  • 易于理解:决策树的树形结构直观,易于理解和解释,使得模型的预测过程透明。
  • 适用性广:既可以用于分类问题(如CART、ID3、C4.5),也可以用于回归问题(如回归树)。

构建过程

  1. 评估并选择最佳特征:在每个节点,算法会评估所有可能的特征和切分点,选择能够最好地区分数据的特征和切分点。

  2. 分裂数据集:根据选择的特征和切分点,将数据集分割成两个或多个子集。

  3. 递归构建:对每个子集重复上述过程,直到满足停止条件(如达到最大深度、节点中的样本数太少、或进一步分裂不能显著提高模型性能)。

  4. 剪枝:为了防止过拟合,决策树可能会进行剪枝处理,包括预剪枝和后剪枝。

  5. 输出结果:叶节点通常代表最终的分类结果或回归预测值。

分类

根据特征的重要性评判可以分成以下几种:

ID3算法——基于信息增益

基于信息增益,选择能够最大程度减少熵(不确定性)的特征作为分裂节点。

信息增益(Information Gain)

用来衡量通过一个特征将数据集分割后,能够获得多少关于目标变量的信息。

信息增益越高,说明该特征对于预测目标变量的值越有用。

信息增益的计算

信息增益是基于信息论中的熵(Entropy)概念计算的。

  1. 计算数据集的熵:
    熵是衡量数据集不确定性的指标,熵越高,数据集的不确定性越大。
    H ( D ) = − ∑ i = 1 n p i log ⁡ 2 p i H(D) = -\sum_{i=1}^{n} p_i \log_2 p_i H(D)=i=1npilog2pi
    其中, H ( D ) H(D) H(D)是数据集(D)的熵, p i p_i pi是数据集中第 i i i个类别的概率。

  2. 计算特征的条件熵:
    条件熵用于衡量在已知一个随机变量的条件下,另一个随机变量的不确定性。
    H ( D ∣ A ) = ∑ j = 1 m ∣ D j ∣ ∣ D ∣ H ( D j ) H(D|A) = \sum_{j=1}^{m} \frac{|D_j|}{|D|} H(D_j) H(DA)=j=1mDDjH(Dj)
    其中, H ( D ∣ A ) H(D|A) H(DA)是特征 A A A的条件熵, D j D_j Dj是根据特征 A A A的第 j j j个值分割后的子数据集, ∣ D j ∣ |D_j| Dj是子数据集 D j D_j Dj的大小, ∣ D ∣ |D| D是原始数据集的大小。

  3. 计算信息增益:
    Gain ( D , A ) = H ( D ) − H ( D ∣ A ) \text{Gain}(D, A) = H(D) - H(D|A) Gain(D,A)=H(D)H(DA)
    其中, Gain ( D , A ) \text{Gain}(D, A) Gain(D,A) 是特征 A A A 在数据集 D D D上的信息增益。

C4.5算法——基于信息增益比

基于信息增益比(Gain Ratio),在ID3的基础上进行了改进。与ID3算法不同,C4.5不仅能够处理离散属性,还能处理连续属性,同时在处理缺失值和剪枝(Pruning)方面也有较大的改进

核心思想:通过信息增益比选择最优的划分属性,构建决策树。

信息增益比计算

  1. 计算信息增益:
    Gain ( D , a ) = H ( D ) − H ( D ∣ a ) \text{Gain}(D, a) = H(D) - H(D|a) Gain(D,a)=H(D)H(Da)
    其中, H ( D ) H(D) H(D)是数据集 D D D的熵, H ( D ∣ a ) H(D|a) H(Da)是属性 a a a的条件熵。

  2. 计算固有值:
    固有值衡量了属性的离散性,属性取值越多,固有值越大。
    IV ( a ) = − ∑ v = 1 V ∣ D v ∣ ∣ D ∣ log ⁡ 2 ∣ D v ∣ ∣ D ∣ \text{IV}(a) = -\sum_{v=1}^{V} \frac{|D^v|}{|D|} \log_2 \frac{|D^v|}{|D|} IV(a)=v=1VDDvlog2DDv
    其中, V V V是属性 a a a可能的值的数目, D v D^v Dv是属性 a a a的值为 v v v时的子集, ∣ D v ∣ |D^v| Dv ∣ D ∣ |D| D分别是子集 D v D^v Dv和数据集 D D D的样本数目。

  3. 计算信息增益比:
    Gain Ratio ( D , a ) = Gain ( D , a ) IV ( a ) \text{Gain Ratio}(D, a) = \frac{\text{Gain}(D, a)}{\text{IV}(a)} Gain Ratio(D,a)=IV(a)Gain(D,a)
    其中, Gain ( D , a ) \text{Gain}(D, a) Gain(D,a)是特征 a a a在数据集 D D D上的信息增益, IV ( a ) \text{IV}(a) IV(a)是属性 a a a的固有值。

CART算法

分类问题

基尼不纯度 / 基尼指数作为特征选择的标准,它衡量了数据集的不纯度,选择能够最大程度减少不纯度的特征。
基尼不纯度的计算公式
Gini ( D ) = 1 − ∑ i = 1 n p i 2 \text{Gini}(D) = 1 - \sum_{i=1}^{n} p_i^2 Gini(D)=1i=1npi2
其中, p i p_i pi是数据集中第 i i i个类别的概率。
在这里插入图片描述

回归问题

均方误差作为特征选择的标准。

均方误差的计算公式
MSE ( D ) = 1 ∣ D ∣ ∑ i = 1 ∣ D ∣ ( y i − y ˉ ) 2 \text{MSE}(D) = \frac{1}{|D|} \sum_{i=1}^{|D|} (y_i - \bar{y})^2 MSE(D)=D1i=1D(yiyˉ)2
其中, y i y_i yi是数据集 D D D中第 i i i个样本的目标值, y ˉ \bar{y} yˉ是数据集 D D D中所有样本目标值的平均值, ∣ D ∣ |D| D是数据集 D D D的样本数目。

Random Forest算法

构建多个决策树,每棵树在构建时随机选择特征子集,通过集成多个树的结果来提高模型的稳定性和准确性。

回归树

数学表达式

回归树是一种决策树,用于解决回归问题。

它通过将数据集分割成多个子集,每个子集对应一个叶节点,叶节点包含该子集目标值的平均值。

回归树的数学表达式可以表示为:

f ( x ) = ∑ m = 1 M c m I ( x ∈ R m ) f(x) = \sum_{m=1}^{M} c_m I(x \in R_m) f(x)=m=1McmI(xRm)

其中, M M M是叶节点的数量, c m c_m cm是第 m m m个叶节点的目标值平均数, R m R_m Rm是第 m m m个叶节点对应的子集, I ( x ∈ R m ) I(x \in R_m) I(xRm)是指示函数,如果 x x x属于 R m R_m Rm则为1,否则为0。

这个表达式的意思是,对于任意输入 x x x,预测函数 f ( x ) f(x) f(x) 会根据 x x x 属于哪个叶节点的子集 R m R_m Rm,来返回一个预测值。这个预测值一般是该叶子节点的子集 R m R_m Rm所有输出的平均值 c m c_m cm。如果 x x x 属于第 m m m 个叶节点的子集 R m R_m Rm,则 I ( x ∈ R m ) = 1 I(x \in R_m) = 1 I(xRm)=1,其他所有 I ( x ∈ R k ) = 0 I(x \in R_k) = 0 I(xRk)=0(对于 k ≠ m k \neq m k=m)。因此, f ( x ) f(x) f(x) 的值就是 c m c_m cm

损失函数

回归树的损失函数通常使用均方误差(MSE):

MSE ( D ) = 1 ∣ D ∣ ∑ i = 1 ∣ D ∣ ( y i − y ˉ ) 2 \text{MSE}(D) = \frac{1}{|D|} \sum_{i=1}^{|D|} (y_i - \bar{y})^2 MSE(D)=D1i=1D(yiyˉ)2

其中, y i y_i yi是数据集 D D D中第 i i i个样本的目标值, y ˉ \bar{y} yˉ是数据集 D D D中所有样本目标值的平均值, ∣ D ∣ |D| D是数据集 D D D的样本数目。

树如何构建

树的构建通常要解决三个问题:

  1. 树的深度如何决定——决定训练什么时候停止
    树的深度可以通过多种方式决定,大多数情况下是自己定义

    • 直接指定叶子结点个数或树的深度(无法控制精度)
    • 子节点所包含的样本数小于k个时停止划分
    • 当增加深度不再显著提高模型精度时停止
    • ……
  2. 划分节点如何选取
    划分节点的选取通常基于损失函数的减少。对于回归树,这通常是均方误差MSE的减少。选择能够最大化损失函数减少的特征和切分点作为划分节点。

  3. 叶子节点代表的值 c m c_m cm如何定
    叶子节点的 c m c_m cm 值取该叶子节点中所有训练样本 y i y_i yi的平均值时,得到损失最小,。数学表达式为:
    c m = 1 ∣ S m ∣ ∑ x i ∈ S m y i c_m = \frac{1}{|S_m|} \sum_{x_i \in S_m} y_i cm=Sm1xiSmyi
    其中, S m S_m Sm 是第 m m m 个叶子节点的样本集合, y i y_i yi 是样本 x i x_i xi 的目标值, ∣ S m ∣ |S_m| Sm 是叶子节点 S m S_m Sm 中样本的数量。
    推导过程
    损失函数从按样本遍历的形式转化为按叶子节点进行两次遍历的形式,外部对节点遍历,内部遍历节点内所有的样本
    在这里插入图片描述
    由于y确定(数据集给定的),现在只剩下一个要优化的变量 c m c_m cm
    在这里插入图片描述
    直接对J求导,使导数为0,得到最优解
    在这里插入图片描述
    在这里插入图片描述
    在这里插入图片描述

优化求解

回归树的优化求解过程是通过最小化均方误差来选择最佳的分裂点。具体步骤如下:

  1. 计算当前节点的MSE:在当前节点,计算所有样本的MSE。
  2. 尝试所有可能的分裂:对于每个特征,尝试所有可能的切分点,计算分裂后的MSE。
  3. 选择最佳分裂:选择能够最小化MSE的特征和切分点。
  4. 递归优化:对每个子集重复上述过程,直到满足停止条件。

代码实现

from sklearn.datasets import fetch_california_housing
from sklearn.tree import DecisionTreeRegressor
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error


california_housing = fetch_california_housing(as_frame=True)
X, y = california_housing.data, california_housing.target

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)


regressor = DecisionTreeRegressor(max_depth=2, random_state=42)


regressor.fit(X_train, y_train)


y_pred = regressor.predict(X_test)


mse = mean_squared_error(y_test, y_pred)
print(f"Mean Squared Error: {mse}")

分类树

分类树的特征选择方式

分类树在构建过程中,特征选择是一个关键步骤,它决定了树的分裂方式。有两种常用的特征选择标准:

  1. 信息增益(Information Gain)criterion = ‘entropy’
    信息增益是基于熵的概念,用于衡量一个特征对于目标变量的不确定性减少的程度。信息增益越大,表示该特征对于分类越有帮助。
    Gain ( D , a ) = H ( D ) − ∑ v = 1 V ∣ D v ∣ ∣ D ∣ H ( D v ) \text{Gain}(D, a) = H(D) - \sum_{v=1}^{V} \frac{|D^v|}{|D|} H(D^v) Gain(D,a)=H(D)v=1VDDvH(Dv)
    其中, H ( D ) H(D) H(D) 是数据集 D D D 的熵, D v D^v Dv 是特征 a a a 取值为 v v v 时的子集。

  2. 基尼不纯度(Gini Impurity)criterion = ‘gini’
    基尼不纯度衡量了一个节点的不纯度,即样本类别分布的不均匀性。基尼不纯度越低,表示节点越纯。
    Gini ( D ) = 1 − ∑ i = 1 n p i 2 \text{Gini}(D) = 1 - \sum_{i=1}^{n} p_i^2 Gini(D)=1i=1npi2
    其中, p i p_i pi 是数据集中第 i i i 个类别的概率。

分类树的算法流程

  1. 选择最佳特征:在当前节点,根据信息增益或基尼不纯度选择最佳特征进行分裂。

  2. 分裂数据集:根据选择的特征和切分点,将数据集分裂成两个或多个子集。

  3. 递归构建:对每个子集重复步骤1和2,直到满足停止条件。

  4. 停止条件:当节点中的样本数太少、纯度已经很高、或达到预设的最大深度时,停止分裂。

  5. 叶节点分类:对于每个叶节点,根据该节点中的样本类别分布,确定最终的分类结果。

代码实现

from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier
from sklearn import tree
import matplotlib.pyplot as plt

# 加载数据集
iris = load_iris()
X, y = iris.data, iris.target

# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# 创建决策树分类器实例
clf = DecisionTreeClassifier(max_depth=3, random_state=42)

# 训练模型
clf.fit(X_train, y_train)

# 预测测试集
y_pred = clf.predict(X_test)

# 评估模型
accuracy = clf.score(X_test, y_test)
print(f"Accuracy: {accuracy}")

# 可视化决策树
plt.figure(figsize=(12, 8))
tree.plot_tree(clf, filled=True, feature_names=iris.feature_names, class_names=iris.target_names)
plt.show()

在这里插入图片描述


http://www.kler.cn/a/446427.html

相关文章:

  • linux普通用户使用sudo不需要输密码
  • uniApp上传文件踩坑日记
  • 基于Arduino的平衡车机械臂
  • Java并发编程框架之综合案例——在线聊天室(二)
  • vscode不同的项目使用不同的环境变量或编译环境
  • MSSQL AlwaysOn 可用性组(Availability Group)中的所有副本均不健康排查步骤和解决方法
  • 关于VS项目中添加第三方库出现error C4430: 缺少类型说明符 - 假定为 int。注意: C++ 不支持默认 int 错误的解决方法
  • 【Visual Studio Code(VSCode)介绍】
  • 城市灾害应急管理集成系统——系统介绍
  • Centos7, 使用yum工具,出现 Could not resolve host: mirrorlist.centos.org
  • [react] <NavLink>自带激活属性
  • 项目29:简易谜语生成器 --- 《跟着小王学Python·新手》
  • 如何解决Elastic Job Lite任务分配到不健康实例问题?
  • Java 中 wait 和 sleep 的区别:从原理到实践全解析
  • lua dofile 传参数
  • GhostRace: Exploiting and Mitigating Speculative Race Conditions-记录
  • 基于 Python 将 PDF 转 Markdown 并拆解为 JSON,支持自定义标题处理
  • Odoo:免费开源ERP的AI技术赋能出海企业电子商务应用介绍
  • Python Turtle图形库基本命令详解
  • leetcode之hot100---160相交链表(C++)
  • MFC/C++学习系列之简单记录2——thread和Release
  • 【服务器】MyBatis是如何在java中使用并进行分页的?
  • 中阳科技的量化交易模型:从理论到实践的全面探索
  • 1688跨境代购代采:API赋能的自动化与信息化革新
  • 【NLP 18、新词发现和TF·IDF】
  • git中的tag标签远程管理