Skip to content

Commit

Permalink
Merge pull request #118 from tidymodels/partykit-classification-predi…
Browse files Browse the repository at this point in the history
…ction

update `partykit_tree_info()` to handle classification outputs
  • Loading branch information
EmilHvitfeldt authored Dec 17, 2024
2 parents dde75bd + ae6e973 commit 81a15bc
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 3 deletions.
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ S3method(tidypredict_test,party)
S3method(tidypredict_test,randomForest)
S3method(tidypredict_test,ranger)
S3method(tidypredict_test,xgb.Booster)
export(.extract_partykit_classprob)
export(.extract_xgb_trees)
export(acceptable_formula)
export(as_parsed_model)
Expand Down
80 changes: 77 additions & 3 deletions R/model-partykit.R
Original file line number Diff line number Diff line change
@@ -1,9 +1,22 @@
partykit_tree_info <- function(model) {
model_nodes <- map(seq_along(model), ~ model[[.x]])
is_split <- map_lgl(model_nodes, ~ class(.x$node[1]) == "partynode")
# non-cat model
mean_resp <- map_dbl(model_nodes, ~ mean(.x$fitted[, "(response)"]))
prediction <- ifelse(!is_split, mean_resp, NA)
if (is.numeric(model_nodes[[1]]$fitted[["(response)"]])) {
mean_resp <- map_dbl(model_nodes, ~ mean(.x$fitted[, "(response)"]))
prediction <- ifelse(!is_split, mean_resp, NA)
} else {
stat_mode <- function(x) {
counts <- rev(sort(table(x)))
if (counts[[1]] == counts[[2]]) {
ties <- counts[counts[1] == counts]
return(names(rev(ties))[1])
}
names(counts)[1]
}
mode_resp <- map_chr(model_nodes, ~ stat_mode(.x$fitted[, "(response)"]))
prediction <- ifelse(!is_split, mode_resp, NA)
}

party_nodes <- map(seq_along(model), ~ partykit::nodeapply(model, .x))

kids <- map(party_nodes, ~ {
Expand Down Expand Up @@ -88,3 +101,64 @@ tidypredict_fit.party <- function(model) {
parsedmodel <- parse_model(model)
build_fit_formula_rf(parsedmodel)[[1]]
}

# For {orbital}
#' @keywords internal
#' @export
.extract_partykit_classprob <- function(model) {
extract_classprob <- function(model) {
mod <- model$fitted
response <- mod[["(response)"]]
weights <- mod[["(weights)"]]

lvls <- levels(response)
weights_sum <- tapply(weights, response, sum)
weights_sum[is.na(weights_sum)] <- 0
res <- weights_sum / sum(weights)
names(res) <- lvls
res
}

preds <- map(seq_along(model), ~extract_classprob(model[[.x]]))
preds <- matrix(
unlist(preds),
nrow = length(preds),
byrow = TRUE,
dimnames = list(NULL, names(preds[[1]]))
)

generate_one_tree <- function(tree_info) {
paths <- tree_info$nodeID[tree_info[, "terminal"]]
paths <- map(
paths,
~ {
prediction <- tree_info$prediction[tree_info$nodeID == .x]
if (is.null(prediction)) cli::cli_abort("Prediction column not found")
if (is.factor(prediction)) prediction <- as.character(prediction)
list(
prediction = prediction,
path = get_ra_path(.x, tree_info, FALSE)
)
}
)

classes <- attr(model$terms, "dataClasses")
pm <- list()
pm$general$model <- "party"
pm$general$type <- "tree"
pm$general$version <- 2
pm$trees <- list(paths)
parsedmodel <- as_parsed_model(pm)

build_fit_formula_rf(parsedmodel)[[1]]
}

tree_info <- partykit_tree_info(model)

res <- list()
for (i in seq_len(ncol(preds))) {
tree_info$prediction <- preds[, i]
res[[i]] <- generate_one_tree(tree_info)
}
res
}

0 comments on commit 81a15bc

Please sign in to comment.