当前位置: 首页 > article >正文

机器学习——Bagging

Bagging

方法:集成n个base learner模型,每个模型都对原始数据集进行有放回的随机采样获得随机数据集,然后并行训练。

回归问题:n个base模型进行预测,将得到的预测值取平均得到最终结果。

分类问题:n个base模型进行预测,投票选择出n个分类结果中出现次数最对的结果作为最终分类结果

代表模型:随机森林是Bagging的一个代表。它基于自助采样法从原始数据集中抽取多个样本子集,

并在每个子集上训练一个决策树,最后通过投票或平均的方式得到最终的预测结果。

随机森林在鸢尾花数据集的分类实现,代码可直接运行,数据集在文章顶部免费下载

# 导入所需的库
import pandas as pd
from matplotlib import pyplot as plt
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score, confusion_matrix
from sklearn.preprocessing import StandardScaler
import seaborn as sns

# 加载鸢尾花数据集
data = pd.read_excel('../data/鸢尾花分类数据集/Iris花分类.xlsx')
X = data.iloc[:, :4].values  # 选取前4列作为特征
y = data.iloc[:, 4:].values.ravel()  # 选取最后1列作为标签

# 特征缩放(标准化)
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)

# 将数据集划分为训练集和测试集
# 通常我们使用80%的数据作为训练集,20%的数据作为测试集
X_train, X_test, y_train, y_test = train_test_split(X_scaled, y, test_size=0.2, random_state=66)

# 创建随机森林分类器实例
# n_estimators表示森林中树的数量,可以调整以获得更好的性能
randomForest = RandomForestClassifier(n_estimators=100, random_state=42)

# 使用训练数据来拟合(训练)随机森林模型
randomForest.fit(X_train, y_train)

# 使用训练好的模型对测试集进行预测
y_pred = randomForest.predict(X_test)

# 计算预测结果的准确度
accuracy = accuracy_score(y_test, y_pred)

# 打印出准确度
print("随机森林分类精度为: {:.4f}%".format(accuracy * 100))

# 获取特征重要性
feature_importances = randomForest.feature_importances_
# 获取特征名称
feature_names = data.columns[:4].tolist()
# 打印特征重要性
print("特征重要性:")
for feature, importance in zip(feature_names, feature_importances):
    print(f"{feature}: {importance:.4f}")
# 可视化特征重要性
# 创建一个DataFrame来存储特征重要程度
importances_df = pd.DataFrame({'Feature': feature_names, 'Importance': feature_importances})

# 按重要程度降序排序
importances_df = importances_df.sort_values(by='Importance', ascending=False)

# 绘制条形图
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.figure(figsize=(10, 5))
plt.bar(importances_df['Feature'], importances_df['Importance'])
plt.title('Feature Importances')
plt.ylabel('Importance')
plt.xlabel('Feature')
plt.show()

# 计算混淆矩阵
cm = confusion_matrix(y_test, y_pred)

# 绘制混淆矩阵图
plt.figure(figsize=(7, 5))
sns.heatmap(cm, annot=True, fmt=".0f", linewidths=.5, square=True, cmap='Blues')
plt.ylabel('实际标签', fontproperties='SimHei', size=14)
plt.xlabel('预测标签', fontproperties='SimHei', size=14)
plt.title('随机森林分类器混淆矩阵', fontproperties='SimHei', size=15)
plt.show()

结果为:


http://www.kler.cn/news/316045.html

相关文章:

  • String类和String类常用方法
  • LinuxC高级作业1
  • css边框修饰
  • 代码随想录:打家劫舍||
  • 鸿蒙OpenHarmony【轻量系统内核扩展组件(CPU占用率)】子系统开发
  • 【C++】面向对象编程的三大特性:深入解析继承机制
  • Open3D(C++) 基于点云的曲率提取特征点(自定义阈值法)
  • Unity DOTS系列之IJobChunk来迭代处理数据
  • 速盾:高防cdn防御的时候会封ip吗?
  • GPTo1论文详解
  • ICML 2024 论文分享┆用于高分辨率图像合成的可扩展修正流Transformers
  • 深度学习与应用:行人跟踪
  • 使用Docker快速搭建Airflow+MySQL详细教程
  • 【Linux篇】常用命令及操作技巧(基础篇)
  • IM项目-----消息转发子服务
  • 开源模型应用落地-qwen模型小试-调用Qwen2-VL-7B-Instruct-更清晰地看世界-集成vLLM(二)
  • 运行在docker环境下的图片压缩小工具
  • Qt集成Direct2D绘制,实现离屏渲染
  • OpenHarmony(鸿蒙南向开发)——轻量系统内核(LiteOS-M)【SHELL】
  • ARM中的寄存器
  • Zabbix 6.4添加中文语言
  • IT 人转架构设计必备:项目学习资料+视频分享,涵盖运维管理全内容
  • C++ 构造函数最佳实践
  • Jmeter压力测试-ServerAgent-2.2.3闪退问题解决
  • 【编程基础知识】MySQL中什么叫做聚簇索引、非聚簇索引、回表、覆盖索引
  • Spring Boot文件上传
  • Spring Boot 入门面试五道题
  • 【图灵完备 Turing Complete】游戏经验攻略分享 Part.6 处理器架构2 函数
  • 从局部到全局:深入理解Java Web的作用域机制
  • 【SpinalHDL】Scala/SpinalHDL联合编程之实例化