当前位置: 首页 > article >正文

【python因果推断库15】使用 sci-kit learn 模型进行回归断点分析

目录

导入数据

线性模型和主效应模型

线性模型、主效应模型和交互作用模型

使用bandwidth


from sklearn.gaussian_process import GaussianProcessRegressor
from sklearn.gaussian_process.kernels import ExpSineSquared, WhiteKernel
from sklearn.linear_model import LinearRegression

import causalpy as cp
%config InlineBackend.figure_format = 'retina'

导入数据

data = cp.load_data("rd")
data.head()

线性模型和主效应模型

result = cp.skl_experiments.RegressionDiscontinuity(
    data,
    formula="y ~ 1 + x + treated",
    model=LinearRegression(),
    treatment_threshold=0.5,
)
fig, ax = result.plot()

result.summary(round_to=3)
Difference in Differences experiment
Formula: y ~ 1 + x + treated
Running variable: x
Threshold on running variable: 0.5

Results:
Discontinuity at threshold = 0.19
Model coefficients:
  Intercept      	         0
  treated[T.True]	      0.19
  x              	      1.23

线性模型、主效应模型和交互作用模型

result = cp.skl_experiments.RegressionDiscontinuity(
    data,
    formula="y ~ 1 + x + treated + x:treated",
    model=LinearRegression(),
    treatment_threshold=0.5,
)
result.plot();

虽然我们可以看到这样做并不能很好地拟合数据,几乎肯定高估了阈值处的不连续性。 

result.summary(round_to=3)
Difference in Differences experiment
Formula: y ~ 1 + x + treated + x:treated
Running variable: x
Threshold on running variable: 0.5

Results:
Discontinuity at threshold = 0.92
Model coefficients:
  Intercept        	         0
  treated[T.True]  	      2.47
  x                	      1.32
  x:treated[T.True]	     -3.11

使用bandwidth

我们处理这个问题的一种方法是使用 `bandwidth` 参数。这将只对阈值附近的一定带宽内的数据进行拟合。如果 x 是连续变量,那么模型将只对满足 threshold-bandwidth\leq x\leq threshold+bandwidth 的数据进行拟合。

result = cp.skl_experiments.RegressionDiscontinuity(
    data,
    formula="y ~ 1 + x + treated + x:treated",
    model=LinearRegression(),
    treatment_threshold=0.5,
    bandwidth=0.3,
)

result.plot();

 我们甚至可以走得更远,只为接近阈值的数据拟合截距。但很明显,这将涉及更多的估计误差,因为我们使用的数据较少。

result = cp.skl_experiments.RegressionDiscontinuity(
    data,
    formula="y ~ 1 + treated",
    model=LinearRegression(),
    treatment_threshold=0.5,
    bandwidth=0.3,
)

result.plot();


http://www.kler.cn/a/305245.html

相关文章:

  • 腾讯云AI代码助手编程挑战赛——贪吃蛇小游戏
  • python学opencv|读取图像(三十二)使用cv2.getPerspectiveTransform()函数制作透视图-变形的喵喵
  • 【简博士统计学习方法】第2章:3. 感知机——学习算法之原始形式:算法解说
  • 多线程面试相关
  • 大数据智能选课系统
  • React 元素渲染
  • Linux基础-Makefile的编写、以及编写第一个Linux程序:进度条(模拟在 方便下载的同时,更新图形化界面)
  • ubuntu 22.04 ~24.04 如何修改登录背景
  • 【JavaScript】LeetCode:707设计链表
  • Python版《天天酷跑+源码》,详细讲解,手把手教学-python游戏开发
  • jmeter设置全局token
  • (180)时序收敛--->(30)时序收敛三十
  • 大模型教程:使用 Milvus、vLLM 和 Llama 3.1 搭建 RAG 应用
  • 怎么让手机ip地址变化?介绍几种实用方法
  • uniapp 微信小程序自定义tabbar层级低于canvas解决方案
  • 见刊丨“GPU池化”术语发布
  • 本地内存和分布式缓存(面试)
  • Python Web 开发中的性能优化策略(二)
  • git 命令---想要更改远程仓库
  • 指针与函数传递
  • C++速通LeetCode简单第12题-二叉树的直径
  • 深度学习-目标检测(四)-Faster R-CNN
  • C#实现串口中继
  • 不废话简单易懂的Selenium 页面操作与切换
  • Python实现一个简单的爬虫程序(爬取图片)
  • postgresql 导出CSV格式数据