【python因果库实战15】因果生存分析4
这里写目录标题
- 加权标准化生存分析
- 总结
- 个体层面的生存曲线
加权标准化生存分析
我们还可以将加权与标准化结合起来,使用 WeightedStandardizedSurvival
模块。在这里,我们将逆倾向得分加权模型(根据基线协变量重新加权人群)与加权回归以及标准化模型相结合:
from causallib.survival.weighted_standardized_survival import WeightedStandardizedSurvival
ipw = IPW(learner=LogisticRegression(max_iter=2000))
poly_transform_pipeline = Pipeline(
[("transform", PolynomialFeatures(degree=2)), ("LR", LogisticRegression(max_iter=8000, C=1.5))]
)
weighted_standardized_survival = WeightedStandardizedSurvival(
survival_model=poly_transform_pipeline, weight_model=ipw
)
weighted_standardized_survival.fit(X, a, t, y)
population_averaged_survival_curves = weighted_standardized_survival.estimate_population_outcome(
X, a, t
)
plot_survival_curves(
population_averaged_survival_curves,
labels=["non-quitters", "quitters"],
title="Weighted standardized survival of smoke quitters vs. non-quitters in a 10 years observation period",
)
或者,我们也可以使用 lifelines
包中的 RegressionFitter
类,例如 Cox 比例风险拟合器。这是一种加权的 Cox 分析。
ipw = IPW(learner=LogisticRegression(max_iter=1000))
weighted_standardized_survival = WeightedStandardizedSurvival(survival_model=lifelines.CoxPHFitter(), weight_model=ipw)
# Note the fit_kwargs (passed to CoxPHFitter.fit() method)
weighted_standardized_survival.fit(X, a, t, y, fit_kwargs={'robust': True})
# Without setting 'robust=True', we'll get the following warning:
"""StatisticalWarning: It appears your weights are not integers, possibly propensity or sampling scores then?
It's important to know that the naive variance estimates of the coefficients are biased. Instead a) set `robust=True` in the call to `fit`, or b) use Monte Carlo to
estimate the variances."""
population_averaged_survival_curves = weighted_standardized_survival.estimate_population_outcome(X, a, t)
plot_survival_curves(population_averaged_survival_curves,
labels=['non-quitters', 'quitters'],
title='Weighted standardized survival of smoke quitters vs. non-quitters in a 10 years observation period')
总结
不同模型的并列比较。
import itertools
def plot_multiple_models(models_dict):
grid_dims = (int(np.round(np.sqrt(len(models_dict)))), int(np.ceil(np.sqrt(len(models_dict)))))
grid_indices = itertools.product(range(grid_dims[0]), range(grid_dims[1]))
fig, ax = plt.subplots(*grid_dims)
models_names = list(models_dict.keys())
for model_name, plot_idx in zip(models_names, grid_indices):
model = models_dict[model_name]
model.fit(X, a, t, y)
curves = model.estimate_population_outcome(X, a, t, y)
ax[plot_idx].plot(curves[0])
ax[plot_idx].plot(curves[1])
ax[plot_idx].set_title(model_name)
ax[plot_idx].set_ylim(0.7, 1.02)
ax[plot_idx].grid()
plt.tight_layout()
plt.show()
MODELS_DICT = {
"MarginalSurvival Kaplan-Meier": MarginalSurvival(survival_model=None),
"MarginalSurvival LogisticRegression": MarginalSurvival(
survival_model=LogisticRegression(max_iter=2000)
),
"MarginalSurvival PiecewiseExponential": MarginalSurvival(
survival_model=lifelines.PiecewiseExponentialFitter(breakpoints=range(1, 120, 10))
),
"WeightedSurvival Kaplan-Meier": WeightedSurvival(
weight_model=IPW(LogisticRegression(max_iter=2000)), survival_model=None
),
"WeightedSurvival LogisticRegression": WeightedSurvival(
weight_model=IPW(LogisticRegression(max_iter=2000)),
survival_model=LogisticRegression(max_iter=2000),
),
"WeightedSurvival WeibullFitter": WeightedSurvival(
weight_model=IPW(LogisticRegression(max_iter=2000)),
survival_model=lifelines.WeibullFitter(),
),
"StandardizedSurvival LogisticRegression": StandardizedSurvival(
survival_model=LogisticRegression(max_iter=2000)
),
"StandardizedSurvival Cox": StandardizedSurvival(survival_model=lifelines.CoxPHFitter()),
"WeightedStandardizedSurvival": WeightedStandardizedSurvival(
weight_model=IPW(LogisticRegression(max_iter=2000)),
survival_model=LogisticRegression(max_iter=2000),
),
}
plot_multiple_models(MODELS_DICT)
个体层面的生存曲线
在使用直接结果模型(StandardizedSurvival
和 WeightedStandardizedSurvival
)时,可以在 causallib 中生成个体层面的效果估计和生存曲线。
%matplotlib inline
import matplotlib as mpl
import seaborn.objects as so
from sklearn.linear_model import LogisticRegression
from sklearn.pipeline import make_pipeline
from causallib.survival import StandardizedSurvival
from causallib.datasets import load_nhefs_survival
data = load_nhefs_survival(augment=False, onehot=False)
data.t = data.t.rename("longevity")
data.X.join(data.a).join(data.t).join(data.y)
现在让我们创建一个基于公式的数据转换器,以便轻松指定以下两点:
- 使用样条灵活地建模连续变量,
- 创建与所有变量的治疗交互项,以允许效应修正。
from formulaic import Formula
from sklearn.base import BaseEstimator, TransformerMixin
class FormulaTransformer(BaseEstimator, TransformerMixin):
def __init__(self, formula):
super().__init__()
self.formula = formula
def fit(self, X, y=None):
return self
def transform(self, X, y=None):
X_ = Formula(self.formula).get_model_matrix(X)
return X_
formula = f"""
~ 1
+ {data.a.name}*(
C(exercise) + C(active) + C(education)
+ sex + race + bs(age, degree=5)
+ bs(smokeintensity) + bs(smokeyrs)
+ bs(wt71)
+ bs({data.t.name}, degree=5)
)
"""
estimator = make_pipeline(
FormulaTransformer(formula),
LogisticRegression(penalty="none", max_iter=1000)
)
model = StandardizedSurvival(
estimator,
stratify=False,
)
model.fit(data.X, data.a, data.t, data.y)
po = model.estimate_individual_outcome(data.X, data.a, data.t)
po
遵循 lifelines
的惯例,结果的维度将不同的时间点作为行,个体作为列。
列进一步按照治疗分配索引,因为这些值是潜在结果。
这种结构使我们能够像在非生存分析中那样获得个体层面的效果(生存差异):
effect = po[1] - po[0]
# effect
我们现在将结果转置,使其变为长格式,以便后续绘图:
effect = effect.reset_index(names="time").melt(id_vars="time", var_name="id", value_name="effect")
effect
f = mpl.figure.Figure()
# Plot inidividual lines:
p = so.Plot(
effect,
x="time",
y="effect",
group="id",
).add(
so.Lines(linewidth=.5, alpha=0.1, color="#919090")
).label(
title="Spaghetti plot of the effect difference",
).on(f).plot()
# Plot average effect:
avg_effect = effect.groupby("time")["effect"].mean().reset_index()
ax = f.axes[0]
ax.plot(avg_effect["time"], avg_effect["effect"], color="#062f80")
ax.text(
0, 0, "ATE",
verticalalignment="bottom",
color="#062f80"
)
f
一旦我们得到了个体级别的生存曲线,我们可以任意聚合它们来观察效应在不同的协变量分层中是如何变化的。
f = mpl.figure.Figure()
effectX = effect.merge(data.X, left_on="id", right_index=True)
strata = "race"
p_eff_strat = so.Plot(
effectX,
x="time",
y="effect",
color=strata, # Stratify the effect curves by
group="id",
).add(
so.Lines(linewidth=.5, alpha=0.1)
).scale(
color=so.Nominal(["#1f77b4", "#ff7f0e"]),
).label(
title="Spaghetti plot for stratified effects",
).on(f).plot()
p_eff_strat
avg_effect = effectX.groupby(["time", strata])["effect"].mean().reset_index()
ax = f.axes[0]
for s, stratum_data in avg_effect.groupby(strata):
ax.plot(
stratum_data["time"], stratum_data["effect"],
color="black", linestyle="--",
)
ax.text(
stratum_data["time"].iloc[-1], stratum_data["effect"].iloc[-1],
f"{strata}:{s}",
verticalalignment="center",
)
f