当前位置: 首页 > article >正文

XGB-3: 模型IO

在XGBoost 1.0.0中,引入了对使用JSON保存/加载XGBoost模型和相关超参数的支持,旨在用一个可以轻松重用的开放格式取代旧的二进制内部格式。后来在XGBoost 1.6.0中,还添加了对通用二进制JSON的额外支持,作为更高效的模型IO的优化。它们具有相同的文档结构,但具有不同的表示形式,但都统称为JSON格式。本教程旨在分享一些关于XGBoost中使用的JSON序列化方法的基本见解。除非明确说明,以下各节假定正在使用2个输出格式之一,可以通过在保存/加载模型时提供带有.json(或二进制JSON的.ubj)文件扩展名的文件名来启用这两种格式:booster.save_model('model.json')

在开始之前,需要说明的是,XGBoost是一个以树模型为重点的梯度提升库,这意味着在XGBoost内部有两个明显的部分:

  • 由树组成的模型
  • 用于构建模型的超参数和配置

如果是专注于深度学习领域,那么应该清楚由固定张量操作的权重组成的神经网络结构与用于训练它们的优化器(例如RMSprop)之间存在差异。

因此,当调用 booster.save_model(在R中是 xgb.save)时,XGBoost会保存树、一些模型参数(例如在训练树中的输入列数)以及目标函数,这些组合在一起代表了XGBoost中的“模型”概念。至于为什么将目标函数保存为模型的一部分,原因是目标函数控制全局偏差的转换(在XGBoost中称为base_score)。用户可以与他人共享此模型,用于预测、评估或使用不同的超参数集继续训练等

有些情况下,需要保存的不仅仅是模型本身。例如,在分布式训练中,XGBoost执行检查点操作。或者由于某些原因,分布式计算框架决定将模型从一个工作节点复制到另一个工作节点,并在那里继续训练。在这种情况下,序列化输出需要包含足够的信息,以便在不需要用户再次提供任何参数的情况下继续以前的训练。将这种情景视为内存快照( memory snapshot或基于内存的序列化方法),并将其与普通的模型IO操作区分开来。目前,内存快照用于以下情况:

  • Python:使用内置的pickle模块对Booster对象进行pickle
  • R:使用内置函数saveRDS或save对xgb.Booster对象进行持久化
  • JVM:使用内置函数saveModelBooster对象进行序列化

注意:

旧的二进制格式不能区分模型和原始内存序列化格式的差异,它是一切的混合体。JVM包有其自己的基于内存的序列化方法。

为了启用模型 IO 的 JSON 格式支持(仅保存树和目标),请在文件名中使用 .json.ubj 作为文件扩展名,后者是通用二进制 JSON 的扩展名。

  • Python
bst.save_model('model_file_name.json')
  • R
xgb.save(bst, 'model_file_name.json')
  • Scala
val format = "json"  // or val format = "ubj"
model.write.option("format", format).save("model_directory_path")

注意:

仅从由 XGBoost 生成的 JSON 文件加载模型。尝试加载由外部来源生成的 JSON 文件可能导致未定义的行为和崩溃。

关于模型和内存快照的向后兼容性说明

保证模型的向后兼容性,但不保证内存快照的向后兼容性。

模型(树和目标)使用稳定的表示,因此在较早版本的 XGBoost 中生成的模型可以在较新版本的 XGBoost 中访问。如果希望将模型存储或存档以供长期存储,请使用 save_model(Python)和 xgb.save(R)。

另一方面,内存快照(序列化)捕获了 XGBoost 内部的许多内容,其格式不稳定且可能经常更改。因此,内存快照仅适用于检查点,可以持久保存训练配置的完整快照,以便可以从可能的故障中强大地恢复并恢复训练过程。加载由较早版本的 XGBoost 生成的内存快照可能会导致错误或未定义的行为。如果使用 pickle.dump(Python)或 saveRDS(R)持久保存模型,则该模型可能无法在较新版本的 XGBoost 中访问。

自定义目标和度量标准

XGBoost支持用户提供的自定义目标和度量标准函数作为扩展。这些函数不会保存在模型文件中,因为它们是与语言相关的特性。在Python中,用户可以使用pickle将这些函数包含在保存的二进制文件中。其中一个缺点是,pickle输出不是稳定的序列化格式,在不同的Python版本和XGBoost版本上都无法使用,更不用说在不同的语言环境中了。解决此限制的另一种方法是在加载模型后再次提供这些函数。如果定制的函数很有用,请考虑创建一个PR(Pull Request)在XGBoost内部实现它,这样就可以在不同的语言绑定中使用定制的函数。

加载来自不同版本XGBoost的pickled文件

如前所述,pickle模型既不具备可移植性,也不稳定,但在某些情况下,pickled模型是有价值的。将其在将来恢复的一种方法是使用特定版本的Python和XGBoost将其加载回来,然后通过调用save_model导出模型。

可以使用类似的过程来恢复保存在旧RDS文件中的模型。在R中,可以使用remotes包安装旧版本的XGBoost:

library(remotes)
remotes::install_version("xgboost", "0.90.0.1")   # 安装版本0.90.0.1

安装所需的版本后,可以使用readRDS加载RDS文件并恢复xgb.Booster对象。然后,调用xgb.save以使用稳定表示导出模型,就能够在最新版本的XGBoost中使用该模型。

  • Python
import xgboost as xgb

bst = xgb.Booster({'nthread': 4}) 
bst.load_model('model_file_name.json')  # load xgb model

preds = bst.predict(xgb.DMatrix(X_test))  # predict if x_test is not DMatrix format
print(preds)

保存和加载内部参数配置

XGBoost的C APIPython APIR API支持直接将内部配置保存和加载为JSON字符串。在Python包中:

bst = xgboost.train(...)
config = bst.save_config()
print(config)

或在R中:

config <- xgb.config(bst)
print(config)

将打印出类似以下的内容(由于太长,以下内容不是实际输出,仅用于演示):

{
  "Learner": {
    "generic_parameter": {
      "device": "cuda:0",
      "gpu_page_size": "0",
      "n_jobs": "0",
      "random_state": "0",
      "seed": "0",
      "seed_per_iteration": "0"
    },
    "gradient_booster": {
      "gbtree_train_param": {
        "num_parallel_tree": "1",
        "process_type": "default",
        "tree_method": "hist",
        "updater": "grow_gpu_hist",
        "updater_seq": "grow_gpu_hist"
      },
      "name": "gbtree",
      "updater": {
        "grow_gpu_hist": {
          "gpu_hist_train_param": {
            "debug_synchronize": "0",
          },
          "train_param": {
            "alpha": "0",
            "cache_opt": "1",
            "colsample_bylevel": "1",
            "colsample_bynode": "1",
            "colsample_bytree": "1",
            "default_direction": "learn",

            ...

            "subsample": "1"
          }
        }
      }
    },
    "learner_train_param": {
      "booster": "gbtree",
      "disable_default_eval_metric": "0",
      "objective": "reg:squarederror"
    },
    "metrics": [],
    "objective": {
      "name": "reg:squarederror",
      "reg_loss_param": {
        "scale_pos_weight": "1"
      }
    }
  },
  "version": [1, 0, 0]
}

可以将其加载回由相同版本的XGBoost生成的模型,方法是:

bst.load_config(config)

保存模型和转储模型之间的区别

XGBoost在Booster对象中有一个名为dump_model的函数,它以可读的格式(如txtjsondot(graphviz))导出模型。它的主要用途是进行模型解释或可视化,不应该加载回XGBoost。JSON版本具有模式Schema 。

保存模型(Save Model): 通过save_model函数,XGBoost将整个模型以二进制格式保存到文件中。这包括模型的树结构、超参数和目标函数等。保存的模型文件可以用于在不同的XGBoost版本之间共享、加载和继续训练。

  • Python
booster.save_model('model.bin')
  • R
xgb.save(booster, 'model.bin')

转储模型(Dump Model): 通过dump_model函数,XGBoost将模型导出为可读的文本、JSON或Graphviz DOT格式,以便进行模型解释、可视化或分析。这是为了方便用户查看模型的结构和特性,而不是用于加载回XGBoost进行进一步的训练或预测。

  • Python
booster.dump_model('model.txt')
  • R
xgb.dump(booster, 'model.txt')

Json Schema

JSON格式的另一个重要特点是有一个详细记录的模式(schema),基于这个模式,用户可以轻松地重用XGBoost输出的模型。以下是输出模型的JSON模式(不是序列化,如上所述将不是稳定的)。有关解析XGBoost树模型的示例,请参见/demo/json-model。请注意“dart” booster 中使用的“weight_drop”字段。XGBoost不直接对树叶进行缩放,而是将权重保存为一个单独的数组

{
  "$schema": "http://json-schema.org/draft-07/schema#",
  "definitions": {
    "gbtree": {
      "type": "object",
      "properties": {
        "name": {
          "const": "gbtree"
        },
        "model": {
          "type": "object",
          "properties": {
            "gbtree_model_param": {
              "$ref": "#/definitions/gbtree_model_param"
            },
            "trees": {
              "type": "array",
              "items": {
                "type": "object",
                "properties": {
                  "tree_param": {
                    "$ref": "#/definitions/tree_param"
                  },
                  "id": {
                    "type": "integer"
                  },
                  "loss_changes": {
                    "type": "array",
                    "items": {
                      "type": "number"
                    }
                  },
                  "sum_hessian": {
                    "type": "array",
                    "items": {
                      "type": "number"
                    }
                  },
                  "base_weights": {
                    "type": "array",
                    "items": {
                      "type": "number"
                    }
                  },
                  "left_children": {
                    "type": "array",
                    "items": {
                      "type": "integer"
                    }
                  },
                  "right_children": {
                    "type": "array",
                    "items": {
                      "type": "integer"
                    }
                  },
                  "parents": {
                    "type": "array",
                    "items": {
                      "type": "integer"
                    }
                  },
                  "split_indices": {
                    "type": "array",
                    "items": {
                      "type": "integer"
                    }
                  },
                  "split_conditions": {
                    "type": "array",
                    "items": {
                      "type": "number"
                    }
                  },
                  "split_type": {
                    "type": "array",
                    "items": {
                      "type": "integer"
                    }
                  },
                  "default_left": {
                    "type": "array",
                    "items": {
                      "type": "integer"
                    }
                  },
                  "categories": {
                    "type": "array",
                    "items": {
                      "type": "integer"
                    }
                  },
                  "categories_nodes": {
                    "type": "array",
                    "items": {
                      "type": "integer"
                    }
                  },
                  "categories_segments": {
                    "type": "array",
                    "items": {
                      "type": "integer"
                    }
                  },
                  "categories_sizes": {
                    "type": "array",
                    "items": {
                      "type": "integer"
                    }
                  }
                },
                "required": [
                  "tree_param",
                  "loss_changes",
                  "sum_hessian",
                  "base_weights",
                  "left_children",
                  "right_children",
                  "parents",
                  "split_indices",
                  "split_conditions",
                  "default_left",
                  "categories",
                  "categories_nodes",
                  "categories_segments",
                  "categories_sizes"
                ]
              }
            },
            "tree_info": {
              "type": "array",
              "items": {
                "type": "integer"
              }
            }
          },
          "required": [
            "gbtree_model_param",
            "trees",
            "tree_info"
          ]
        }
      },
      "required": [
        "name",
        "model"
      ]
    },
    "gbtree_model_param": {
      "type": "object",
      "properties": {
        "num_trees": {
          "type": "string"
        },
        "num_parallel_tree": {
          "type": "string"
        }
      },
      "required": [
        "num_trees",
        "num_parallel_tree"
      ]
    },
    "tree_param": {
      "type": "object",
      "properties": {
        "num_nodes": {
          "type": "string"
        },
        "size_leaf_vector": {
          "type": "string"
        },
        "num_feature": {
          "type": "string"
        }
      },
      "required": [
        "num_nodes",
        "num_feature",
        "size_leaf_vector"
      ]
    },
    "reg_loss_param": {
      "type": "object",
      "properties": {
        "scale_pos_weight": {
          "type": "string"
        }
      }
    },
    "pseudo_huber_param": {
      "type": "object",
      "properties": {
        "huber_slope": {
          "type": "string"
        }
      }
    },
    "aft_loss_param": {
      "type": "object",
      "properties": {
        "aft_loss_distribution": {
          "type": "string"
        },
        "aft_loss_distribution_scale": {
          "type": "string"
        }
      }
    },
    "softmax_multiclass_param": {
      "type": "object",
      "properties": {
        "num_class": { "type": "string" }
      }
    },
    "lambda_rank_param": {
      "type": "object",
      "properties": {
        "num_pairsample": { "type": "string" },
        "fix_list_weight": { "type": "string" }
      }
    },
    "lambdarank_param": {
      "type": "object",
      "properties": {
        "lambdarank_num_pair_per_sample": { "type": "string" },
        "lambdarank_pair_method": { "type": "string" },
        "lambdarank_unbiased": {"type": "string" },
        "lambdarank_bias_norm": {"type": "string" },
        "ndcg_exp_gain": {"type": "string"}
      }
    }
  },
  "type": "object",
  "properties": {
    "version": {
      "type": "array",
      "items": [
        {
          "type": "number",
          "minimum": 1
        },
        {
          "type": "number",
          "minimum": 0
        },
        {
          "type": "number",
          "minimum": 0
        }
      ],
      "minItems": 3,
      "maxItems": 3
    },
    "learner": {
      "type": "object",
      "properties": {
        "feature_names": {
          "type": "array",
          "items": {
              "type": "string"
          }
        },
        "feature_types": {
          "type": "array",
          "items": {
              "type": "string"
          }
        },
        "gradient_booster": {
          "oneOf": [
            {
              "$ref": "#/definitions/gbtree"
            },
            {
              "type": "object",
              "properties": {
                "name": { "const": "gblinear" },
                "model": {
                  "type": "object",
                  "properties": {
                    "weights": {
                      "type": "array",
                      "items": {
                        "type": "number"
                      }
                    }
                  }
                }
              }
            },
            {
              "type": "object",
              "properties": {
                "name": { "const": "dart" },
                "gbtree": {
                  "$ref": "#/definitions/gbtree"
                },
                "weight_drop": {
                  "type": "array",
                  "items": {
                    "type": "number"
                  }
                }
              },
              "required": [
                "name",
                "gbtree",
                "weight_drop"
              ]
            }
          ]
        },

        "objective": {
          "oneOf": [
            {
              "type": "object",
              "properties": {
                "name": { "const": "reg:squarederror" },
                "reg_loss_param": { "$ref": "#/definitions/reg_loss_param"}
              },
              "required": [
                "name",
                "reg_loss_param"
              ]
            },
            {
              "type": "object",
              "properties": {
                "name": { "const": "reg:pseudohubererror" },
                "reg_loss_param": { "$ref": "#/definitions/reg_loss_param"}
              },
              "required": [
                "name",
                "reg_loss_param"
              ]
            },
            {
              "type": "object",
              "properties": {
                "name": { "const": "reg:squaredlogerror" },
                "reg_loss_param": { "$ref": "#/definitions/reg_loss_param"}
              },
              "required": [
                "name",
                "reg_loss_param"
              ]
            },
            {
              "type": "object",
              "properties": {
                "name": { "const": "reg:linear" },
                "reg_loss_param": { "$ref": "#/definitions/reg_loss_param"}
              },
              "required": [
                "name",
                "reg_loss_param"
              ]
            },
            {
              "type": "object",
              "properties": {
                "name": { "const": "reg:logistic" },
                "reg_loss_param": { "$ref": "#/definitions/reg_loss_param"}
              },
              "required": [
                "name",
                "reg_loss_param"
              ]
            },
            {
              "type": "object",
              "properties": {
                "name": { "const": "binary:logistic" },
                "reg_loss_param": { "$ref": "#/definitions/reg_loss_param"}
              },
              "required": [
                "name",
                "reg_loss_param"
              ]
            },
            {
              "type": "object",
              "properties": {
                "name": { "const": "binary:logitraw" },
                "reg_loss_param": { "$ref": "#/definitions/reg_loss_param"}
              },
              "required": [
                "name",
                "reg_loss_param"
              ]
            },
            {
              "type": "object",
              "properties": {
                "name": { "const": "count:poisson" },
                "poisson_regression_param": {
                  "type": "object",
                  "properties": {
                    "max_delta_step": { "type": "string" }
                  }
                }
              },
              "required": [
                "name",
                "poisson_regression_param"
              ]
            },
            {
              "type": "object",
              "properties": {
                "name": { "const": "reg:tweedie" },
                "tweedie_regression_param": {
                  "type": "object",
                  "properties": {
                    "tweedie_variance_power": { "type": "string" }
                  }
                }
              },
              "required": [
                "name",
                "tweedie_regression_param"
              ]
            },
            {
              "properties": {
                "name": {
                  "const": "reg:absoluteerror"
                }
              },
              "type": "object"
            },
            {
              "properties": {
                "name": {
                  "const": "reg:quantileerror"
                },
                "quantile_loss_param": {
                  "type": "object",
                  "properties": {
                    "quantle_alpha": {"type": "array"}
                  }
                }
              },
              "type": "object"
            },
            {
              "type": "object",
              "properties": {
                "name": { "const": "survival:cox" }
              },
              "required": [ "name" ]
            },
            {
              "type": "object",
              "properties": {
                "name": { "const": "reg:gamma" }
              },
              "required": [ "name" ]
            },

            {
              "type": "object",
              "properties": {
                "name": { "const": "multi:softprob" },
                "softmax_multiclass_param": { "$ref": "#/definitions/softmax_multiclass_param"}
              },
              "required": [
                "name",
                "softmax_multiclass_param"
              ]
            },
            {
              "type": "object",
              "properties": {
                "name": { "const": "multi:softmax" },
                "softmax_multiclass_param": { "$ref": "#/definitions/softmax_multiclass_param"}
              },
              "required": [
                "name",
                "softmax_multiclass_param"
              ]
            },

            {
              "type": "object",
              "properties": {
                "name": { "const": "rank:pairwise" },
                "lambda_rank_param": { "$ref": "#/definitions/lambdarank_param"}
              },
              "required": [
                "name",
                "lambdarank_param"
              ]
            },
            {
              "type": "object",
              "properties": {
                "name": { "const": "rank:ndcg" },
                "lambda_rank_param": { "$ref": "#/definitions/lambdarank_param"}
              },
              "required": [
                "name",
                "lambdarank_param"
              ]
            },
            {
              "type": "object",
              "properties": {
                "name": { "const": "rank:map" },
                "lambda_rank_param": { "$ref": "#/definitions/lambda_rank_param"}
              },
              "required": [
                "name",
                "lambda_rank_param"
              ]
            },
            {
              "type": "object",
              "properties": {
                "name": {"const": "survival:aft"},
                "aft_loss_param": { "$ref": "#/definitions/aft_loss_param"}
              }
            },
            {
              "type": "object",
              "properties": {
                "name": {"const": "binary:hinge"}
              }
            }
          ]
        },

        "learner_model_param": {
          "type": "object",
          "properties": {
            "base_score": { "type": "string" },
            "num_class": { "type": "string" },
            "num_feature": { "type": "string" },
            "num_target": { "type": "string" }
          }
        }
      },
      "required": [
        "gradient_booster",
        "objective"
      ]
    }
  },
  "required": [
    "version",
    "learner"
  ]
}

http://www.kler.cn/a/229585.html

相关文章:

  • 深度学习中的常见初始化方法:原理、应用与比较
  • 前端练习题
  • 【JVM-2.3】深入解析JVisualVM:Java性能监控与调优利器
  • 【RedisStack】Linux安装指南
  • 【微服务】面试 7、幂等性
  • 【数据结构-堆】【二分】力扣3296. 移山所需的最少秒数
  • [UI5 常用控件] 06.Splitter,ResponsiveSplitter
  • node环境打包js,webpack和rollup两个打包工具打包,能支持vue
  • SpringBoot中使用Spring自带线程池ThreadPoolTaskExecutor与Java8CompletableFuture实现异步任务示例
  • YOLOv8改进 | 检测头篇 | 独创RFAHead检测头超分辨率重构检测头(适用Pose、分割、目标检测)
  • 深度强化学习基础【1】-动态规划问题初探(leetcode算法的63题-不同路径II)
  • 题目:有1,2,3,4共四个数字,能组成多少个不相同而且无重复数字的三位数有多少个,都是多少?lua
  • 忘记 RAG:拥抱Agent设计,让 ChatGPT 更智能更贴近实际
  • 【数据结构和算法】--- 基于c语言排序算法的实现(1)
  • Elasticsearch:基本 CRUD 操作 - Python
  • PyTorch和TensorFlow的简介
  • 画出TCP三次握手和四次挥手的示意图,并且总结TCP和UDP的区别
  • 数字孪生网络攻防模拟与城市安全演练
  • 使用PDFBox实现pdf转其他图片格式
  • JDWP 简介
  • 勒索病毒最新变种.halo勒索病毒来袭,如何恢复受感染的数据?
  • 商汤科技「日日新4.0」正式发布,多维度升级大模型体系,能力比肩GPT-4!
  • CentOS 中文乱码
  • Google Chrome Close AutoUpdate
  • 小程序:类型三级分类
  • 【51单片机】LED的三个基本项目(LED点亮&LED闪烁&LED流水灯)(3)