当前位置: 首页 > article >正文

随机森林(Random Forest)算法Python代码实现

class RandomForest():
    def __init__(self, n_estimators=100, min_samples_split=2, min_gain=0,
                 max_depth=float("inf"), max_features=None):
        # 树的棵树
        self.n_estimators = n_estimators
        # 树最小分裂样本数
        self.min_samples_split = min_samples_split
        # 最小增益
        self.min_gain = min_gain
        # 树最大深度
        self.max_depth = max_depth
        # 所使用最大特征数
        self.max_features = max_features

        self.trees = []
        # 基于决策树构建森林
        for _ in range(self.n_estimators):
            tree = ClassificationTree(min_samples_split=self.min_samples_split, min_gini_impurity=self.min_gain,
                                      max_depth=self.max_depth)
            self.trees.append(tree)
            
    # 自助抽样
    def bootstrap_sampling(self, X, y):
        X_y = np.concatenate([X, y.reshape(-1,1)], axis=1)
        np.random.shuffle(X_y)
        n_samples = X.shape[0]
        sampling_subsets = []

        for _ in range(self.n_estimators):
            # 第一个随机性,行抽样
            idx1 = np.random.choice(n_samples, n_samples, replace=True)
            bootstrap_Xy = X_y[idx1, :]
            bootstrap_X = bootstrap_Xy[:, :-1]
            bootstrap_y = bootstrap_Xy[:, -1]
            sampling_subsets.append([bootstrap_X, bootstrap_y])
        return sampling_subsets
            
    # 随机森林训练
    def fit(self, X, y):
        # 对森林中每棵树训练一个双随机抽样子集
        sub_sets = self.bootstrap_sampling(X, y)
        n_features = X.shape[1]
        # 设置max_feature
        if self.max_features == None:
            self.max_features = int(np.sqrt(n_features))
        
        for i in range(self.n_estimators):
            # 第二个随机性,列抽样
            sub_X, sub_y = sub_sets[i]
            idx2 = np.random.choice(n_features, self.max_features, replace=True)
            sub_X = sub_X[:, idx2]
            self.trees[i].fit(sub_X, sub_y)
            # 保存每次列抽样的列索引,方便预测时每棵树调用
            self.trees[i].feature_indices = idx2
            print('The {}th tree is trained done...'.format(i+1))
    
    # 随机森林预测
    def predict(self, X):
        y_preds = []
        for i in range(self.n_estimators):
            idx = self.trees[i].feature_indices
            sub_X = X[:, idx]
            y_pred = self.trees[i].predict(sub_X)
            y_preds.append(y_pred)
            
        y_preds = np.array(y_preds).T
        res = []
        for j in y_preds:
            res.append(np.bincount(j.astype('int')).argmax())
        return res

这段代码实现了一个简化版的 随机森林(Random Forest)模型。代码的主要功能是通过集成多棵决策树来进行分类。每棵树是通过自助采样(bootstrap sampling)和特征随机抽样(feature bagging)来训练的。以下是对每个部分的详细解析:

1. __init__ 方法

这个方法是 RandomForest 类的构造函数,它定义了随机森林的参数,并初始化了森林中的每棵决策树。

def __init__(self, n_estimators=100, min_samples_split=2, min_gain=0,
             max_depth=float("inf"), max_features=None):
  • n_estimators: 决策树的数量,也就是森林中的树的个数。
  • min_samples_split: 每棵树在进行分裂时,节点中最小的样本数。这个参数用于避免过拟合。
  • min_gain: 每次分裂的最小信息增益,当增益小于该值时停止分裂。
  • max_depth: 每棵树的最大深度,用于避免过拟合。
  • max_features: 每次划分时考虑的最大特征数。如果为 None,则默认使用特征总数的平方根。

然后,构造函数初始化一个空列表 self.trees 用来存储森林中的每棵树,并基于这些参数实例化决策树(假设 ClassificationTree 是一个已定义的决策树类)。

2. bootstrap_sampling 方法

这个方法实现了 自助抽样(Bootstrap Sampling),即每棵决策树会从训练数据中随机抽取样本来训练,且允许重复抽样。

def bootstrap_sampling(self, X, y):
    X_y = np.concatenate([X, y.reshape(-1,1)], axis=1)
    np.random.shuffle(X_y)
    n_samples = X.shape[0]
    sampling_subsets = []
  • X_y 是将特征矩阵 X 和标签数组 y 合并后的数据集。
  • np.random.shuffle(X_y) 用于将数据集打乱顺序。
  • 通过 np.random.choice 从数据集中随机抽取样本(带放回的抽样),生成每棵树的训练子集。
  • 结果是返回一个包含多个训练子集的列表 sampling_subsets,每个子集包含特征和标签。

3. fit 方法

fit 方法用于训练随机森林。它首先生成用于训练的子集,然后训练每棵树。

def fit(self, X, y):
    sub_sets = self.bootstrap_sampling(X, y)
    n_features = X.shape[1]
    if self.max_features == None:
        self.max_features = int(np.sqrt(n_features))
    
    for i in range(self.n_estimators):
        sub_X, sub_y = sub_sets[i]
        idx2 = np.random.choice(n_features, self.max_features, replace=True)
        sub_X = sub_X[:, idx2]
        self.trees[i].fit(sub_X, sub_y)
        self.trees[i].feature_indices = idx2
  • 通过 self.bootstrap_sampling(X, y) 获得自助抽样的子集。
  • self.max_features 默认为特征数量的平方根,表示每次分裂时考虑的最大特征数。
  • 对于每棵树,先随机选择一个特征子集(通过 np.random.choice),然后使用这个子集训练决策树。
  • 每棵树的 feature_indices 属性保存了它使用的特征的索引,方便在预测时使用。

4. predict 方法

predict 方法用于在训练完成后对新的数据进行预测。

def predict(self, X):
    y_preds = []
    for i in range(self.n_estimators):
        idx = self.trees[i].feature_indices
        sub_X = X[:, idx]
        y_pred = self.trees[i].predict(sub_X)
        y_preds.append(y_pred)
    
    y_preds = np.array(y_preds).T
    res = []
    for j in y_preds:
        res.append(np.bincount(j.astype('int')).argmax())
    return res
  • 对于每棵树,提取它的 feature_indices 来获取它所使用的特征,然后用它来预测。
  • 将每棵树的预测结果收集到 y_preds 中。
  • 通过 np.bincount 对所有树的预测结果进行投票,选出出现最多的类别作为最终的预测。

关键点解析

  • 自助抽样(Bootstrap Sampling)和特征随机选择(Feature Bagging)是随机森林的两个关键概念。通过自助抽样,每棵树可以看到不同的数据子集;通过特征随机选择,每棵树在每次划分时只考虑一部分特征。
  • 决策树:此代码中假设 ClassificationTree 是一个决策树类,应该具备训练(fit)和预测(predict)方法,能够进行基于数据的分裂。
  • 投票机制:在预测时,随机森林通过每棵树的投票结果来决定最终的分类结果,使用多数表决法。

改进建议

  1. 多线程/并行化:训练多个决策树的过程可以并行化,以提高计算效率。
  2. 优化 np.random.choice:可以改进特征选择和样本选择的实现方式,提高效率。
  3. 提高决策树的深度控制:可以在决策树内部加一些剪枝操作,以避免过拟合。

http://www.kler.cn/a/393388.html

相关文章:

  • ubuntu连接orangepi-zero-2w桌面的几种方法
  • influxDB 时序数据库安装 flux语法 restful接口 nodjsAPI
  • GISBox VS ArcGIS:分别适用于大型和小型项目的两款GIS软件
  • 040 线程池
  • Linux相关习题-gcc-gdb-冯诺依曼
  • 深入理解接口测试:实用指南与最佳实践5.0(二)
  • 数据量大Excel卡顿严重?选对报表工具提高10倍效率
  • 同三维T85HU HDMI+USB摄像机多路多机位手机直播采集卡
  • 浅析pytorch中的常见函数和方法
  • 128.WEB渗透测试-信息收集-ARL(19)
  • DDE(深度桌面环境) Qt 6.8 适配说明
  • 嵌入式开发套件(golang版本)
  • 昇思大模型平台打卡体验活动:项目6基于MindSpore通过GPT实现情感分类
  • 力扣662:二叉树的最大宽度
  • Java面向对象编程进阶之包装类
  • Python---re模块(正则表达式)
  • 快递鸟快递查询API接口参数代码
  • 字符设备 - The most important !
  • JavaScript 中实例化生成对象的相关探讨
  • JVM 中的完整 GC 流程
  • 电信网关配置管理后台 upload_channels.php 任意文件上传漏洞复现
  • IntelliJ IDEA设置打开文件tab窗口多行展示
  • 使用Cesium for Unreal与Cesium ion构建3D地理空间应用教程
  • PHP运算符
  • 使用React和Vite构建一个AirBnb Experiences克隆网站
  • 父子线程间传值问题以及在子线程或者异步情况下使用RequestContextHolder.getRequestAttributes()的注意事项和解决办法