当前位置: 首页 > article >正文

因果推断14--DRNet论文和代码学习

目录

论文介绍

代码实现

DRNet

ReadMe

因果森林


论文介绍

因果推断3--DRNet(个人笔记)_万三豹的博客-CSDN博客

摘要:估计个体在不同程度的治疗暴露下的潜在反应,对于医疗保健、经济学和公共政策等几个重要领域具有很高的实际意义。然而,现有的从观察数据中估计反事实结果的学习方法要么专注于估计平均剂量-反应曲线,要么局限于只有两种没有相关剂量参数的治疗方法。在这里,我们提出了一种新的机器学习方法,用于学习反事实表示,用于使用神经网络估计具有连续剂量参数的任意数量治疗的单个剂量-反应曲线。在已建立的潜在结果框架的基础上,我们引入了性能指标、模型选择标准、模型架构和用于估计单个剂量反应曲线的开放基准。我们的实验表明,在这项工作中开发的方法在估计个体剂量反应方面设置了一个新的最先进的方法。

代码实现

GitHub - d909b/drnet: 💉📈 Dose response networks (DRNets) are a method for learning to estimate individual dose-response curves for multiple parametric treatments from observational data using neural networks.

DRNet

def get_method_name_map():
        return {
            'knn': KNearestNeighbours,
            'ols1': OrdinaryLeastSquares1,
            'ols2': OrdinaryLeastSquares2,
            'cf': CausalForest,
            'rf': RandomForest,
            'bart': BayesianAdditiveRegressionTrees,
            'nn': TFNeuralNetwork,
            'nn+': NeuralNetwork,
            'xgb': GradientBoostedTrees,
            'gp': GaussianProcess,
            'psm': PSM,
            'psmpbm': PSM_PBM,
            'ganite': GANITE,
            'gps': GPS,
        }

    def _build_graph(self, input_dim, num_units,
                     num_representation_layers, num_regression_layers, weight_initialisation_std,
                     reweight_sample=False, loss_function="l2",
                     imbalance_penalty_function="wass", rbf_sigma=0.1,
                     wass_lambda=10.0, wass_iterations=10, wass_bpt=True):
        """
        Constructs a TensorFlow subgraph for counterfactual regression.
        Sets the following member variables (to TF nodes):
        self.output         The output prediction "y"
        self.tot_loss       The total objective to minimize
        self.imb_loss       The imbalance term of the objective
        self.pred_loss      The prediction term of the objective
        self.weights_in     The input/representation layer weights
        self.weights_out    The output/post-representation layer weights
        self.weights_pred   The (linear) prediction layer weights
        self.h_rep          The layer of the penalized representation
        """
        ''' Initialize input placeholders '''
        self.x = tf.placeholder("float", shape=[None, input_dim], name='x')
        self.t = tf.placeholder("float", shape=[None, 1], name='t')
        self.y_ = tf.placeholder("float", shape=[None, 1], name='y_')

        ''' Parameter placeholders '''
        self.imbalance_loss_weight = tf.placeholder("float", name='r_alpha')
        self.l2_weight = tf.placeholder("float", name='r_lambda')
        self.dropout_representation = tf.placeholder("float", name='dropout_in')
        self.dropout_regression = tf.placeholder("float", name='dropout_out')
        self.p_t = tf.placeholder("float", name='p_treated')

        dim_input = input_dim
        dim_in = num_units
        dim_out = num_units

        weights_in, biases_in = [], []

        if num_representation_layers == 0:
            dim_in = dim_input
        if num_regression_layers == 0:
            dim_out = dim_in

        ''' Construct input/representation layers '''
        h_rep, weights_in, biases_in = build_mlp(self.x, num_representation_layers, dim_in,
                          self.dropout_representation, self.nonlinearity,
                          weight_initialisation_std=weight_initialisation_std)

        # Normalize representation.
        h_rep_norm = h_rep / safe_sqrt(tf.reduce_sum(tf.square(h_rep), axis=1, keep_dims=True))

        ''' Construct ouput layers '''
        y, y_concat, weights_out, weights_pred = self._build_output_graph(h_rep_norm, self.t, dim_in, dim_out,
                                                                          self.dropout_regression,
                                                                          num_regression_layers,
                                                                          weight_initialisation_std)

        ''' Compute sample reweighting '''
        if reweight_sample:
            w_t = self.t/(2*self.p_t)
            w_c = (1-self.t)/(2*(1-self.p_t))
            sample_weight = w_t + w_c
        else:
            sample_weight = 1.0

        self.sample_weight = sample_weight

        ''' Construct factual loss function '''
        if self.with_pehe_loss:
            risk = pred_error = tf.reduce_mean(sample_weight*tf.square(self.y_ - y)) + \
                                pehe_loss(self.y_, y_concat, self.t, self.x, self.num_treatments) / 10.
        elif loss_function == 'log':
            y = 0.995/(1.0+tf.exp(-y)) + 0.0025
            res = self.y_*tf.log(y) + (1.0-self.y_)*tf.log(1.0-y)

            risk = -tf.reduce_mean(sample_weight*res)
            pred_error = -tf.reduce_mean(res)
        else:
            risk = tf.reduce_mean(sample_weight*tf.square(self.y_ - y))
            pred_error = tf.sqrt(tf.reduce_mean(tf.square(self.y_ - y)))

        ''' Regularization '''
        for i in range(0, num_representation_layers):
            self.weight_decay_loss += tf.nn.l2_loss(weights_in[i])

        p_ipm = 0.5

        if self.imbalance_loss_weight_param == 0.0:
            imb_dist = tf.reduce_mean(self.t)
            imb_error = 0
        elif imbalance_penalty_function == 'mmd2_rbf':
            imb_dist = mmd2_rbf(h_rep_norm, self.t, p_ipm, rbf_sigma)
            imb_error = self.imbalance_loss_weight * imb_dist
        elif imbalance_penalty_function == 'mmd2_lin':
            imb_dist = mmd2_lin(h_rep_norm, self.t, p_ipm)
            imb_error = self.imbalance_loss_weight * mmd2_lin(h_rep_norm, self.t, p_ipm)
        elif imbalance_penalty_function == 'mmd_rbf':
            imb_dist = tf.abs(mmd2_rbf(h_rep_norm, self.t, p_ipm, rbf_sigma))
            imb_error = safe_sqrt(tf.square(self.imbalance_loss_weight) * imb_dist)
        elif imbalance_penalty_function == 'mmd_lin':
            imb_dist = mmd2_lin(h_rep_norm, self.t, p_ipm)
            imb_error = safe_sqrt(tf.square(self.imbalance_loss_weight) * imb_dist)
        elif imbalance_penalty_function == 'wass':
            imb_dist, imb_mat = wasserstein(h_rep_norm, self.t, p_ipm, sq=True,
                                            its=wass_iterations, lam=wass_lambda, backpropT=wass_bpt)
            imb_error = self.imbalance_loss_weight * imb_dist
            self.imb_mat = imb_mat  # FOR DEBUG
        elif imbalance_penalty_function == 'wass2':
            imb_dist, imb_mat = wasserstein(h_rep_norm, self.t, p_ipm, sq=True,
                                            its=wass_iterations, lam=wass_lambda, backpropT=wass_bpt)
            imb_error = self.imbalance_loss_weight * imb_dist
            self.imb_mat = imb_mat  # FOR DEBUG
        else:
            imb_dist = lindisc(h_rep_norm, p_ipm, self.t)
            imb_error = self.imbalance_loss_weight * imb_dist

        ''' Total error '''
        tot_error = risk
        if self.imbalance_loss_weight_param != 0.0:
            tot_error = tot_error + imb_error
        tot_error = tot_error + self.l2_weight*self.weight_decay_loss

        self.output = y
        self.tot_loss = tot_error
        self.imb_loss = imb_error
        self.imb_dist = imb_dist
        self.pred_loss = pred_error
        self.weights_in = weights_in
        self.weights_out = weights_out
        self.weights_pred = weights_pred
        self.h_rep = h_rep
        self.h_rep_norm = h_rep_norm

ReadMe

剂量反应网络(DRNets)是一种用于学习使用神经网络从观测数据估计多参数治疗的个体剂量反应曲线的方法。该存储库包含用于评估DRNets的源代码以及用于评估个体治疗效果的最相关的现有最先进方法(有关结果,请参阅我们的手稿)。为了便于将来的研究,源代码被设计为易于使用(1)新方法和(2)新的基准数据集进行扩展。

作者:Patrick Schwab,苏黎世ETHpatrick.schwab@hest.ethz.ch苏黎世联邦理工学院Lorenz Linhardtllorenz@student.ethz.ch,Stefan Bauer,MPI for Intelligent Systemsstefan.bauer@tuebingen.mpg.de,苏黎世联邦理工学院Joachim M.Buhmannjbuhmann@inf.ethz.ch苏黎世联邦理工学院Walter Karlenwalter.karlen@hest.ethz.ch

许可证:MIT,请参阅License.txt

引用

如果您在工作中引用或使用我们的方法、代码或结果,请考虑引用:

@在过程中{schwab2020-剂量反应,

title={{学习用于估计个体剂量响应曲线的反事实表示}},

作者={施瓦布、帕特里克和林哈特、洛伦兹和鲍尔、斯特凡和布曼、约阿希姆·M和卡伦、沃尔特},

booktitle={{AAAI人工智能会议}},

年={2020}

}

用法:

可运行的脚本位于drnet/apps/子目录中。

drnet/apps/main.py是运行实验的主要可运行脚本。

drnet/apps/parameters.py中描述了可运行脚本的可用命令行参数

您可以通过将drnet/models/baseline/baseline.py子类化,将新的基线方法添加到评估中

有关如何实现自己的基线方法的示例,请参见drnet/models/baselines/neural_network.py。

通过向drnet/apps/main.py中的get_method_name_map方法添加新条目,可以从命令行注册新方法以供使用

您可以通过实现基准接口来添加新的基准,有关如何将自己的基准添加到基准套件的示例,请参见drnet/models/benchmarks。

通过向drnet/apps/evaluate.py中的get_benchmark_name_map方法添加新条目,可以从命令行注册新的基准测试以供使用

要求和相关性

该项目设计用于Python 2.7。我们不能保证,也没有测试过与Python 3的兼容性。

要运行TCGA和News基准,需要下载包含这些基准的原始数据样本的SQLite数据库(News.db和TCGA.db)。

您可以使用以下链接下载原始数据:tcga.db和news.db。

请注意,您需要大约10GB的可用磁盘空间来存储数据库。

将数据库文件保存到/数据目录,以便与下面的分步指南兼容或相应地调整命令。

要运行MVICU基准测试,您需要访问MIMIC-III数据库,由于数据集的敏感性,这需要经过审批过程。

注意,您需要大约75GB的可用磁盘空间来存储带有索引的MIMIC-III数据库。

访问数据集并将MIMIC-III数据加载到SQLite数据库(保存为例如/your/path/to/mimic3.db)后,可以使用drnet/apps/load_db_icu.py脚本将MVICU基准数据从MIMIC-IIII数据库提取到中的单独数据库中/数据文件夹,通过运行:

python drnet/apps/load_db_icu.py/your/path/to/mimic3.db./data

一旦建立,基准数据库将使用大约43MB的磁盘空间。

要运行BART、因果森林和GPS,并再现需要安装R的数字。看见https://www.r-project.org/安装说明。

要运行BART,需要安装R包rJava和bartMachine。看见https://github.com/kapelner/bartMachine安装说明。注意,rJava也需要一个工作的Java安装。

要运行因果森林,需要安装R包grf。看见https://github.com/grf-labs/grf安装说明。

要运行GPS,您需要安装R包causaldrf,例如在R-shell中运行install.packages(“causaldrv”)。

要复制论文的数字,您需要安装R-packagelatex2exp。看见https://cran.r-project.org/web/packages/latex2exp/vignettes/using-latex2exp.html安装说明。

有关python依赖关系,请参阅setup.py。您可以使用pipinstall。安装drnet包及其python依赖项。请注意,如果您的系统上没有正常的R安装,rpy2的安装将失败(请参见上文)。

再现实验

确保您具备上面列出的必要要求,包括/与此文件相关的数据目录以及所需的数据库(参见上文)。

您可以使用脚本drnet/apps/run_all_experiments.py获取main.py使用的精确参数,以重现论文中的实验结果。

drnet/apps/run_all_experiments.py脚本打印。

因果森林

https://grf-labs.github.io/grf/articles/grf.h

return self.grf.causal_forest(x,
                                      FloatVector([float(yy) for yy in y]),
                                      FloatVector([float(tt) for tt in t]), seed=909)

W
The treatment assignment (must be a binary or real numeric vector with no NAs).

W

治疗分配(必须是没有NA的二进制或实数矢量)。

class CausalForest(PickleableMixin, Baseline):
    def __init__(self):
        super(CausalForest, self).__init__()
        self.bart = None

    def install_grf(self):
        from rpy2.robjects.packages import importr
        import rpy2.robjects.packages as rpackages
        from rpy2.robjects.vectors import StrVector
        import rpy2.robjects as robjects

        # robjects.r.options(download_file_method='curl')

        # package_names = ["grf"]
        # utils = rpackages.importr('utils')
        # utils.chooseCRANmirror(ind=0)
        # utils.chooseCRANmirror(ind=0)
        #
        # names_to_install = [x for x in package_names if not rpackages.isinstalled(x)]
        # if len(names_to_install) > 0:
        #     utils.install_packages(StrVector(names_to_install))

        return importr("grf")

    def _build(self, **kwargs):
        from rpy2.robjects import numpy2ri
        from sklearn import linear_model
        grf = self.install_grf()

        self.grf = grf
        numpy2ri.activate()
        num_treatments = kwargs["num_treatments"]
        self.with_exposure = kwargs["with_exposure"]

        return [linear_model.Ridge(alpha=.5)] +\
               [None for _ in range(num_treatments)]

    def predict_for_model(self, model, x):
        base_y = Baseline.predict_for_model(self, self.model[0], x)
        if model == self.model[0]:
            return base_y
        else:
            import rpy2.robjects as robjects
            r = robjects.r
            result = r.predict(model, self.preprocess(x))
            y = np.array(result[0])
            return y[:, -1] + base_y

    def fit_grf_model(self, x, t, y):
        from rpy2.robjects.vectors import StrVector, FactorVector, FloatVector, IntVector
        return self.grf.causal_forest(x,
                                      FloatVector([float(yy) for yy in y]),
                                      FloatVector([float(tt) for tt in t]), seed=909)

参考:

  1. 《因果学习周刊》第8期:因果反事实预测 - 知乎
  2. 因果效应估计:用数据和模型指导决策 - 知乎

http://www.kler.cn/a/10239.html

相关文章:

  • Qt初识简单使用Qt
  • 国标GB28181视频平台EasyCVR私有化部署视频平台对接监控录像机NVR时,录像机“资源不足”是什么原因?
  • 在Linux上部署(MySQL Redis Elasticsearch等)各类软件
  • 【非关系型数据库】【IOT设备】InfluxDB、TimescaleDB、Cassandra和MongoDB
  • FPGA学习笔记#5 Vitis HLS For循环的优化(1)
  • pytorch实现深度神经网络DNN与卷积神经网络CNN
  • 如果让你做技术负责人,你会怎么设计后端架构?
  • 查看 Elasticsearch 分析器
  • selenium库有哪些功能呢?都是如何实现的呢?
  • ( “树” 之 DFS) 543. 二叉树的直径 ——【Leetcode每日一题】
  • Git的安装与基本使用
  • 2021蓝桥杯真题大写 C语言/C++
  • 计算机网络笔记(横向)
  • 代码随想录算法训练营第三十四天-贪心算法3| 1005.K次取反后最大化的数组和 134. 加油站 135. 分发糖果
  • 微服务+springcloud+springcloud alibaba学习笔记【Eureka服务注册中心】(3/9)
  • C++标准库--IO库(Primer C++ 第五版 · 阅读笔记)
  • 离散数学_第二章:基本结构:集合、函数、序列、求和和矩阵(1)
  • 探索树形数据结构,通识树、森林与二叉树的基础知识(专有名词),进一步利用顺序表和链表表示、遍历和线索树形结构
  • 梯度的看法
  • MyBatis配置文件 —— 相关标签详解
  • 干翻Hadoop系列之:Hadoop前瞻之分布式知识
  • Leetcode.1992 找到所有的农场组
  • NumPy 秘籍中文第二版:十、Scikits 的乐趣
  • vue3+TS+Pinia+Vite项目实战之一
  • 程序员的日常瞎想,个人规划,和企业把控之间的微妙关系。职场人你懂!!
  • WPF MVVM模式构建项目