当前位置: 首页 > article >正文

机器学习实战:泰坦尼克号乘客生存率预测(数据处理+特征工程+建模预测)

项目描述

任务:根据训练集数据中的数据预测泰坦尼克号上哪些乘客能生存下来

数据源:csv文件(train.csv)

目标变量:Survived(0-1变量)

数据集预览:

1、英文描述:

2、译文描述:

初步分析

注:代码后紧跟运行结果截图

1、查看数据

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

# 导入数据
train = pd.read_csv('train.csv')
print(train.head(10))

2、查看数据量与缺失值 

# 查看数据基本信息
print(train.info())

# 结果:共891条数据,共12列,部分列含有缺失值

# 方便查看哪些特征含缺失值
print(train.isnull().sum())

3、分析不同性别的生存率

# 计算不同性别的生存率
survived_sex = train.groupby('Sex')['Survived'].mean()
print(survived_sex)
# 可视化
sns.barplot(x='Sex', y='Survived', data=train, errorbar=None, hue='Sex', palette='viridis')

# 结论:女性生存率明显高于男性

3、分析不同年龄段的生存率

# 计算不同年龄段的生存率
train['Age_group'] = pd.cut(train["Age"], bins=[0, 5, 14, 18, 30, 60, 110], 
                            labels=['baby', 'child', 'teenager', 'adult', 'old_adult', 'old'], right=False)
survived_age = train.groupby('Age_group', observed=False)['Survived'].mean()
print(survived_age)
# 可视化
order = ['baby', 'child', 'teenager', 'adult', 'old_adult', 'old']
sns.barplot(x='Age_group', y='Survived', data=train, order=order, errorbar=None, hue='Age_group', palette='viridis')
plt.show()

# 结论:婴儿生存率明显更高,老人生还率明显更低,随年龄的增长,生存率有下降的趋势

 4、分析不同 Plass 的乘客的生存率

# 计算不同 Plass 的乘客的生存率
survival_pclass = train.groupby('Pclass')['Survived'].mean()
print(survival_pclass)
# 可视化
sns.barplot(x='Pclass', y='Survived', data=train, errorbar=None)
plt.show()

# 结论:1、2、3号的生存率由高到低排列

5、 分析不同票价的乘客的生存率

# 计算不同票价的乘客的生存率
train['Fare_group'] = pd.cut(train['Fare'], bins=[0, 5, 15, 25, 50, 100, 200, float('inf')], 
                             labels=['very low', 'low', 'normal', 'medium', 'high', 'very high', 'Luxury'], right=False)
print(train.groupby('Fare_group', observed=False)['Survived'].mean())
# 可视化
sns.barplot(x='Fare_group', y='Survived', data=train, errorbar=None, hue='Fare_group', palette='viridis')
plt.show()

# 结论:票价越高,生存率越高,高价位生存率在 60% 以上

数据处理与特征工程

1、查看文本特征的取值

# 查看训练集中,名字特征有哪些头衔(代表乘客所属阶级)
print(train['Name'].str.extract(' ([A-Za-z]+\.)')[0].unique())
# 查看训练集 Cabin 有哪些取值
print('Cabin 仓号首字母有:', train['Cabin'].unique())  

 2、数据处理与特征工程

# 数据处理与特征工程

from sklearn.impute import KNNImputer
from sklearn.preprocessing import StandardScaler

def pre(df):
    # 家庭人数
    df['family'] = df["Parch"] + df["SibSp"]
    # 提取名字头衔,表示阶级信息
    df['Title'] = df['Name'].str.extract(' ([A-Za-z]+\.)')
    df['Title'] = df['Title'].str.replace('.', '')
    df['Title'] = df['Title'].replace(['Don', 'Lady', 'Capt', 'Col', 'Dr', 'Major', 'Rev', 'Sir', 'Jonkheer', 'Dona'], 'Rare')
    df['Title'] = df['Title'].replace(['Mlle', 'Ms'], 'Miss')
    df['Title'] = df['Title'].replace('Mme', 'Mrs')
    # 缺失值处理
    df['Embarked'].fillna(df['Embarked'].mode()[0])   # 众数填充
    df['Fare'].fillna(df['Fare'].median())   # 中位数填充
    # 手动对 Fare 进行独热编码
    df['very low'] = (df['Fare']<5)
    df['low'] = (df['Fare']<15) & (df['Fare']>=5)
    df['normal'] = (df['Fare']<25) & (df['Fare']>=15)
    df['medium'] = (df['Fare']<50) & (df['Fare']>=25)
    df['high'] = (df['Fare']<100) & (df['Fare']>=50)
    df['very high'] = (df['Fare']<300) & (df['Fare']>=100)
    df['Luxury'] = (df['Fare']>=300)
    # 填充 Cabin 缺失值,并用仓号第一个字母替代原来的仓号
    df['Cabin'] = df['Cabin'].fillna('N').map(lambda x: x[0])
    # KNN 算法填充 Age 的缺失值
    imputer = KNNImputer(n_neighbors=5)
    df['Age'] = imputer.fit_transform(df[['Age']])
    # 对 Age 分组(加上填补的值进分组里),并手动进行标签编码
    df['Age_group'] = pd.cut(df["Age"], bins=[0, 5, 14, 18, 30, 60, 110], 
                            labels=[1, 2, 3, 4, 5, 6], right=False)
    # 独热编码
    df = pd.get_dummies(df, columns=['Embarked', 'Title', 'Cabin'], drop_first=True)
    df['Sex'] = df['Sex'].map({'male': 0, 'female': 1})
    # 处理特征出现在测试集而不在训练集的可能情况
    for p in ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'T']:
        col = f'Cabin_{p}'
        if col not in df.columns:   # 若存在训练集没出现过的字母,则新建一列
            df[col] = pd.Series([False]*df.shape[0])
    # 选择对目标变量可能有影响的特征(挑选特征)
    cols = ['Pclass', 
            'Title_Miss', 'Title_Mr', 'Title_Mrs','Title_Rare',
            'Sex',
            'Age', 'Age_group',
            'SibSp', 'Parch',
            'Fare', 'very low', 'low', 'normal', 'medium', 'high', 'very high', 'Luxury',
            'Cabin_A', 'Cabin_B', 'Cabin_C', 'Cabin_D', 'Cabin_E', 'Cabin_F','Cabin_G', 'Cabin_T'  # 忽略缺失值 N 
           ]
    df2 = df[cols].copy()
    # 标准化数值特征
    num_features = ['Age', 'Fare', 'Parch', 'SibSp', 'Age_group']
    scaler = StandardScaler()
    df2[num_features] = scaler.fit_transform(df2[num_features])

    return df2

建模过程

# 建模、模型评估

from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.svm import SVC
from sklearn.neighbors import KNeighborsClassifier
from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier
from sklearn.metrics import accuracy_score

train = pd.read_csv('train.csv')
# 特征工程,调用前面定义的函数
train_ = pre(train)   
X = train_  #特征变量
y = train['Survived']   #目标变量
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)

models = {
    '逻辑回归': LogisticRegression(),
    '支持向量机SVM': SVC(),
    'KNN': KNeighborsClassifier(),
    '随机森林': RandomForestClassifier(),
    '梯度提升树': GradientBoostingClassifier()
}
accuary_cores = {
    '逻辑回归': 0,
    '支持向量机SVM': 0,
    'KNN': 0,
    '随机森林': 0,
    '梯度提升树': 0
}

# 计算模型在各自默认参数下反复训练 n 次的平均准确率
n = 20
for i in range(n):
    for model_name, model in models.items():
        model.fit(X_train, y_train)
        y_pred = model.predict(X_test)
        accuracy = accuracy_score(y_test, y_pred)
        accuary_cores[model_name] += accuracy
print('默认参数下各模型的平均准确率:')
for model_name, accuracy in accuary_cores.items():
    accuracy /= n
    print(f'{model_name}:  {accuracy:.4f}')

注:关于上面多个模型,各有优缺点,模型选择取决于具体任务需求(比如模型解释性需求、模型准确度需求、模型精确度需求等)来选择合适模型,以上模型模型参数还待优化。项目还有很多地方可以完善,我们下期再见叭!

# 若对大噶有帮助的话,希望点个赞支持一下叭!

# 文章若有错误,欢迎大家不吝赐教!


http://www.kler.cn/a/422107.html

相关文章:

  • flink-connector-mysql-cdc:03 mysql-cdc常见问题汇总
  • vue2+svg+elementui实现花瓣图自定义el-select回显色卡图片
  • Spire.PDF for .NET【页面设置】演示:旋放大 PDF 边距而不改变页面大小
  • 一款支持80+语言,包括:拉丁文、中文、阿拉伯文、梵文等开源OCR库
  • Flutter-Web打包后上线白屏
  • blender 视频背景
  • hhdb数据库介绍(10-22)
  • 【Python】一、最新版Python基础知识总结、综合案例实战
  • 【软考网工笔记】网络基础理论——传输层
  • Subprocess 和 Multiprocessing 的区别与使用要点及进程关闭方法
  • ElasticSearch7.x入门教程之全文搜索聚合分析(十)
  • MongoDB复制(副本)集实战及原理分析
  • 1.Git安装与常用命令
  • tcpdump抓包wireshark分析
  • RNN模型介绍
  • 揭秘MySQL:探索那些鲜为人知的数据类型宝藏
  • 基于 MVC 架构的 SpringBoot 高校行政事务管理系统:设计优化与实现验证
  • postgresql与pgvector安装与使用
  • PHP语法学习(第一天)
  • 【Pytorch】torch.reshape与torch.Tensor.reshape区别
  • 洛谷P1241 括号序列(c嘎嘎)
  • 无人机倾斜摄影测绘三维建模技术详解
  • Linux-虚拟环境
  • Qt Qtablewidget 标题 QHeaderView 增加可选框 QcheckBox
  • 【Vue3】【Naive UI】<NAutoComplete>标签
  • 【Flink】Flink Checkpoint 流程解析