当前位置: 首页 > article >正文

机器学习(三)——决策树(附核心思想、重要算法、概念(信息熵、基尼指数、剪枝处理)及Python源码)

目录

  • 关于
  • 1 基本流程
  • 2 划分属性的选择
    • 2.1 方法一:依据信息增益选择
    • 2.2 方法二:依据增益率选择
    • 2.3 方法三:依据基尼指数选择
  • 3 剪枝处理:防止过拟合
    • 3.1 预剪枝
    • 3.2 后剪枝
  • 4 连续与缺失值
    • 4.1 连续值处理
    • 4.2 缺失值处理
  • 5 多变量决策树
  • X 案例代码
    • X.1 分类任务
      • X.1.1 源码
      • X.1.2 数据集 (鸢尾花数据集)
      • X.1.3 模型效果
    • X.2 回归任务
      • X.2.1 源码
      • X.2.2 数据集(糖尿病数据集)
      • X.2.3 模型效果


关于

  • 本文是基于西瓜书(第四章)的学习记录。讲解决策树的重要概念(划分属性的选择、剪枝处理、连续值和缺失值的处理、多变量决策树等)、核心流程,附Python分类和回归实现代码。
  • 西瓜书电子版:百度网盘分享链接

1 基本流程

决策树是一种模仿人类决策过程的机器学习算法,它通过树状结构来进行决策。

  • 决策树结构
    • 根节点:每个结点包含的样本集合根据属性测试的结果被划分到子结点中,根结点包含样本全集。
    • 内部节点:代表属性测试。
    • 叶节点:对应决策结果。
  • 决策过程:从根节点开始,通过一系列的测试(属性值的判断)到达叶节点,完成决策。
  • 算法流程
    在这里插入图片描述
    决策树的生成是一个递归过程,有三种情形会导致递归返回:
    • 当前结点包含的样本全属于同一类别,无需划分;
    • 当前属性集为空,或是所有样本在所有属性上取值相同,无法划分;
    • 当前结点包含的样本集合为空,不能划分;

2 划分属性的选择

  • 在决策树学习中,选择最优的划分属性是关键,这有助于提高样本集合的纯度,即分支结点所包含的样本尽可能属于同一类别。
  • 下面是三种常见的划分属性选择方法:

2.1 方法一:依据信息增益选择

  • 简述

    • 首先计算数据集D的信息熵;
    • 然后对于每个属性a,考虑将其作为划分属性划分后得到的各个子集,计算各子集的信息熵;
    • 然后利用下面的信息增益公式计算依据属性a划分的信息增益;
    • 最后选择增益最大的属性进行划分然后进行下一轮选择(如果不结束)。
  • 重要概念

    • 信息熵:衡量样本集合D纯度的指标,计算公式为
      E n t ( D ) = − ∑ k = 1 ∣ D ∣ p k log ⁡ 2 p k Ent(D) = -\sum_{k=1}^{|D|} p_k \log_2 p_k Ent(D)=k=1Dpklog2pk其中 p k p_k pk是第k类样本所占的比例。 E n t ( D ) Ent(D) Ent(D)值越小,纯度越高
    • 信息增益:使用属性a进行划分所获得的纯度提升,计算公式为
      G a i n ( D , a ) = E n t ( D ) − ∑ v = 1 V ∣ D v ∣ ∣ D ∣ E n t ( D v ) Gain(D, a) = Ent(D) - \sum_{v=1}^{V} \frac{|D_v|}{|D|} Ent(D_v) Gain(D,a)=Ent(D)v=1VDDvEnt(Dv)其中 D v D_v Dv是在属性a上取值为v的样本子集。
    • ID3算法:以信息增益为准则选择划分属性,是决策树学习的经典算法之一。
  • 案例
    在这里插入图片描述
    在这里插入图片描述
    在这里插入图片描述

2.2 方法二:依据增益率选择

  • 简述:方法一的信息增益准则对可取值数目较多的属性有所偏好,所以考虑不直接使用信息增益,而是使用“增益率”来选择最优划分属性。其实就在每个属性的信息增益基础上除以 I V ( a ) IV(a) IV(a),得到的结果就是增益率,基于此选择(但是增益率又会偏好可取值数目少的属性,所以先从候选划分属性中找出信息增益高于平均水平的属性,再从中选择增益率最高的)。
  • 重要概念
    • 增益率:减少信息增益对可取值数目较多属性的偏好,计算公式为 G a i n R a t i o ( D , a ) = G a i n ( D , a ) I V ( a ) GainRatio(D, a) = \frac{Gain(D, a)}{IV(a)} GainRatio(D,a)=IV(a)Gain(D,a)其中 I V ( a ) IV(a) IV(a)是属性a的固有值,计算公式如下。
    • I V ( a ) IV(a) IV(a):这是属性a的固有值,a的可能取值越多,该值越大,起计算公式如下:
      I V ( a ) = − ∑ v = 1 V ∣ D v ∣ ∣ D ∣ log ⁡ 2 ∣ D v ∣ ∣ D ∣ \mathrm{IV}(a)=-\sum_{v=1}^V\frac{|D^v|}{|D|}\log_2\frac{|D^v|}{|D|} IV(a)=v=1VDDvlog2DDv
    • C4.5算法:使用增益率选择最优划分属性,是ID3算法的改进版。

2.3 方法三:依据基尼指数选择

  • 简述:流程和方法一一致,但是选择的指标是基尼指数
  • 重要概念
    • 基尼指数:度量数据集纯度的指标(反映了从数据集D 中随机抽取两个样本,其类别标记不一致的概率.因此,Gini(D)越小,则数据集D 的纯度越高),计算公式为 G i n i ( D ) = 1 − ∑ k = 1 ∣ D ∣ p k 2 Gini(D) = 1 - \sum_{k=1}^{|D|} p_k^2 Gini(D)=1k=1Dpk2
      属性a划分后的基尼指数计算公式为: G i n i I n d e x ( D , a ) = ∑ v = 1 V ∣ D v ∣ ∣ D ∣ G i n i ( D v ) . \mathrm{GiniIndex}(D,a)=\sum_{v=1}^V\frac{|D^v|}{|D|}\mathrm{Gini}(D^v) . GiniIndex(D,a)=v=1VDDvGini(Dv).
    • CART算法:使用基尼指数选择划分属性,适用于分类和回归任务

3 剪枝处理:防止过拟合

  • 剪枝是决策树学习中对付过拟合的主要手段,通过剪枝可以降低过拟合的风险。剪枝策略有预剪枝后剪枝两种

3.1 预剪枝

在这里插入图片描述

  • 预剪枝:在生成过程中,对每个结点在划分前先进行估计,若当前结点的划分不能带来决策树泛化性能提升,则停止划分
  • 优点:这不仅降低了过拟合的风险,还显著减少了决策树的训练时间开销和测试时间开销
  • 缺点:预剪枝基于“贪心”本质禁止这些分支展开,给预剪枝决策树带来了欠拟合的风险
  • 决策树桩:只有一层划分的决策树

3.2 后剪枝

在这里插入图片描述

  • 后剪枝:先从训练集生成一棵完整的决策树,然后自底向上地对非叶结点进行考察,若将该结点对应的子树替换为叶结点能带来决策树泛化性能提升,则将该子树替换为叶结点。
  • 后剪枝决策树通常比预剪枝决策树保留了更多的分支.
  • 优点:后剪枝决策树的欠拟合风险很小,泛化性能往往优于预剪枝决策树
  • 缺点:但后剪枝过程是在生成完全决策树之后进行的,并且要自底向上地对树中的所有非叶结点进行逐一考察,因此其训练时间开销比未剪枝决策树和预剪枝决策树都要大得多.

4 连续与缺失值

4.1 连续值处理

  • 必要性:现实学习任务中常会遇到连续属性,有必要讨论如何在决策树学习中使用连续属性.
  • 连续属性离散化:使用二分法处理连续属性,按照某个划分点的值将连续属性的值划分为两个子集,选择最优的划分点以最大化信息增益。
  • 划分点的选择:对于包含n个不同值的某个属性,可以得到n-1个划分点,根据划分点可以将取值范围划分为两个区间
  • 核心逻辑:得到某个属性的划分点集合之后就可以按照信息增益的大小选择出最合适的划分点构造一个划分节点。
  • 连续数据划分点和离散数据的区别:与离散属性不同,若当前结点划分属性为连续属性,该属性还可作为其后代结点的划分属性
  • 示例:
    在这里插入图片描述

4.2 缺失值处理

  • 必要性:一部分样本得属性缺少值,直接丢弃会造成信息丢失
  • 要解决的问题:
    • 如何在属性值缺失的情况下进行划分属性选择给定划分属性?——将前文的信息增益公式推广后执行类似流程。
    • 若样本在该属性上的值缺失,如何对样本进行划分?——同时划分到子节点,不过进行权重调整,直观地看就是将同一个样本以不同的概率划分到不同的子节点去
  • 权重调整:对缺失值样本进行权重调整后进行划分,确保模型能够从不完整样本中学习。

5 多变量决策树

  • 背景:

    • 分类意味着找到分类边界
    • 决策树分类边界的特点:由若干个与坐标轴平行的分段组成
      • 优点:可解释性强
      • 缺点:需要很多段,学习代价大
  • 多变量决策树

    • 优点:能够实现斜划分(使用属性的线性组合进行测试,允许分类边界不是轴平行的),简化模型复杂度。
    • 改变:在非叶节点上实现斜划分,不再是寻找一个最优划分属性,而是试图建立一个合适的线性分类器
      在这里插入图片描述

X 案例代码

X.1 分类任务

X.1.1 源码

import pandas as pd
import matplotlib.pyplot as plt
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import confusion_matrix, accuracy_score, classification_report
import seaborn as sns

# 1. 加载数据集
iris = load_iris()
X, y = iris.data, iris.target
print("此时X,y的数据类型为:", type(X), type(y), '\n')

# 2. 将数据集分为训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
print("此时X_train,y_train的数据类型为:", type(X_train), type(y_train), '\n')
print("X_train的前10条数据展示:")
print(pd.DataFrame(X_train).head(10).to_string(index=False, justify='left'), '\n')

# 3. 构建并训练决策树分类模型
model = DecisionTreeClassifier(random_state=42)
model.fit(X_train, y_train)

# 4. 预测测试集上的目标变量
y_pred = model.predict(X_test)

# 5. 评估模型性能
accuracy = accuracy_score(y_test, y_pred)
print("模型准确率:", accuracy)

print("分类报告:")
print(classification_report(y_test, y_pred))

# 6. 绘制混淆矩阵
conf_matrix = confusion_matrix(y_test, y_pred)
plt.figure(figsize=(8, 6))
sns.heatmap(conf_matrix, annot=True, fmt='d', cmap='Blues', xticklabels=iris.target_names, yticklabels=iris.target_names)
plt.xlabel('Predicted Labels')
plt.ylabel('True Labels')
plt.title('Confusion Matrix for Iris Dataset')
plt.tight_layout()
plt.show()

# 可选:将结果保存到DataFrame中以便进一步分析
results = pd.DataFrame({
    'Actual': y_test,
    'Predicted': y_pred
})
print("模型预测结果:")
print(results.head())

X.1.2 数据集 (鸢尾花数据集)

  • 鸢尾花数据集是机器学习领域中最著名的数据集之一,常被用于分类算法的测试和演示。

  • 概览

    • 样本数量:150个样本
    • 特征数量:4个特征
    • 标签种类数量:3个类别,每个类别有50个样本
  • 特征描述

    • 萼片长度 (sepal length):花萼的长度,单位为厘米。
    • 萼片宽度 (sepal width):花萼的宽度,单位为厘米。
    • 花瓣长度 (petal length):花瓣的长度,单位为厘米。
    • 花瓣宽度 (petal width):花瓣的宽度,单位为厘米。
  • 目标变量是鸢尾花的种类,共有三种:

    1. Iris setosa
    2. Iris versicolor
    3. Iris virginica
  • 使用

    • 可以使用 sklearn.datasets.load_iris() 函数来加载这个数据集,并查看其详细信息。

X.1.3 模型效果

在这里插入图片描述
在这里插入图片描述

X.2 回归任务

X.2.1 源码

import pandas as pd
from matplotlib import pyplot as plt
from sklearn import datasets
from sklearn.tree import DecisionTreeRegressor, export_text
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error, r2_score

# 加载糖尿病数据集
diabetes = datasets.load_diabetes()

# 提取特征和标签
X = diabetes.data
y = diabetes.target
print("此时X,y的数据类型为:", type(X), type(y), '\n')

# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
print("此时X_train,y_train的数据类型为:", type(X_train), type(y_train), '\n')
print("X_train的前10条数据展示:")
print(pd.DataFrame(X_train).head(10).to_string(index=False, justify='left'), '\n')

# 创建决策树回归模型
regressor = DecisionTreeRegressor(random_state=42)

# 训练模型
regressor.fit(X_train, y_train)

# 进行预测
y_pred = regressor.predict(X_test)

# 计算均方误差
mse = mean_squared_error(y_test, y_pred)
r2 = r2_score(y_test, y_pred)

print(f"均方误差(MSE): {mse}")
print(f"决定系数(R^2): {r2}", '\n')

# 查看决策树的结构
tree_rules = export_text(regressor, feature_names=diabetes.feature_names)
print(tree_rules)

# 绘制实际值和预测值的折线图
plt.figure(figsize=(12, 6))
plt.plot(y_test, label='Actual', marker='o', color='blue')
plt.plot(y_pred, label='Predicted', marker='x', color='red', linestyle='--')
plt.title('Actual vs Predicted Values')
plt.xlabel('Sample Index')
plt.ylabel('Target Value')
plt.legend()
plt.tight_layout()
plt.show()

X.2.2 数据集(糖尿病数据集)

  • 糖尿病数据集包含442名患者的10项生理特征,目标是预测一年后疾病水平的定量测量值。这些特征经过了标准化处理,使得每个特征的平均值为零,标准差为1。

  • 概览

    • 样本数量:442个样本
    • 特征数量:10个特征
    • 目标变量:1个目标变量(一年后疾病水平的定量测量值)
  • 特征描述

    1. 年龄 (age):患者年龄(已标准化)
    2. 性别 (sex):患者性别(已标准化)
    3. 体质指数 (bmi):身体质量指数(已标准化)
    4. 血压 (bp):平均动脉压(已标准化)
    5. S1:血清测量值1(已标准化)
    6. S2:血清测量值2(已标准化)
    7. S3:血清测量值3(已标准化)
    8. S4:血清测量值4(已标准化)
    9. S5:血清测量值5(已标准化)
    10. S6:血清测量值6(已标准化)
  • 目标变量

    • 一年后疾病水平的定量测量值:这是模型需要预测的目标变量。
  • 使用

    • 可以使用 sklearn.datasets.load_diabetes() 函数来加载这个数据集,并查看其详细信息。

X.2.3 模型效果

在这里插入图片描述

在这里插入图片描述


http://www.kler.cn/a/383061.html

相关文章:

  • 服务器开放了mongodb数据库的外网端口,但是用mongodbCompass还是无法连接。
  • 使用 Python 和 OpenCV 实现实时人脸识别
  • Java+Swing可视化图像处理软件
  • 【昇腾】从单机单卡到单机多卡训练
  • 【ONLYOFFICE文档】8.2版本测评
  • Kotlin的内置函数
  • Flutter UI构建渲染(4)
  • Windows10/11下python脚本自动连接WiFi热点
  • STM32启动文件分析
  • Axure是什么软件?全方位解读助力设计入门
  • 实践是认识的来源
  • GPU的内存是什么?
  • 继承——面向对象编程的基石
  • 【C++】lambda表达式的理解与运用(C++11新特性)
  • [C++ 核心编程]笔记 4.4.2 类做友元
  • 【Vue 2.x】之指令详解
  • Nat Med 病理AI系列|人工智能在肝病临床试验中的应用·顶刊精析·24-11-06
  • QT开发:掌握现代UI动画技术:深入解析QML和Qt Quick中的动画效果
  • 用PyQt 5 开发的雷达基数据可视化软件
  • 关于c指针的一些说明
  • 第2篇 使用Intel FPGA Monitor Program创建基于ARM处理器的汇编或C语言工程<二>
  • 【5.10】指针算法-快慢指针将有序链表转二叉搜索树
  • Java项目实战II基于Spring Boot的问卷调查系统的设计与实现(开发文档+数据库+源码)
  • Linux 文件基本属性
  • SQL Server 日志记录
  • linux arm板启动时间同步服务