当前位置: 首页 > article >正文

机器学习作业 | 泰坦尼克号生存的预测任务

泰坦尼克号生存的预测任务

学校作业,我来水一水

环境:pycharm+anaconda虚拟环境

文章目录

  • 泰坦尼克号生存的预测任务
    • 0.环境搭建参考:
    • 1 目的与要求
    • 2 任务背景
    • 3 任务简介
    • 4 模型介绍
      • 1.决策树(Decision Tree)
      • 2.朴素贝叶斯(Naive Bayes)
      • 3.支持向量机(Support Vector Machine, SVM)
    • 5 结论
    • 6.训练集测试集代码链接

0.环境搭建参考:

学校机器学习_为了前进而后退,为了走直路而走弯路的博客-CSDN博客

1 目的与要求

(1)目的: 本任务旨在使用机器学习算法预测泰坦尼克号乘客的生存情况。根据乘客的个人信息(如年龄、性别、船舱等级等),使用合适的模型来判断该乘客是否在事故中幸存。模型将对给定的测试集进行预测,并生成预测结果。
(2)采用不同的机器学习算法进行建模(如:决策树、朴素贝叶斯、支持向量机)。
对模型的性能进行评估,计算准确率。
对比不同模型的准确率,通过可视化手段展示结果(如准确率折线图、混淆矩阵等)。
输出每个模型的可视化结果并保存(如决策树的图示、特征重要性等)。

2 任务背景

泰坦尼克号(RMS Titanic)是世界历史上最著名的沉船之一,1912年4月15日沉没。在事故中,约有1500多人失去了生命。根据该事件的相关数据集(包含乘客的个人信息及生死状态),我们可以构建模型预测乘客的生存概率。数据集包含的特征有:乘客的年龄、性别、船舱等级、票价、家庭成员数量等。

3 任务简介

数据集介绍:
训练集(mytrain.csv):包含泰坦尼克号乘客的个人信息及其是否生还的标签(Survived)。这个数据集将用于训练模型。
测试集(mytest.csv):包含泰坦尼克号乘客的个人信息。这个数据集将用于模型预测,预测结果与 mygender.csv 中的标签进行比较。

**目标:**构建机器学习模型,预测乘客在事故中是否生还(分类任务)。
数据预处理:

**处理缺失值:**填充年龄的缺失值为平均值,填充 Embarked 的缺失值为众数,填充 Fare 的缺失值为均值。

特征工程:将性别特征(Sex)从字符型转化为数值型(男为0,女为1)。

任务目标:通过不同的机器学习算法(如决策树、朴素贝叶斯、支持向量机),对乘客生还与否进行预测,并比较其准确率。

可视化图表

准确率对比折线图:
比较三种模型(决策树、朴素贝叶斯、SVM)的准确率。
保存为 model_comparison_accuracy.png。
决策树的图示:
生成决策树的结构图,并保存为 decision_tree.png。
混淆矩阵图:
为每个模型生成并保存混淆矩阵图,分别为
decision_tree_confusion_matrix.png,naive_bayes_confusion_matrix.png, svm_confusion_matrix.png。

4 模型介绍

image-20241230195502739

1.决策树(Decision Tree)

简介:决策树是一种树形结构的模型,适用于分类和回归任务。它通过一系列的决策规则将数据划分为不同的类别,直到满足某种停止条件。决策树模型易于理解和解释。

关键代码:
from sklearn.tree import DecisionTreeClassifier, plot_tree

决策树模型训练

dt_model = DecisionTreeClassifier(random_state=42)
dt_model.fit(X_train, y_train)

预测

dt_y_pred = dt_model.predict(X_test)

绘制决策树图

plt.figure(figsize=(12, 8))
plot_tree(dt_model, feature_names=features, class_names=['Not Survived', 'Survived'], filled=True, rounded=True)
plt.title("Decision Tree Visualization")
plt.savefig('decision_tree.png')

**模型评估:**通过 accuracy_score 计算准确率,比较实际与预测结果。

结果:
准确率:0.8014
Decision Tree_confusion_matrix.png

img

decision_tree.pngimg

feature_importance.png

img

2.朴素贝叶斯(Naive Bayes)

简介:朴素贝叶斯是一种基于贝叶斯定理的分类算法,适用于特征之间独立性假设成立的场景。它通过计算各类别的条件概率,选择概率最大的类别作为预测结果。

关键代码:
from sklearn.naive_bayes import GaussianNB

朴素贝叶斯模型训练

nb_model = GaussianNB()
nb_model.fit(X_train, y_train)

预测

nb_y_pred = nb_model.predict(X_test)
模型评估:同样使用 accuracy_score 来评估朴素贝叶斯模型的准确率。

结果:
准确率:0.9306
Naïve Bayes_confusion_matrix.png

img

3.支持向量机(Support Vector Machine, SVM)

简介:支持向量机是一种基于最大间隔原则的分类模型,主要用于二分类问题。SVM通过构造一个超平面来实现分类,使得不同类别的数据点距离超平面尽可能远。

关键代码:
from sklearn.svm import SVC

支持向量机模型训练

svm_model = SVC(random_state=42)
svm_model.fit(X_train, y_train)

预测

svm_y_pred = svm_model.predict(X_test)

模型评估:使用 accuracy_score 来评估SVM模型的准确率。
混淆矩阵: 对于每个模型,生成混淆矩阵,评估模型的精确度、召回率等。
准确率对比折线图: 通过 seaborn 和 matplotlib 绘制不同模型的准确率对比折线图,方便比较各算法的性能。

代码示例:

绘制模型准确率折线图

accuracies = [dt_accuracy, nb_accuracy, svm_accuracy]
models = ['Decision Tree', 'Naive Bayes', 'SVM']
sns.lineplot(x=models, y=accuracies, marker='o')
plt.title('Model Comparison: Accuracy of Decision Tree, Naive Bayes, and SVM')
plt.ylabel('Accuracy')
plt.savefig('model_comparison_accuracy.png')
plt.show()

结果:
准确率:0.6531
SVM_confusion_matrix.png

img

完整代码:

#导入需要的库

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import os
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier, plot_tree
from sklearn.naive_bayes import GaussianNB
from sklearn.svm import SVC
from sklearn.metrics import accuracy_score, confusion_matrix
from sklearn.preprocessing import LabelEncoder

#加载数据

train_df = pd.read_csv('mytrain.csv')
test_df = pd.read_csv('mytest.csv')
gender_df = pd.read_csv('mygender.csv')

#数据预处理

#填充缺失值,删除无关列

train_df.fillna({'Age': train_df['Age'].mean(), 'Embarked': train_df['Embarked'].mode()[0]}, inplace=True)
test_df.fillna({'Age': test_df['Age'].mean(), 'Fare': test_df['Fare'].mean()}, inplace=True)

#转换性别列为数字(男=0,女=1)

label_encoder = LabelEncoder()
train_df['Sex'] = label_encoder.fit_transform(train_df['Sex'])
test_df['Sex'] = label_encoder.transform(test_df['Sex'])

#特征选择 - 这里只选择了一些简单的特征

features = ['Pclass', 'Sex', 'Age', 'SibSp', 'Parch', 'Fare']
X_train = train_df[features]
y_train = train_df['Survived']
X_test = test_df[features]

#创建保存结果的文件夹

def create_folder(folder_name):
    if not os.path.exists(folder_name):
        os.makedirs(folder_name)

#初始化准确率对比列表

accuracies = []

#决策树模型

dt_model = DecisionTreeClassifier(random_state=42)
dt_model.fit(X_train, y_train)
dt_y_pred = dt_model.predict(X_test)
dt_accuracy = accuracy_score(gender_df['Survived'], dt_y_pred)
accuracies.append(dt_accuracy)

#朴素贝叶斯模型

nb_model = GaussianNB()
nb_model.fit(X_train, y_train)
nb_y_pred = nb_model.predict(X_test)
nb_accuracy = accuracy_score(gender_df['Survived'], nb_y_pred)
accuracies.append(nb_accuracy)

#支持向量机模型

svm_model = SVC(random_state=42)
svm_model.fit(X_train, y_train)
svm_y_pred = svm_model.predict(X_test)
svm_accuracy = accuracy_score(gender_df['Survived'], svm_y_pred)
accuracies.append(svm_accuracy)

#可视化决策树

def visualize_decision_tree(model, folder_name):
    plt.figure(figsize=(12, 8))
    plot_tree(model, feature_names=features, class_names=['Not Survived', 'Survived'], filled=True, rounded=True)
    plt.title("Decision Tree Visualization")
    plt.savefig(os.path.join(folder_name, "decision_tree.png"))
    plt.close()

#可视化混淆矩阵

def visualize_confusion_matrix(y_true, y_pred, folder_name, model_name):
    conf_matrix = confusion_matrix(y_true, y_pred)
    plt.figure(figsize=(6, 4))
    sns.heatmap(conf_matrix, annot=True, fmt='d', cmap='Blues', xticklabels=['Not Survived', 'Survived'], yticklabels=['Not Survived', 'Survived'])
    plt.title(f'{model_name} Confusion Matrix')
    plt.ylabel('Actual')
    plt.xlabel('Predicted')
    plt.savefig(os.path.join(folder_name, f"{model_name}_confusion_matrix.png"))
    plt.close()

#可视化特征重要性(仅限决策树)

def visualize_feature_importance(model, folder_name):
    feature_importances = model.feature_importances_
    sns.barplot(x=features, y=feature_importances)
    plt.title('Feature Importances - Decision Tree')
    plt.ylabel('Importance')
    plt.savefig(os.path.join(folder_name, "feature_importance.png"))
    plt.close()

#保存决策树结果

dt_folder = 'decision_tree_results'
create_folder(dt_folder)
visualize_decision_tree(dt_model, dt_folder)
visualize_confusion_matrix(gender_df['Survived'], dt_y_pred, dt_folder, 'Decision Tree')
visualize_feature_importance(dt_model, dt_folder)

#保存朴素贝叶斯结果

nb_folder = 'naive_bayes_results'
create_folder(nb_folder)
visualize_confusion_matrix(gender_df['Survived'], nb_y_pred, nb_folder, 'Naive Bayes')

#保存支持向量机结果

svm_folder = 'svm_results'
create_folder(svm_folder)
visualize_confusion_matrix(gender_df['Survived'], svm_y_pred, svm_folder, 'SVM')

#绘制三种模型准确率的折线图

plt.figure(figsize=(8, 6))
models = ['Decision Tree', 'Naive Bayes', 'SVM']
sns.lineplot(x=models, y=accuracies, marker='o')
plt.title('Model Comparison: Accuracy of Decision Tree, Naive Bayes, and SVM')
plt.ylabel('Accuracy')
plt.savefig('model_comparison_accuracy.png')
plt.show()

#输出各模型准确率

print(f"Decision Tree Accuracy: {dt_accuracy:.4f}")
print(f"Naive Bayes Accuracy: {nb_accuracy:.4f}")
print(f"SVM Accuracy: {svm_accuracy:.4f}")

三种算法对比折线图:
Model_comparison_accuract.png

img

5 结论

总结: 通过本次任务,我使用了三种不同的机器学习算法:决策树、朴素贝叶斯和支持向量机,来预测泰坦尼克号乘客的生还与否。每个算法都进行了训练并生成了相应的预测结果。通过对比准确率,我发现不同模型在本任务中的表现有所不同。
收获:
学会了如何进行数据预处理,如处理缺失值、转换类别特征等。
掌握了多种常用分类算法的使用,包括决策树、朴素贝叶斯和支持向量机,并能够对其进行评估。
学会了如何使用混淆矩阵和准确率等评估指标来评价模型的表现。
通过可视化,能够更清晰地理解模型的预测结果以及不同模型之间的差异。

6.训练集测试集代码链接

我就直接把整个项目的都放在这了

通过网盘分享的文件:PythonProject
链接: https://pan.baidu.com/s/1VXbp32N29owHKvdf7F0eoA 提取码: zp5u


http://www.kler.cn/a/464568.html

相关文章:

  • 密钥管理系统在数据安全解决方案中的重要性
  • AF3 AtomAttentionEncoder类解读
  • asp.net core框架搭建4-部署IIS/Nginx/Docker
  • 【MySQL基础篇】三、表结构的操作
  • pandas-栗子
  • java实验4 反射机制
  • ruoyi开发学习
  • 点击取消按钮,console出来数据更改了,页面视图没有更新
  • 初学STM32 ---高级定时器互补输出带死区控制
  • antd-vue - - - - - a-date-picker限制选择范围
  • 【SOC 芯片设计 DFT 学习专栏 -- DFT 为何需要在综合之后插入】
  • 如何通过API接入电竞数据
  • 检测碳化硅外延晶片表面痕量金属的方法
  • 大模型系列17-RAGFlow搭建本地知识库
  • Linux-Redis哨兵搭建
  • 34.键盘1 C#例子 WPF例子
  • strapi中使用Documentation插件
  • [XCTF/网络安全] Python之Django模块+curl 攻防世界 Cat 解题详析
  • 2011-2020年各省粗离婚率数据
  • 谷粒商城项目125-spring整合high-level-client
  • C++简明教程(14)动态库和静态库的内存共享机制
  • 在基于Centos7的服务器上启用【Gateway】的【Clion Nova】(即 ReSharper C++ 引擎)
  • 文件查找工具locate和find
  • 【ShuQiHere】 集成学习:提升模型性能的有效策略
  • [Qt] 常用控件 | QWidget | “表白程序2.0”
  • 按字段拆分多个工作表到独立的工作簿并增加合计-Excel易用宝