当前位置: 首页 > article >正文

【python因果库实战13】因果生存分析2

加权生存分析

参见《因果推断》一书第17.4节(“边际结构模型的逆概率加权”)。

我们可以通过使用 causallib 的 WeightEstimator(如 IPW)生成一个加权的伪人群来调整混淆因子,使得戒烟者和非戒烟者变得可互换。然后,我们可以在这个加权人群中计算生存曲线以得到无混淆效应。

import warnings

from sklearn.exceptions import ConvergenceWarning
from sklearn.linear_model import LogisticRegression

from causallib.estimation import IPW
from causallib.evaluation import evaluate

warnings.simplefilter(action="ignore", category=FutureWarning)
warnings.simplefilter(action="ignore", category=ConvergenceWarning)

# Fit an inverse propensity model
ipw = IPW(learner=LogisticRegression(max_iter=800))
ipw.fit(X, a)

# Evaluate
evaluation_results = evaluate(ipw, X, a, y, metrics_to_evaluate={})
f, ax = plt.subplots(figsize=(8, 8))
evaluation_results.plot_covariate_balance(kind="love", ax=ax, phase="train");

在这里插入图片描述

我们现在有了一个 IPW 模块,它实现了良好的特征平衡(加权后所有特征的标准均值差 SMD < 0.1)。让我们将 IPW 与生存分析结合起来,使用 causalib.survival.WeightedSurvival 模块。

from causallib.survival.weighted_survival import WeightedSurvival

# Compute adjusted survival curves
weighted_survival = WeightedSurvival(weight_model=ipw)
weighted_survival.fit(
    X, a
)  # fit weight model (we can actually skip this since it was already fitted above)
population_averaged_survival_curves = weighted_survival.estimate_population_outcome(X, a, t, y)

plot_survival_curves(
    population_averaged_survival_curves,
    labels=["non-quitters", "quitters"],
    title="IPW-adjusted survival of smoke quitters vs. non-quitters in a 10 years observation period",
)

在这里插入图片描述

调整后的10年生存率差异减小(可能减少到了不显著的程度。这一点可以通过自助抽样等方法确定)。
这些曲线是使用内部的、非参数默认 Kaplan-Meier 估计器生成的。它们可以通过使用参数化风险模型来进一步平滑。需要注意的是,加权风险模型只基于时间条件进行建模,例如,它不考虑协变量(不像标准化,见下文)。

# Compute adjusted survival curves with a parametric hazards model
weighted_survival = WeightedSurvival(
    weight_model=ipw, survival_model=LogisticRegression()
)  # note the survival_model param (use a parametric hazards model)
weighted_survival.fit(X, a)
population_averaged_survival_curves = weighted_survival.estimate_population_outcome(X, a, t, y)

plot_survival_curves(
    population_averaged_survival_curves,
    labels=["non-quitters", "quitters"],
    title="IPW-adjusted survival of smoke quitters vs. non-quitters in a 10 years observation period, parametric curves",
)

在这里插入图片描述

这些曲线可能过于平滑,因为它们是用线性风险模型拟合的。我们可以使用一些特征工程来构建更具表现力的替代方案:

# Compute adjusted survival curves with a parametric hazards model and feature engineering
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import PolynomialFeatures

# Init sklearn pipeline with feature transformation and logistic regression
pipeline = Pipeline(
    [("transform", PolynomialFeatures(degree=2)), ("LR", LogisticRegression(max_iter=1000))]
)

weighted_survival = WeightedSurvival(
    weight_model=ipw, survival_model=pipeline
)  # note the survival_model param (use a parametric hazards model)
weighted_survival.fit(X, a)
population_averaged_survival_curves = weighted_survival.estimate_population_outcome(X, a, t, y)

plot_survival_curves(
    population_averaged_survival_curves,
    labels=["non-quitters", "quitters"],
    title="IPW-adjusted survival of smoke quitters vs. non-quitters in a 10 years observation period, parametric curves ver. 2",
)

在这里插入图片描述

或者,我们可以插入 lifelines 包中的任何 UnivariateFitter 类,例如 PiecewiseExponentialFitterWeibullFitter

# Compute adjusted survival curves with a lifelines PieacewiseExponentialFitter
weighted_survival = WeightedSurvival(
    weight_model=ipw,
    survival_model=lifelines.PiecewiseExponentialFitter(breakpoints=range(1, 120, 30)),
)
weighted_survival.fit(X, a)
population_averaged_survival_curves = weighted_survival.estimate_population_outcome(X, a, t, y)

plot_survival_curves(
    population_averaged_survival_curves,
    labels=["non-quitters", "quitters"],
    title="IPW-adjusted survival of smoke quitters vs. non-quitters in a 10 years observation period, lifelines PiecewiseExponentialFitter",
)

在这里插入图片描述


http://www.kler.cn/a/455681.html

相关文章:

  • Debian12 安装配置 ODBC for GaussDB
  • Excel无法插入新单元格怎么办?有解决方法吗?
  • Vue axios 异步请求,请求响应拦截器
  • Docker 安装mysql ,redis,nacos
  • Playwright爬虫xpath获取技巧
  • SpringBoot使用Validation校验参数
  • STM32-笔记11-手写带操作系统的延时函数
  • 笔记:使用python对飞书用户活跃度统计的一个尝试
  • Go Redis实现排行榜
  • 神经网络-ResNet
  • 【社区投稿】自动特征auto trait的扩散规则
  • Effective C++ 条款32:确定你的 public 继承塑模出 is-a 关系
  • Pytorch | 利用IE-FGSM针对CIFAR10上的ResNet分类器进行对抗攻击
  • 如何在 Spring Boot 微服务中设置和管理多个数据库
  • 对外发PDF设置打开次数
  • Python机器学习笔记(十五、聚类算法的对比和评估)
  • 【JavaEE进阶】@RequestMapping注解
  • 《一文读懂BP神经网络:从原理到应用》
  • redis中,msyql数据库读写分离搭建
  • Gitlab17.7+Jenkins2.4.91实现Fastapi/Django项目持续发布版本详细操作(亲测可用)
  • C语言-结构体嵌套
  • 服务器数据恢复—Lustre分布式文件系统下服务器节点进水的数据恢复案例
  • 【网络工程师教程】六、网络互联与互联网
  • github codespaces推送镜像时unauthorized: access token has insufficient scopes 解决方法
  • 云服务器yum无法解析mirrorlist.centos.org
  • 网安瞭望台第17期:Rockstar 2FA 故障催生 FlowerStorm 钓鱼即服务扩张现象剖析