关于决策树的一些介绍
在这篇文章中我将介绍机器学习中的决策树,我将介绍它的概念、如何构造、决策树的分类、应用以及如何用python实现。
一、 概念
关于决策树的概念我首先给出一棵决策树,而下图就是一棵决策树:
它展示了根据国籍与所属时代来区分四位我十分喜欢的作家:马尔克斯,博尔赫斯,纪德以及安妮埃尔诺。(当然,起初我想举中国作家鲁迅和余华,但在显示时中文会出现乱码,所以我索性就改为拉美的两位作家了)
在这张图里,树的每个节点都代表一个问题或是包含某一问题的终结点(也叫叶节点),树的边则将问题的答案与下一问题连接起来。
在我的代码中,我的第一个节点与二、三节点间的边上是“is”,而非说具体的某一答案,这其实也是可以的,因为决策树的设计可以具有一定的灵活性,只要能够展示你的逻辑就好。
以下是它的代码:
import graphviz
import matplotlib.pyplot as plt
from imageio import imread
ax = plt.gca()
mygraph = graphviz.Digraph(node_attr={'shape':'box'},
edge_attr={'labeldistance':'10.5'},
format="png")
mygraph.node("0","nationality?")
mygraph.node("1","Latin America?")
mygraph.node("2","France?")
mygraph.node("3", "Luis Borges")
mygraph.node("4", "García Márquez")
mygraph.node("5","André Gide")
mygraph.node("6","Annie Ernaux")
mygraph.edge("0","1",label="is")
mygraph.edge("0","2",label="is")
mygraph.edge("1","3",label="modern period")
mygraph.edge("2","5",label="modern period")
mygraph.edge("1","4",label="contemporary")
mygraph.edge("2","6",label="contemporary")
mygraph.render("writers")
ax.imshow(imread("writers.png"))
ax.set_axis_off()
plt.show()
在看了这个图后,对于决策树我们大致就了解了,可以说它由三个组成要素,分别是节点,边与分支。其中关于节点又分为三种,内部节点,根节点与叶子节点。另外,关于分支,它可以理解为根节点到叶节点的路径称为分支或路径,每个分支代表一个决策序列。
二、 构造
构造决策树就是学习一系列的if/else问题,使得我们可以以最快的速度得到正确答案。这些问题叫做测试,但它与测试集并非同一概念,测试集是用于测试模型的泛化性能的。
2.1 选择特征
在构造时,首先我们要选择特征,用一个最佳特征来进行分割,常见的选择方法包括信息增益(Information Gain)、增益率(Gain Ratio)、基尼指数(Gini Index)等。
2.2 划分数据集
在选择完特征后,我们就要划分数据集,将其划分为若干个子集。当然,在实际应用时,大概率不会像我举得例子那样用什么国籍,年代来划分作家,而是对数据进行划分,就像如下这样:
其中根节点表示整个数据集。
2.3 递归构造子树
接下来,对每个子集重复上述过程,递归地构建子树,直到满足停止条件(例如,所有样本属于同一类,或者达到预设的最大深度等)
如果树中某个叶节点所包含的数据点的目标值都相同,那么我们就说这个叶节点是纯的(pure)。
2.4 剪枝
通常的,如果我们要构造一棵决策树直到其所有的叶节点都是纯的叶节点,那么这样就会导致模型变得非常复杂,并且容易出现过拟合。所以我们对于防止过拟合会有两种常见的策略,一种是及早停止树的生长,叫预剪枝(pre-pruning);另一种是先构造树,但在随后删除或折叠那些信息量很少的节点,叫后剪枝(post-pruning)。不过,在sklearn函数库中,只实现了预剪枝而没有实现后剪枝。
三、 应用
关于决策树的应用,它有很多方面,首先可以用于进行分类任务,既可以是二分类亦可以是多分类,比如对于一些病症、欺诈问题上的一些分类,在我自习用的其中一本书上就是用乳腺癌来展示决策树的;此外,决策树还可以用来解决回归上的问题,比如房价预测、销售预测什么的;然后它在强化学习上也是有所应用,比如某些游戏的AI以及作为机器人的导航等。
四、决策树的类型
决策树主要可以分为两种类型,分类树以及回归树。
4.1 分类树
使用分类树时,目标变量是一个离散型的类型标签,而目标变量通常是名义型(categorical)或有序型(ordinal)变量,如“是”或“否”、“良性”或“恶性”等。
4.2 回归树
使用回归树时,目标变量是一个连续的值,而目标变量也通常是连续的数值,比如“温度”等。
五、python的实现
下面是决策树的具体实现,其中我使用了预剪枝来防止过拟合:
import numpy as np
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import accuracy_score
from sklearn.tree import plot_tree
import matplotlib.pyplot as plt
# 生成一个简单的二分类数据集
X, y = make_classification(n_samples=100, n_features=4, n_informative=2, n_redundant=0, random_state=42)
# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)
# 创建决策树分类器,设置最大深度为3
clf_pre_pruned = DecisionTreeClassifier(max_depth=3, min_samples_leaf=5)
# 训练模型
clf_pre_pruned.fit(X_train, y_train)
# 预测
y_pred_pre_pruned = clf_pre_pruned.predict(X_test)
# 评估
accuracy_pre_pruned = accuracy_score(y_test, y_pred_pre_pruned)
print(f"Accuracy (Pre-pruned): {accuracy_pre_pruned}")
# 可视化决策树
plt.figure(figsize=(10, 5))
plot_tree(clf_pre_pruned, filled=True, feature_names=["Feature 1", "Feature 2", "Feature 3", "Feature 4"])
plt.show()
而那颗决策树展示出来为这样:
此上