当前位置: 首页 > article >正文

[机器学习]决策树

1 决策树简介

2 信息熵

 3 ID3决策树

3.1 决策树构建流程

3.2 决策树案例

4 C4.5决策树

5 CART决策树(分类&回归)

6 泰坦尼克号生存预测案例

import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier,plot_tree
import matplotlib.pyplot as plt
from sklearn.metrics import accuracy_score,precision_score,recall_score,f1_score,classification_report
# 获取数据
data=pd.read_csv('titanic/train.csv')
# data.info()
# 数据处理
x=data[['Sex','Age','Pclass']]
y=data['Survived']
# x.head()
# 热编码
x=pd.get_dummies(x)
# 缺失值填充
x['Age']=x['Age'].fillna(x['Age'].mean())
# x.head()
x_train,x_test,y_train,y_test=train_test_split(x,y,test_size=0.2,random_state=22)
# 模型训练
tree=DecisionTreeClassifier(criterion='gini',max_depth=6)
tree.fit(x_train,y_train)
# 模型预测
y_predict=tree.predict(x_test)
# print(y_predict)
# 模型评估
print('accuracy_score',accuracy_score(y_test,y_predict))
print('precision_score',precision_score(y_test,y_predict))
print('recall_score',recall_score(y_test,y_predict))
print('f1_score',f1_score(y_test,y_predict))
print(classification_report(y_test,y_predict))
# 绘制树
plt.figure(figsize=(30,20))
plot_tree(tree,filled=True,feature_names=['Age','Pclass','Sex_female','Sex_male'],class_names=['died','survived'])
plt.show()

7 CART回归树

import numpy as np
from sklearn.linear_model import LinearRegression
from sklearn.tree import DecisionTreeRegressor,plot_tree
import matplotlib.pyplot as plt

# 构建数据
x=np.array(list(range(1,11))).reshape(-1,1)
print(x.shape)
y=np.array([5.56,5.7,5.91,6.4,6.8,7.05,8.9,8.7,9,9.05])
# print(x)

# 模型训练
model1=LinearRegression()
model2=DecisionTreeRegressor(max_depth=1)
model3=DecisionTreeRegressor(max_depth=3)

model1.fit(x,y)
model2.fit(x,y)
model3.fit(x,y)
# 模型预测
x_test=np.arange(0.0,10.0,0.01).reshape(-1,1)
print(x_test.shape)
y1=model1.predict(x_test)
y2=model2.predict(x_test)
y3=model3.predict(x_test)

plt.scatter(x,y)
plt.plot(x_test,y1)
plt.plot(x_test,y2)
plt.plot(x_test,y3)
plt.grid()
plt.show()

plt.figure(figsize=(30,20))
plot_tree(model3,filled=True)
plt.show()

8 决策树剪枝


http://www.kler.cn/news/305527.html

相关文章:

  • Parallels Desktop 20 for Mac中文版发布了?会哪些新功能
  • 学习笔记-Golang中的Context
  • 基础算法(3)——二分
  • Java邮件:如何配置以实现自动化邮件通知?
  • 平安养老险阜阳中心支公司开展金融教育宣传专项活动
  • ElementUI 快速入门:使用 Vue 脚手架搭建项目
  • SQL 代表什么?SQL 的全称是什么?
  • 二叉树算法 JAVA
  • 微信小程序中的模块化、组件化开发:完整指南
  • 资源管理新视角:利用 FastAPI Lifespan 事件优化你的应用
  • Android Greendao的数据库复制到设备指定位置
  • PhpStudy下载安装使用学习
  • 外国车牌字符识别与分类系统源码分享
  • PPT幻灯片的添加与编辑:全面技术指南
  • 【30天玩转python】高级数据结构
  • 2024年增强现实(AR)的现状
  • 用牛只面部图像实现牛只身份识别(与人脸识别不同的牛脸识别)
  • 发展绿色新质生产力,创维汽车亮相2024国际数字能源展
  • SSHamble:一款针对SSH技术安全的研究与分析工具
  • 华宇TAS应用中间件斩获2024鲲鹏应用创新大赛北京赛区总决赛二等奖!
  • SAP B1 Web Client MS Teams App集成连载二:安装Install/升级Upgrade/卸载Uninstall
  • 【LeetCode 算法笔记】155. 最小栈
  • Element-Ui el-table 序号使用问题
  • ESP32-S3百度文心一言大模型AI语音聊天助手(支持自定义唤醒词训练)【手把手非常详细】【万字教程】
  • [区间dp]添加括号
  • LEAN 类型系统属性 之 算法式相等的非传递性(Algorithm equality is not transitive)注解
  • Vue3+TypeScript二次封装axios
  • C++多态讲解
  • 进阶岛 任务3: LMDeploy 量化部署进阶实践
  • 在云服务器上安装配置 MySQL 并开启远程访问(详细教程)