当前位置: 首页 > article >正文

ExplaineR:集成K-means聚类算法的SHAP可解释性分析 | 可视化混淆矩阵、决策曲线、模型评估与各类SHAP图

集成K-means聚类算法的SHAP可解释性分析

  • 加载数据集并训练机器学习模型

    • SHAP 分析以提取特征对预测的影响

    • 通过混淆矩阵可视化模型性能

    • 决策曲线分析

    • 模型评估(多指标和ROC曲线的目视检查)

    • 带注释阈值的 ROC 曲线

    • 加载 SHAP 结果以进行下游分析

    • 与特征值关联的 SHAP 值

  • 特征的部分依赖图

  • 提取特征值和预测概率作为模型的输出进行分析

  • 运行一个 Shiny 应用程序来可视化 2-way 部分依赖图

  • 由 SHAP 集群确定的患者亚组

  • 模型公平性(敏感性分析)

  • 模型参数

引言

SHAP 可解释分析乃是一种被广泛运用的模型解释方法。然而,因缺乏专门的软件工具,其全部潜力常常未能得到充分开发。对此,在此介绍ExplaineR,这是一个 R 包,其作用在于促进基于用于 SHAP 分析的聚类功能的二分类和回归模型的解释。该工具还在可视化中提供用户交互元素,可用于评估模型性能、进行公平性分析、决策曲线分析以及生成各种SHAP图。它有助于对模型进行深入的预测后分析,使用户能够找出 SHAP 图中潜在的重要模式,进而通过 SHAP 聚类将其追溯至实例。

该软件包的核心功能在于将 k-means 聚类集成到 SHAP 分析中,并提供对重要但被忽视的模型评估方法(如公平性分析和决策曲线分析)的统一访问。

SHAP 聚类在基于临床数据的疾病结果预测中很有价值。例如,在二元分类任务中,根据特征的 SHAP 值将测试集上的样本聚类为三个聚类后,这些聚类应表征三个患者亚组。一个集群将包括样本,其中大多数被预测为具有患病的高风险(反映在较高的 SHAP 值上)。同样,患病风险较低的患者亚组将以第二个集群为特征。在实践中,模型有一些样本的预测较弱或不确定(即预测概率趋向于机会水平)。来自不确定预测的患者亚组预计将由第三个集群表征。这既可用于模型和数据诊断,也可用于模型的泛化,在模型推广中,人们可以根据其特征向量与通过 SHAP 聚类方法识别的已知患者亚组的接近程度来预测模型是否对新患者表现良好。

除了特征重要性的大小外,SHAP 分析还指示特征影响的方向。此方向信息对于了解特定特征值的增加或减少是否会导致模型预测的增加或减少至关重要。SHAP 值还可以反映非线性和特征交互,但受加性假设的约束,这意味着每个特征对预测的贡献是独立考虑的。因此,在模型解释中应考虑这些优点和局限性。

ExplaineR工作流程图

使用 ExplaineR 包进行机器学习分析的工作流程概述,该软件包在基于 Wisconsin Breast Cancer 数据集的随机森林模型(二元分类)上展示。在左侧面板上,显示了训练集和测试集的混淆矩阵,以及用于比较模型净收益与备选方案的图。在右侧面板上,显示了由 SHAP 聚类生成的三个样本聚类的 SHAP 汇总图及其相应的混淆矩阵。

ExplaineR

它可以通过 Shapley 分析(包括个体亚组的数据驱动特征)对复杂的分类和回归模型进行详细解释。此外,它还有助于多度量模型评估、模型公平性和决策曲线分析。此外,它还通过交互式元素提供增强的可视化效果。通过查看包教程了解如何使用该包。

你可以从CRAN安装该软件包的已发布版本。

install.packages("explainer")

也可以安装来自 GitHub 的开发版本:

install.packages("devtools")
devtools::install_github("PERSIMUNE/explainer")

使用ExplaineR解释分类模型

加载数据集并训练机器学习模型

第一个代码块加载数据集并创建二元分类任务,并使用 mlr3 包训练“随机森林”模型。

Sys.setenv(LANG = "en") # 将 R 语言更改为英语!
RNGkind("L'Ecuyer-CMRG") # 更改为 L'Ecuyer-CMRG,以防它使用默认的“Mersenne-Twister”

library("explainer")
# 设置种子以实现可重复性
seed <- 246
set.seed(seed)

# 从 mlbench 包加载 BreastCancer 数据
data("BreastCancer", package = "mlbench")

# 将目标列保留为“Class”
target_col <- "Class"

# 将阳性类别更改为“恶性”
positive_class <- "malignant"
    
# 仅保留预测变量和结果
mydata <- BreastCancer[, -1] # 1 is ID

# 删除有缺失值的行
mydata <- na.omit(mydata)

# 创建性别类别向量
sex <- sample(c("Male", "Female"), size = nrow(mydata), replace = TRUE)

# 创建性别类别向量
mydata$age <- as.numeric(sample(seq(18,60), size = nrow(mydata), replace = TRUE))

# 在 mydata 数据框中添加性别列(用于公平性分析)
mydata$sex <- factor(sex, levels = c("Male", "Female"), labels = c(1, 0))

# 创建分类任务
maintask <- mlr3::TaskClassif$new(id = "my_classification_task",
                                  backend = mydata,
                                  target = target_col,
                                  positive = positive_class)

# 创建一个训练-测试分割
set.seed(seed)
splits <- mlr3::partition(maintask)

# 添加机器学习(机器学习模型库)
# library("mlr3learners")
# library("mlr3extralearners")

# mlr_learners$get("classif.randomForest")
# 这里我们使用随机森林为例(您可以使用任何其他可用的模型)
# mylrn <- mlr3::lrn("classif.randomForest", predict_type = "prob")
library("mlr3learners")
mylrn <- mlr3::lrn("classif.ranger", predict_type = "prob")

# 训练模型
mylrn$train(maintask, splits$train)

# 对新数据进行预测
mylrn$predict(maintask, splits$test)
## <PredictionClassif> for 225 observations:
##  row_ids     truth  response prob.malignant prob.benign
##        2    benign malignant     0.86798175  0.13201825
##        5    benign    benign     0.00922619  0.99077381
##        7    benign    benign     0.35852381  0.64147619
##      ---       ---       ---            ---         ---
##      671    benign    benign     0.00000000  1.00000000
##      675    benign    benign     0.00230000  0.99770000
##      681 malignant malignant     0.91511905  0.08488095

SHAP 分析以提取特征对预测的影响

以下代码块使用 eSHAP_plot 函数来估计测试集的 SHAP 值并创建交互式 SHAP 图。这是增强的 SHAP 图,意味着它提供了附加信息,例如预测是否正确(TP 或 TN)。颜色映射提供了对 SHAP 图的增强视觉检查。

library("magrittr")
library("plotly")
# 增强型SHAP图
SHAP_output <- eSHAP_plot(task = maintask,
           trained_model = mylrn,
           splits = splits,
           sample.size = 30,
           seed = seed,
           subset = .8)

# 显示SHAP图
myplot <- SHAP_output[[1]]
myplot

下图显示了与预测概率相关的 SHAP 值。

SHAP_output[[5]]

通过混淆矩阵可视化模型性能

以下代码块使用 eCM_plot 函数可视化混淆矩阵,以评估训练集和测试集的模型性能

# 增强型混淆矩阵
confusionmatrix_plot <- eCM_plot(task = maintask,
         trained_model = mylrn,
         splits = splits)
print(confusionmatrix_plot)

测试集

训练集

决策曲线分析

提供的代码块使用 eDecisionCurve 函数对模型内的测试集进行“决策曲线分析”。

决策曲线分析是一种通过变动概率阈值,以净收益对应阈值概率绘制图形,来评估事件预测变量的方法,它与传统统计法的区别在于能够考量预测因子的临床价值,借此确定使用该预测因子做临床决策,相较于其他决策标准,是否在给定阈值概率下更具效益。

# 增强型决策曲线图
eDecisionCurve(task = maintask,
         trained_model = mylrn,
         splits = splits,
         seed = seed)

模型评估(多指标和ROC曲线的目视检查)

通过运行下一个代码块,您将获得以下模型评估指标和可视化效果:

AUC(曲线下面积):AUC 通过评估 ROC 曲线下的面积来量化二元分类模型的性能,该曲线绘制了跨不同阈值的灵敏度与 1-特异性的关系。值为 0.5 表示随机机会性能,而 1 表示完美分类。

BACC(平衡精度):BACC 通过平均敏感性和特异性来解决类别不平衡问题。范围从 0 到 1,0 分表示机会表现,1 表示完美分类。

MCC(马修斯相关系数):MCC 评估二元分类模型的质量,考虑真阳性、真阴性、假阳性和假阴性。范围从-1到1,-1表示完全不同意,0表示机会表现,1表示完美分类。

BBRIER(布赖尔分数):BBRIER 通过测量预测概率与真实二元结果之间的均方差来衡量概率预测的准确性。值范围从 0 到 1,其中 0 表示校准完美,1 表示校准不良。

PPV(阳性预测值):PPV(即精度)衡量模型做出的所有阳性预测中真正阳性预测的比例。

NPV(负预测值):NPV 量化模型做出的所有负面预测中真正负面预测的比例。

Specificity:特异性计算二元分类问题中真实阴性预测占所有实际阴性案例的比例。

Sensitivity:灵敏度,也称为召回率或真阳性率,衡量二元分类问题中真阳性预测占所有实际阳性案例的比例。

PRAUC(曲线下精确召回面积):PRAUC 根据精度和召回率评估二元分类模型的性能,量化精度-召回率曲线下的面积。 PRAUC 值为 1 表示分类性能完美。

此外,分析还涉及开发和测试集的 ROC 和精确召回曲线的可视化。

eROC_plot(task = maintask,
         trained_model = mylrn,
         splits = splits)
## [[2]]
##             pred_results$score(measures = mlr3::msrs(meas))
## auc                                                    1.00
## bacc                                                   0.99
## mcc                                                    0.98
## bbrier                                                 0.01
## ppv                                                    0.98
## npv                                                    1.00
## specificity                                            0.99
## sensitivity                                            0.99
## prauc                                                  1.00
## 
## [[3]]
##             pred_results_test$score(measures = mlr3::msrs(meas))
## auc                                                         0.99
## bacc                                                        0.97
## mcc                                                         0.92
## bbrier                                                      0.04
## ppv                                                         0.93
## npv                                                         0.99
## specificity                                                 0.96
## sensitivity                                                 0.97
## prauc                                                       0.97

带注释阈值的 ROC 曲线

通过运行下一个代码块,您将获得这次开发和测试集的 ROC 和 Precision Recall 曲线以及概率阈值信息。

ePerformance(task = maintask,
         trained_model = mylrn,
         splits = splits)

加载 SHAP 结果以进行下游分析

现在我们可以从 eSHAP_plot 函数获取输出,以对 SHAP 值应用聚类

shap_Mean_wide <- SHAP_output[[2]]

shap_Mean_long <- SHAP_output[[3]]

shap <- SHAP_output[[4]]

与特征值关联的 SHAP 值

ShapFeaturePlot(shap_Mean_long)

特征的部分依赖图

部分依赖图Partial dependence plots(PDP):PDP可用于可视化单个特征对模型预测的边际效应。

ShapPartialPlot(shap_Mean_long = shap_Mean_long)

提取特征值和预测概率作为模型的输出进行分析

fval_predprob <- reshape2::dcast(shap, sample_num + pred_prob + predcorrectness ~ feature, value.var = "feature.value")

运行一个 Shiny 应用程序来可视化 2-way 部分依赖图

library(shiny)
library(ggplot2)
# 假设你的数据表格名为`fval_predprob`。
ui <- fluidPage(
  titlePanel("Feature Plot"),
  sidebarLayout(
    sidebarPanel(
      selectInput("x_feature", "X-axis Feature:",
                  choices = names(fval_predprob)),
      selectInput("y_feature", "Y-axis Feature:",
                  choices = names(fval_predprob)),
      width = 2,
      actionButton("stop_button", "Stop")
    ),
    mainPanel(
      plotOutput("feature_plot")
    )
  )
)

server <- function(input, output) {
  # 为应用程序状态创建一个响应式值。
  app_state <- reactiveValues(running = TRUE)
  
  output$feature_plot <- renderPlot({
    ggplot(fval_predprob, aes(x = .data[[input$x_feature]], y = .data[[input$y_feature]], fill = pred_prob)) + 
      geom_tile() +
      scale_fill_gradient(low = "white", high = "steelblue", limits = c(0, 1)) +
      xlab(input$x_feature) +
      ylab(input$y_feature) +
      labs(shape = "correct prediction") +
      egg::theme_article()
  })
  
  # observe the stop button
  observeEvent(input$stop_button, {
    app_state$running <- FALSE
  })
  
# 如果运行状态为假,则停止应用程序。
  observe({
    if(!app_state$running) {
      stopApp()
    }
  })
}

shinyApp(ui = ui, server = server)
str(shap_Mean_long)
## Classes 'data.table' and 'data.frame':   1980 obs. of  9 variables:
##  $ feature           : chr  "Bare.nuclei" "Bare.nuclei" "Bare.nuclei" "Bare.nuclei" ...
##  $ mean_phi          : num  0.0934 0.0934 0.0934 0.0934 0.0934 ...
##  $ Phi               : num  -0.0859 -0.0821 -0.0317 0.0997 0.0817 ...
##  $ f_val             : num  0 0.111 0 0.444 0.556 ...
##  $ unscaled_f_val    : num  1 2 1 5 6 1 1 3 1 1 ...
##  $ sample_num        : int  1 2 3 4 5 6 7 8 9 10 ...
##  $ correct_prediction: Factor w/ 2 levels "Incorrect","Correct": 2 2 2 1 2 2 2 2 2 2 ...
##  $ pred_prob         : num  0.00255 0.00182 0 0.92977 0.9283 ...
##  $ pred_class        : Factor w/ 2 levels "malignant","benign": 2 2 2 1 1 2 2 1 2 2 ...
##  - attr(*, ".internal.selfref")=<externalptr>
# 这是一个以宽格式包含 SHAP 值的数据表。每一行包含 sample_num 作为样本 ID 以及每个特征的 SHAP 值。
shap_Mean_wide
## Key: <sample_num>
##      sample_num Bare.nuclei  Bl.cromatin  Cell.shape   Cell.size Cl.thickness
##           <int>       <num>        <num>       <num>       <num>        <num>
##   1:          1 -0.08589487 -0.069985767 -0.04592299 -0.05557913 -0.016991720
##   2:          2 -0.08207815 -0.043686534 -0.06930323 -0.07851767 -0.040950132
##   3:          3 -0.03167399 -0.041501508 -0.05516873 -0.07578722 -0.026012011
##   4:          4  0.09965638  0.100056508  0.16292108  0.13855823 -0.003343413
##   5:          5  0.08170963  0.061670635  0.09991013  0.11153302 -0.024341667
##  ---                                                                         
## 176:        176 -0.10101344 -0.022493651 -0.09723053 -0.11001489 -0.036337381
## 177:        177 -0.07544257 -0.050335159 -0.07331122 -0.07960852 -0.052942063
## 178:        178 -0.07287780 -0.023547804 -0.08396833 -0.07580437 -0.065089868
## 179:        179 -0.12122407  0.054544259 -0.06030249 -0.03421050 -0.088212037
## 180:        180  0.17650148  0.008577434  0.08656048  0.09603124  0.118961111
##      Epith.c.size Marg.adhesion       Mitoses Normal.nucleoli           age
##             <num>         <num>         <num>           <num>         <num>
##   1: -0.027082751   0.002213201 -0.0022784127    -0.028862540 -0.0037774339
##   2: -0.010480000  -0.018519815 -0.0023586243    -0.016425476 -0.0006488360
##   3: -0.024506111  -0.020070503 -0.0010957672    -0.015432884 -0.0008956349
##   4:  0.076623836   0.041688201 -0.0006521693     0.017386984  0.0067640741
##   5:  0.038122116   0.030285079  0.0018288624     0.029264021 -0.0230379101
##  ---                                                                       
## 176: -0.029844048  -0.031585979  0.0005687831    -0.027441481 -0.0015749206
## 177: -0.024546667  -0.022700053 -0.0008221164    -0.020056270 -0.0021401058
## 178: -0.020292487  -0.024337857 -0.0018688889    -0.019011217  0.0006503175
## 179: -0.030457143  -0.004646905 -0.0040282540    -0.008570794  0.0014429101
## 180: -0.006486825   0.047554444  0.0024593386     0.034816376 -0.0157181481
##                sex
##              <num>
##   1:  0.0003026190
##   2: -0.0004190476
##   3:  0.0015755556
##   4:  0.0008892857
##   5:  0.0014779101
##  ---              
## 176:  0.0011694974
## 177: -0.0014722487
## 178:  0.0008101587
## 179: -0.0010852910
## 180: -0.0008616667
# 这是 SHAP 值的长格式数据表,包括特征名称、mean_phi_test(样本中特征的平均 SHAP 值)、缩放后的特征值、每个特征的 SHAP 值、样本 ID 以及预测是否正确(即预测类别等于实际类别)。
shap_Mean_long
##           feature    mean_phi           Phi     f_val unscaled_f_val sample_num
##            <char>       <num>         <num>     <num>          <num>      <int>
##    1: Bare.nuclei 0.093444244 -0.0858948677 0.0000000              1          1
##    2: Bare.nuclei 0.093444244 -0.0820781481 0.1111111              2          2
##    3: Bare.nuclei 0.093444244 -0.0316739947 0.0000000              1          3
##    4: Bare.nuclei 0.093444244  0.0996563757 0.4444444              5          4
##    5: Bare.nuclei 0.093444244  0.0817096296 0.5555556              6          5
##   ---                                                                          
## 1976:         sex 0.001183048  0.0011694974 1.0000000              2        176
## 1977:         sex 0.001183048 -0.0014722487 0.0000000              1        177
## 1978:         sex 0.001183048  0.0008101587 1.0000000              2        178
## 1979:         sex 0.001183048 -0.0010852910 0.0000000              1        179
## 1980:         sex 0.001183048 -0.0008616667 0.0000000              1        180
##       correct_prediction   pred_prob pred_class
##                   <fctr>       <num>     <fctr>
##    1:            Correct 0.002554762     benign
##    2:            Correct 0.001822222     benign
##    3:            Correct 0.000000000     benign
##    4:          Incorrect 0.929772222  malignant
##    5:            Correct 0.928304762  malignant
##   ---                                          
## 1976:            Correct 0.000000000     benign
## 1977:            Correct 0.000000000     benign
## 1978:            Correct 0.000000000     benign
## 1979:            Correct 0.120114286     benign
## 1980:            Correct 0.926616667  malignant

由 SHAP 集群确定的患者亚组

SHAP 聚类是一种更好地理解为什么模型对于某些患者可能比其他患者表现更好的方法。例如,在这里,您可以识别具有与其他子组不同的特定模式的患者子组,这解释了为什么您开发的模型对这些患者的性能可能比整个数据集的平均性能更好或更差。您可以看到 SHAP 图的差异,如果将所有图分组在一起,则提供总体 SHAP 摘要图。这里的边缘再次反映了每个单独样本(实例)中特征如何相互作用。

# 聚类的数量可以改变。
SHAP_plot_clusters <- SHAPclust(task = maintask,
         trained_model = mylrn,
         splits = splits,
         shap_Mean_wide = shap_Mean_wide,
         shap_Mean_long = shap_Mean_long,
         num_of_clusters = 3,
         seed = seed,
         subset = .8)

# 请注意,子集必须与先前进行的 SHAP 分析具有相同的值。
# 显示 SHAP 聚类图。
SHAP_plot_clusters[[1]]
## Key: <sample_num, feature, Phi>
##       sample_num         feature           Phi cluster    mean_phi     f_val
##            <int>          <char>         <num>   <int>       <num>     <num>
##    1:          1     Bare.nuclei -0.0858948677       3 0.093444244 0.0000000
##    2:          1     Bl.cromatin -0.0699857672       3 0.044548987 0.0000000
##    3:          1      Cell.shape -0.0459229894       3 0.076720237 0.0000000
##    4:          1       Cell.size -0.0555791270       3 0.091198097 0.0000000
##    5:          1    Cl.thickness -0.0169917196       3 0.069835393 0.5000000
##   ---                                                                       
## 1976:        180   Marg.adhesion  0.0475544444       2 0.026453762 1.0000000
## 1977:        180         Mitoses  0.0024593386       2 0.002165296 0.1250000
## 1978:        180 Normal.nucleoli  0.0348163757       2 0.021391343 0.7777778
## 1979:        180             age -0.0157181481       2 0.003230698 0.0000000
## 1980:        180             sex -0.0008616667       2 0.001183048 0.0000000
##       unscaled_f_val correct_prediction   pred_prob pred_class
##                <num>             <fctr>       <num>     <fctr>
##    1:              1            Correct 0.002554762     benign
##    2:              1            Correct 0.002554762     benign
##    3:              1            Correct 0.002554762     benign
##    4:              1            Correct 0.002554762     benign
##    5:              5            Correct 0.002554762     benign
##   ---                                                         
## 1976:             10            Correct 0.926616667  malignant
## 1977:              2            Correct 0.926616667  malignant
## 1978:              8            Correct 0.926616667  malignant
## 1979:             18            Correct 0.926616667  malignant
## 1980:              1            Correct 0.926616667  malignant

# 显示与 SHAP 聚类相对应的混淆矩阵(由 SHAP 聚类确定的患者子集)
SHAP_plot_clusters[[2]]

模型公平性(敏感性分析)

有时我们想调查我们的模型对于基于性别等变量类别的不同子组是否表现良好。

# 你应该决定使用哪些变量进行测试
# 这里我们从数据集中现有的变量中选择性别
Fairness_results <- eFairness(task = maintask,
         trained_model = mylrn,
         splits = splits,
         target_variable = "sex",
         var_levels = c("Male", "Female"))

# 开发(左)和测试(右)集子组的 ROC 曲线
Fairness_results[[1]]

# 开发集子组中的表现
Fairness_results[[2]]
##             Male Female
## auc         1.00   1.00
## bacc        0.99   0.99
## mcc         0.97   0.98
## bbrier      0.01   0.01
## ppv         0.96   0.99
## npv         1.00   0.99
## specificity 0.98   0.99
## sensitivity 1.00   0.99
## prauc       1.00   1.00
# 测试集子组中的表现
Fairness_results[[3]]
##             Male Female
## auc         0.98   0.99
## bacc        0.98   0.96
## mcc         0.94   0.91
## bbrier      0.04   0.03
## ppv         0.92   0.93
## npv         1.00   0.97
## specificity 0.96   0.96
## sensitivity 1.00   0.95
## prauc       0.96   0.98

模型参数

# 获取模型参数
model_params <- mylrn$param_set
print(data.table::as.data.table(model_params))

参考资料:Zargari Marandi, Ramtin. “ExplaineR: an R package to explain machine learning models.” Bioinformatics Advances 4, no. 1 (2024): vbae049.doi:10.1093/bioadv/vbae049


http://www.kler.cn/a/503391.html

相关文章:

  • Python在Excel工作表中创建数据透视表
  • 学英语学Elasticsearch:04 Elastic integrations 工具箱实现对第三方数据源的采集、存储、可视化,开箱即用
  • 理解机器学习中的参数和超参数
  • Vue Diff 算法完全解析
  • C++并发编程之跨应用程序与驱动程序的单生产者单消费者队列
  • 在 Ubuntu 上安装和配置 Redis
  • 【机器学习】实战:天池工业蒸汽量项目(一)数据探索
  • 【案例81】NMC调用导致数据库的效率问题
  • 深度学习中的学习率调度器(scheduler)分析并作图查看各方法差异
  • 【CI/CD构建】关于不小心将springMVC注解写在service层
  • 利用Python爬虫按图搜索1688商品(拍立淘):开启智能购物新体验
  • RIP协议在简单网络架构的使用
  • 工具推荐:PDFgear——免费且强大的PDF编辑工具 v2.1.12
  • QT中,在子线程中更新UI,会出现哪些问题,如何避免这种情况发生。
  • 全面掌握AI提示词的艺术:从基础到高级的深度探索
  • 使用Selenium进行网页自动化测试
  • jupyter ai 结合local llm 实现思路
  • 复健第一天之[SWPUCTF 2022 新生赛]奇妙的MD5
  • 【Vue3 入门到实战】1. 创建Vue3工程
  • 信创改造-龙蜥操作系统搭载MySql、Tomcat等服务
  • 微信小程序获取当前页面路径,登录成功后重定向回原页面
  • 使用Flink-JDBC将数据同步到Doris
  • 【华为路由/交换机的telnet远程设置】
  • 重邮+数字信号处理实验七:用 MATLAB 设计 IIR 数字滤波器
  • DATACOM-防火墙-复习-实验
  • Swift语言的软件工程