From bd139b1b74b95d1064b44e0a21d99779c65f3ba6 Mon Sep 17 00:00:00 2001 From: Emil Hvitfeldt Date: Mon, 16 Dec 2024 09:58:36 -0800 Subject: [PATCH] add .extract_xgb_trees --- NAMESPACE | 1 + R/model-xgboost.R | 45 +++++++++++++++++++++++++++++++++++++++------ 2 files changed, 40 insertions(+), 6 deletions(-) diff --git a/NAMESPACE b/NAMESPACE index 31453ea..723c5fe 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -39,6 +39,7 @@ S3method(tidypredict_test,party) S3method(tidypredict_test,randomForest) S3method(tidypredict_test,ranger) S3method(tidypredict_test,xgb.Booster) +export(.extract_xgb_trees) export(acceptable_formula) export(as_parsed_model) export(parse_model) diff --git a/R/model-xgboost.R b/R/model-xgboost.R index 5355689..bf9ca0f 100644 --- a/R/model-xgboost.R +++ b/R/model-xgboost.R @@ -48,17 +48,17 @@ get_xgb_tree <- function(tree) { x } -get_xgb_trees <- function(model) { +get_xgb_trees <- function(model, filter_trees = TRUE) { xd <- xgboost::xgb.dump( model = model, dump_format = "text", with_stats = TRUE ) feature_names <- model$feature_names - get_xgb_trees_character(xd, feature_names) + get_xgb_trees_character(xd, feature_names, filter_trees) } -get_xgb_trees_character <- function(xd, feature_names) { +get_xgb_trees_character <- function(xd, feature_names, filter_trees) { feature_names_tbl <- data.frame( Feature = as.character(0:(length(feature_names) - 1)), feature_name = feature_names, @@ -75,10 +75,12 @@ get_xgb_trees_character <- function(xd, feature_names) { lapply(trees[, c("Yes", "No", "Missing")], function(x) as.integer(x) + 1) trees_split <- split(trees, trees$Tree) - trees_rows <- purrr::map_dbl(trees_split, nrow) - trees_filtered <- trees_split[trees_rows > 1] + if (filter_trees) { + trees_rows <- purrr::map_dbl(trees_split, nrow) + trees_split <- trees_split[trees_rows > 1] + } - purrr::map(trees_filtered, get_xgb_tree) + purrr::map(trees_split, get_xgb_tree) } #' @export @@ -181,3 +183,34 @@ tidypredict_fit.xgb.Booster <- function(model) { parsedmodel <- parse_model(model) build_fit_formula_xgb(parsedmodel) } + +# For {orbital} +#' @keywords internal +#' @export +.extract_xgb_trees <- function(model) { + if (!inherits(model, "xgb.Booster")) { + cli::cli_abort( + "{.arg model} must be {.cls xgb.Booster}, not {.obj_type_friendly {x}}." + ) + } + + params <- model$params + wosilent <- params[names(params) != "silent"] + wosilent$silent <- params$silent + + pm <- list() + pm$general$model <- "xgb.Booster" + pm$general$type <- "xgb" + pm$general$niter <- model$niter + pm$general$params <- wosilent + pm$general$feature_names <- model$feature_names + pm$general$nfeatures <- model$nfeatures + pm$general$version <- 1 + pm$trees <- get_xgb_trees(model, filter_trees = FALSE) + + parsedmodel <- as_parsed_model(pm) + map( + seq_len(length(parsedmodel$trees)), + ~ expr(case_when(!!!get_xgb_case_tree(.x, parsedmodel))) + ) +}