构建ID3决策树的算法代码 核心部分详细讲解
# ID3 算法类
class ID3Tree:
# 定义决策树节点类
class TreeNode:
# 定义树节点
def __init__(self, name):
self.name = name
self.connections = {}
# 定义树的连接
def connect(self, label, node):
self.connections[label] = node
# 定义参数变量,包括数据集、特征集、标签和根结点
def __init__(self, df, label):
self.columns = df.columns
self.df = df
self.label = label
self.root = self.TreeNode("Root")
# 构建树的调用
def construct_tree(self):
self.construct(self.root, "", self.df, self.columns)
# 树的构建递归方法
def construct(self, parent_node, parent_label, sub_df, columns):
# 选择最优特征
max_value, best_feature, max_splited = choose_best_feature(sub_df[columns], self.label)
# 如果找不到最优划分特征,则构造单纯节点
if not best_feature:
node = self.TreeNode(sub_df[self.label].iloc[0])
parent_node.connect(parent_label, node)
return
# 根据最优特征值以及子集构建树
node = self.TreeNode(best_feature)
parent_node.connect(parent_label, node)
# 'A-B' 分的特征的特征集
new_columns = [col for col in columns if col != best_feature]
# 遍历构建连接树
for splited_value, splited_data in max_splited.items():
self.construct(node, splited_value, splited_data, new_columns)
ID3Tree
类 负责构建决策树。TreeNode
子类 定义了树的每个节点。construct_tree
方法 调用递归方法construct
开始构建树。construct
方法 是递归的核心,用于选择最优特征构建子树,并递归处理每一个子集,直到叶节点(分类结果)。
这段代码实现了一个基于 ID3 算法 的决策树构建过程。它定义了一个 ID3Tree
类,并使用递归方式构建决策树。下面我们来详细解释这段代码的每一部分:
1. TreeNode 类 - 决策树节点的定义
class TreeNode:
def __init__(self, name):
self.name = name
self.connections = {}
def connect(self, label, node):
self.connections[label] = node
TreeNode
类 用于定义决策树的节点。- 属性:
name
: 节点的名称或表示的特征名称。connections
: 存储与当前节点相连的子节点。connections
是一个字典,其中键是连接的分支(通常是某个特征的取值),值是子节点(TreeNode
的实例)。
__init__
方法:初始化节点时为其赋予名称,并创建一个空的连接字典。connect
方法:用于连接当前节点与子节点。这个方法将子节点与对应的分支(label
)相连。
2. ID3Tree 类 - ID3 决策树的实现
class ID3Tree:
def __init__(self, df, label):
self.columns = df.columns
self.df = df
self.label = label
self.root = self.TreeNode("Root")
-
ID3Tree
类 是 ID3 算法的主体,用于构建决策树。 -
属性:
columns
: 数据集中所有的特征列。df
: 输入的数据集(通常是一个Pandas DataFrame
)。label
: 数据集中的标签列,用于标注样本的分类。root
: 树的根节点,初始值为"Root"
节点。
-
__init__
方法:初始化决策树时,接收数据集df
和标签列label
。它将根节点设为"Root"
,并存储特征和数据集。
3. 构建树的入口方法
def construct_tree(self):
self.construct(self.root, "", self.df, self.columns)
construct_tree
方法 是构建树的入口方法。它调用递归方法construct
来开始构建决策树。- 初始时,传入的参数是根节点
self.root
、空的parent_label
(即根节点的父节点为空)、完整的数据集self.df
和所有特征self.columns
。
4. 核心构建方法 construct
def construct(self, parent_node, parent_label, sub_df, columns):
construct
方法 是递归构建决策树的核心方法。它通过选择最佳的划分特征,将数据集不断划分为子集,并为每个子集创建新的节点。- 参数:
parent_node
: 当前节点的父节点。parent_label
: 当前节点对应的父节点的分支标签(即从父节点连接过来的特征值)。sub_df
: 当前递归操作的数据子集。columns
: 当前子集使用的特征列。
5. 选择最佳划分特征
max_value, best_feature, max_splited = choose_best_feature(sub_df[columns], self.label)
choose_best_feature
是选择最佳划分特征的函数。根据输入的数据集和标签列,选择能够最大化信息增益的特征。max_value
: 该特征的最大信息增益值。best_feature
: 最佳划分特征的名称。max_splited
: 使用该特征划分后得到的子集(根据特征的取值划分的子集)。
6. 判断是否无法再继续划分
if not best_feature:
node = self.TreeNode(sub_df[self.label].iloc[0])
parent_node.connect(parent_label, node)
return
- 如果没有可以继续划分的特征,就创建一个叶节点,叶节点的名称为当前子集中的第一个标签值(
sub_df[self.label].iloc[0]
)。此时,这个叶节点表示最终的分类结果(即数据集已经被分得很纯了)。 - 然后将该叶节点与父节点连接,并结束当前递归。
7. 为最佳划分特征创建节点并连接到父节点
node = self.TreeNode(best_feature)
parent_node.connect(parent_label, node)
- 创建新的节点:根据最佳划分特征
best_feature
创建一个新的决策树节点。 - 连接节点:将新创建的节点连接到父节点,连接标签为
parent_label
(即父节点的分支)。
8. 更新特征列,排除已经使用的特征
new_columns = [col for col in columns if col != best_feature]
- 将已经使用的特征从特征列中排除。
new_columns
是剩下的未被使用的特征列。
9. 递归处理每个子集
for splited_value, splited_data in max_splited.items():
self.construct(node, splited_value, splited_data, new_columns)
- 对于
max_splited
中的每一个子集(即最佳特征划分后的子集),调用construct
方法递归处理:node
:当前创建的节点作为父节点。splited_value
:划分特征的取值(即这个分支的标签)。splited_data
:子集数据。new_columns
:剩余的特征列。
总结
- 这段代码实现了 ID3 算法构建决策树的逻辑。
- 核心思想是:通过递归地选择信息增益最大的特征,将数据集划分成更纯的子集,最终构建一棵决策树。
- TreeNode 类负责表示决策树的每一个节点,ID3Tree 类负责管理和构建决策树。
- 通过
construct
方法,程序递归地构建每一个节点,并通过连接节点构建整棵决策树。