当前位置: 首页 > article >正文

KAGGLE竞赛实战2-捷信金融违约预测竞赛-part2-用lightgbm建立baseline

接着上一篇,用lightgbm建立baseline,发现模型效果得到了很大优化(模型分提升为0.73)

# In[211]:


from sklearn.model_selection import cross_val_score,KFold


# In[228]:


import lightgbm as lgb


# In[229]:


from lightgbm import LGBMClassifier 


# In[232]:


from lightgbm import early_stopping


# In[237]:


from sklearn.metrics import roc_auc_score


# In[215]:


bad_chars=[':','""','\\','']
for feature in application_train.columns:
    if any(bad_char in feature for bad_char in bad_chars):
        print(f"Feature '{feature}'包含非法字符。")


# In[216]:


#去掉特殊字符import pandas as pd
import re
def clean_column_names(df):
    """
    清理DataFrame列名,去除特殊字符,使其符合JSON格式要求。
    
    参数:
    df (pd.DataFrame): 输入的DataFrame。
    
    返回:
    pd.DataFrame: 列名已清理的DataFrame。
    """
    # 定义一个函数,用于替换单个列名中的特殊字符
    def replace_chars(col_name):
        # 替换掉所有非字母数字和下划线的字符
        return re.sub(r'\W+', '_', col_name)
    
    # 应用替换函数到所有列名
    df.columns = [replace_chars(col) for col in df.columns]
    return df


    


# In[217]:


application_train=clean_column_names(application_train)
application_test=clean_column_names(application_test)


# In[218]:


bad_chars=[':','""','\\','']
for feature in application_train.columns:
    if any(bad_char in feature for bad_char in bad_chars):
        print(f"Feature '{feature}'包含非法字符。")


# In[238]:


def fit(train=application_train, valid=application_test):
    """
    模型训练函数,
    参数:train训练集
    valid测试集
    返回值:
    valid_auc:验证集上AUC指标
    feature_importances:特征重要性
    test_results:测试集结果
    """
    test = valid.copy()
    x_train = train.drop(['SK_ID_CURR', 'TARGET'], axis=1)
    y_train = train['TARGET']
    # 五折交叉验证
    folds = KFold(n_splits=5, shuffle=True, random_state=1412)
    # 定义变量保存预测结果
    oof_preds = np.zeros(y_train.shape[0])
    test_preds = np.zeros(test.shape[0])
    # 提取特征名
    feature_names = list(x_train.columns)
    # 空数组用于存储特征重要性值
    feature_importance_values = np.zeros(len(feature_names))
    # 实例化模型
    lgb = LGBMClassifier(n_estimators=10000, early_stopping_round=200, random_state=24)
    for fold_idx, (train_idx, valid_idx) in enumerate(folds.split(x_train)):
        X = x_train.iloc[train_idx]
        y = y_train.iloc[train_idx]
        valid_X = x_train.iloc[valid_idx]
        valid_y = y_train.iloc[valid_idx]
        # 定义早停回调函数
        callbacks = [early_stopping(stopping_rounds=200)]
        # 拟合模型
        lgb.fit(X, y, eval_set=[(X, y), (valid_X, valid_y)], callbacks=callbacks)
        # 记录特征重要性
        feature_importance_values += lgb.feature_importances_ / folds.n_splits
        # 在验证集上进行预测
        proba = lgb.predict_proba(valid_X, num_iteration=lgb.best_iteration_)
        oof_preds[valid_idx] = proba[:, 1]  # 选择正类概率
        test_preds += lgb.predict_proba(test[feature_names], num_iteration=lgb.best_iteration_)[:, 1]
    valid_auc = roc_auc_score(y_train, oof_preds)
    feature_importances = pd.DataFrame({'feature': feature_names, 'importance': feature_importance_values})
    test['TARGET'] = test_preds
    return valid_auc, feature_importances, test[['SK_ID_CURR', 'TARGET']]


# In[246]:


valid_auc,feature_importance,submission=fit(application_train[:50000],application_test)
#发现报错了Do not support special JSON characters in feature name,原因是有些列名里有特殊的字符,这是get_dummies时产生的


# In[247]:


valid_auc #0.7421786519800682


# In[248]:


#看下特征重要性
feature_importance.sort_values(by='importance',ascending=False)


# In[250]:


submission.to_csv('baseline_model_lightgbm.csv',index=False)#提交后成绩0.73


# In[251]:


application_train.to_csv('original_application_train.csv')#保存下结果
application_test.to_csv('original_application_test.csv')


http://www.kler.cn/a/509896.html

相关文章:

  • ScratchLLMStepByStep:训练自己的Tokenizer
  • 适配器模式详解:解决接口不兼容问题的灵活设计模式
  • 【Linux系统编程】—— 深入理解Linux中的环境变量与程序地址空间
  • 【JVM中的三色标记法是什么?】
  • 记录 idea 启动 tomcat 控制台输出乱码问题解决
  • mysql 与Redis 数据强一致方案
  • pnpm介绍
  • Java进程内缓存介绍
  • 部署启动nacos报错No DataSource set 及master-db not found
  • 智能学习平台系统设计与实现(代码+数据库+LW)
  • 如何用AI优化自动化回归测试
  • 基于 Android 的个人健康管理 APP 设计与实现
  • Linux探秘坊-------3.开发工具详解(1)
  • 物联网网关Web服务器--Boa服务器移植与测试
  • 某国际大型超市电商销售数据分析和可视化
  • Vue进阶之旅:组件通信与高级用法深度剖析(组件通信进阶用法)
  • 大华C++开发面试题及参考答案
  • opencv对直方图的计算和绘制
  • 网络安全行业岗位职责
  • SSM旅游信息管理系统
  • ros 机器人地图转化为gis地图
  • DNS未响应服务问题的解决(电脑连着网但浏览器访问不了网页)
  • C#高级:通过 Assembly 类加载 DLL 和直接引用DLL的方法大全
  • Chromium 132 编译指南 Linux 篇 - 同步第三方库以及 Hooks(六)
  • Python:两数之和
  • 当使用 npm 时,出现 `certificate has expired` 错误通常意味着请求的证书已过期。