当前位置: 首页 > article >正文

决策树:机器学习中的分类与回归利器

决策树(Decision Tree)是机器学习领域里极为常用的算法,在分类和回归问题中都有广泛应用。它借助树状结构来呈现决策流程,每个内部节点代表对一个特征或属性的测试,分支是测试结果,而叶节点则对应最终的类别或值 。

决策树的基本概念

节点(Node)

树中的每一个点都叫做节点。其中,根节点是整棵树的起始点,内部节点用于做出决策,叶节点则代表最终的决策成果 。

分支(Branch)

从一个节点延伸到另一个节点的路径,我们就称其为分支。

分裂(Split)

按照某个特征,把数据集划分成多个子集的过程,这就是分裂。

纯度(Purity)

纯度用于衡量一个子集中样本的类别是否一致。纯度越高,意味着子集中的样本越相似 。

决策树的工作原理

决策树通过递归方式将数据集分割成更小的子集,进而构建起树结构,具体步骤如下:

  1. 选择最佳特征:依据信息增益、基尼指数等标准,挑选出用于分割的最佳特征。
  2. 分割数据集:按照选定的特征,把数据集划分成多个子集。
  3. 递归构建子树:对每个子集重复上述操作,直到满足停止条件,比如所有样本都属于同一类别,或者达到了最大深度。
  4. 生成叶节点:一旦满足停止条件,就生成叶节点,并赋予其相应的类别或值。

决策树的构建标准

构建决策树时,挑选最佳特征进行分割至关重要,常用标准如下:

信息增益(Information Gain)

主要用于分类问题,用于衡量选择某一特征后数据集纯度的提升程度。计算公式为 [此处应补充具体公式,原文未给出] ,其中 Entropy 是数据集的熵,熵主要用于衡量数据的不确定性 。

基尼指数(Gini Index)

同样用于分类问题的分裂标准,计算公式为 [此处应补充具体公式,原文未给出] ,这里的 pi 是类别 i 的样本占比。基尼指数越小,表明数据集越纯净。

均方误差(MSE)

在回归问题中使用,用于衡量预测值与真实值之间的差异。MSE 越小,说明回归树的预测效果越好。

决策树的优缺点

优点

  1. 易于理解和解释:决策树的结构非常直观,普通人也能轻松理解和解释其决策过程。
  2. 处理多种数据类型:不管是数值型数据还是类别型数据,决策树都能很好地处理。
  3. 不需要数据标准化:决策树在运行时,无需对数据进行标准化或归一化处理。

缺点

  1. 容易过拟合:尤其是在数据集较小或者树的深度较大时,决策树很容易出现过拟合现象。
  2. 对噪声敏感:决策树对噪声数据较为敏感,一旦数据中有噪声,可能会导致模型性能下降。
  3. 不稳定:数据哪怕只有微小变化,都可能生成完全不同的决策树。

用 Python 实现决策树

安装必要的库

首先,要确保安装了 scikit - learn 库。若未安装,在命令行输入 pip install scikit - learn 即可完成安装。

导入库并加载数据集

我们以 scikit - learn 自带的鸢尾花(Iris)数据集为例,来演示决策树的使用方法 ,代码如下:

from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import accuracy_score

# 加载鸢尾花数据集
iris = load_iris()
X = iris.data
y = iris.target

# 将数据集分为训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)

训练决策树模型

接下来,使用 DecisionTreeClassifier 训练决策树模型,代码如下:

# 创建决策树分类器
clf = DecisionTreeClassifier()

# 训练模型
clf.fit(X_train, y_train)

预测与评估

用训练好的模型对测试集进行预测,并评估模型的准确率,代码如下:

# 对测试集进行预测
y_pred = clf.predict(X_test)

# 计算准确率
accuracy = accuracy_score(y_test, y_pred)
print(f"模型准确率: {accuracy:.2f}")

运行后,输出结果为:模型准确率: 1.00 。

可视化决策树

为了更直观地理解决策树的结构,我们可以借助 graphviz 库来实现决策树的可视化。

  1. 安装 graphviz
    • Windows 平台:可从Download | Graphviz 下载适用于 Windows 的安装包(.msi 文件)。
    • Linux 平台:在命令行输入 apt install graphviz 即可安装。
    • macOS 平台:在命令行输入 brew install graphviz 进行安装。

也可进行源码安装,先下载最新的源码包(.tar.gz 文件),然后依次执行以下命令:

tar -zxvf graphviz - <version>.tar.gz
cd graphviz - <version>
./configure
make
sudo make install

安装完成后,在命令行输入 dot -V 验证是否安装成功。若安装成功,会输出类似 “dot - graphviz version 12.2.1 (20241206.2353)” 的内容。
2. 安装 graphviz 库:在命令行输入 pip install graphviz 。
3. 生成决策树可视化图

from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import accuracy_score
from sklearn.tree import export_graphviz
import graphviz


# 加载鸢尾花数据集
iris = load_iris()
X = iris.data
y = iris.target

# 将数据集分为训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)

# 创建决策树分类器
clf = DecisionTreeClassifier()

# 训练模型
clf.fit(X_train, y_train)

# 对测试集进行预测
y_pred = clf.predict(X_test)

# 计算准确率
accuracy = accuracy_score(y_test, y_pred)
print(f"模型准确率: {accuracy:.2f}")

# 导出决策树为dot文件
dot_data = export_graphviz(clf, out_file=None,
                           feature_names=iris.feature_names,  
                           class_names=iris.target_names,  
                           filled=True, rounded=True,  
                           special_characters=True)

# 使用graphviz渲染决策树
graph = graphviz.Source(dot_data)
graph.render("iris_decision_tree")  # 保存为PDF文件
graph.view()  # 在浏览器中查看

执行上述代码后,会生成一个 iris_decision_tree.pdf 文件,通过它能清晰看到决策树的结构 。决策树作为一种强大的机器学习算法,在诸多领域发挥着重要作用,希望通过本文能让大家对它有更深入的理解和掌握 。

 


http://www.kler.cn/a/566931.html

相关文章:

  • LabVIEW 无法播放 AVI 视频的编解码器解决方案
  • Unclutter for Mac v2.2.12 剪贴板/文件暂存/笔记三合一 支持M、Intel芯片
  • jenkins使用插件在Build History打印基本信息
  • DeepSeek 开源周五个开源项目,引领 AI 创新?
  • leetcode---LCR 123.图书整理1
  • LabVIEW中交叉关联算法
  • ‘ts-node‘ 不是内部或外部命令,也不是可运行的程序
  • vue3中展示markdown格式文章的三种形式
  • 阿里云oss文件上传springboot若依java
  • 25新闻研究生复试面试问题汇总 新闻专业知识问题很全! 新闻复试全流程攻略 新闻考研复试调剂真题总结
  • 深度解读 AMS1117:从电气参数到应用电路的全面剖析
  • day02_Java基础
  • 网络安全技术与应用
  • C++题解(31) 2025顺德一中少科院信息学创新班(四期)考核:U537296 青蛙的距离 题解
  • Tomcat的server.xml配置详解
  • Tomcat10下载安装教程
  • ssh配置 远程控制 远程协作 github本地配置
  • 量子计算 + 药物开发:打开分子模拟的新纪元
  • java面试笔记(二)
  • 版图自动化连接算法开发 00002 ------ 添加一个中间点实现 Manhattan 方式连接两个给定的坐标点