From bb6ecaba6b9eed93c583aedea047a4f93be28c0a Mon Sep 17 00:00:00 2001 From: Emil Hvitfeldt Date: Wed, 15 Jan 2025 14:56:58 -0800 Subject: [PATCH 1/2] add `toggle_sparsity()` --- R/fit.R | 2 ++ R/sparsevctrs.R | 76 +++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 78 insertions(+) diff --git a/R/fit.R b/R/fit.R index 13d1315..8c42940 100644 --- a/R/fit.R +++ b/R/fit.R @@ -71,6 +71,8 @@ fit.workflow <- function(object, data, ..., calibration = NULL, control = contro ) } + object <- toggle_sparsity(object, data) + workflow <- object workflow <- .fit_pre(workflow, data) workflow <- .fit_model(workflow, control) diff --git a/R/sparsevctrs.R b/R/sparsevctrs.R index bbb865a..b272d42 100644 --- a/R/sparsevctrs.R +++ b/R/sparsevctrs.R @@ -1,3 +1,79 @@ is_sparse_matrix <- function(x) { methods::is(x, "sparseMatrix") } + +toggle_sparsity <- function(object, data) { + toggle_sparse <- "no" + + if (allow_sparse(object$fit$actions$model$spec)) { + if ("recipe" %in% names(object$pre$actions)) { + est_sparsity <- recipes::.recipes_estimate_sparsity( + object$pre$actions$recipe$recipe + ) + } else { + est_sparsity <- sparsevctrs::sparsity(data, 1000) + } + + pred_log_fold <- pred_log_fold( + est_sparsity, + object$fit$actions$model$spec$engine, + nrow(data) + ) + if (pred_log_fold > 0) { + toggle_sparse <- "yes" + } + } + + object$pre$actions$recipe$recipe <- recipes::.recipes_toggle_sparse_args( + object$pre$actions$recipe$recipe, + choice = toggle_sparse + ) + object +} + +allow_sparse <- function(x) { + if (inherits(x, "model_fit")) { + x <- x$spec + } + res <- parsnip::get_from_env(paste0(class(x)[1], "_encoding")) + all(res$allow_sparse_x[res$engine == x$engine]) +} + +pred_log_fold <- function(sparsity, model, n_rows) { + if (is.null(model) || model == "ranger") { + return("no") + } + + log_fold <- -0.599333138645995 + + ifelse(sparsity < 0.836601307189543, 0.836601307189543 - sparsity, 0) * + -0.541581853008009 + + ifelse(n_rows < 16000, 16000 - n_rows, 0) * 3.23980908942813e-05 + + ifelse(n_rows > 16000, n_rows - 16000, 0) * -2.81001152147355e-06 + + ifelse(sparsity > 0.836601307189543, sparsity - 0.836601307189543, 0) * + 9.82444255114058 + + ifelse(sparsity > 0.836601307189543, sparsity - 0.836601307189543, 0) * + ifelse(n_rows > 8000, n_rows - 8000, 0) * + 7.27456967763306e-05 + + ifelse(sparsity > 0.836601307189543, sparsity - 0.836601307189543, 0) * + ifelse(n_rows < 8000, 8000 - n_rows, 0) * + -0.000798307404212627 + + if (model == "xgboost") { + log_fold <- log_fold + + ifelse(sparsity < 0.984615384615385, 0.984615384615385 - sparsity, 0) * + 0.113098025073806 + + ifelse(n_rows < 8000, 8000 - n_rows, 0) * -9.77914237255269e-05 + + ifelse(n_rows > 8000, n_rows - 8000, 0) * 3.22657666511869e-06 + + ifelse(sparsity > 0.984615384615385, sparsity - 0.984615384615385, 0) * + 41.5180348086939 + + 0.913457808326756 + } + + if (model == "LiblineaR") { + log_fold <- log_fold + + ifelse(sparsity > 0.836601307189543, sparsity - 0.836601307189543, 0) * + -5.39592564852111 + } + + log_fold +} From f68285e6acdceb8b70228a5740c0b8dbfb5f245b Mon Sep 17 00:00:00 2001 From: Emil Hvitfeldt Date: Thu, 16 Jan 2025 12:14:57 -0800 Subject: [PATCH 2/2] move recipes to imports --- DESCRIPTION | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/DESCRIPTION b/DESCRIPTION index 181d507..af08bff 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -25,6 +25,7 @@ Imports: lifecycle (>= 1.0.3), modelenv (>= 0.1.0), parsnip (>= 1.2.1.9000), + recipes (>= 1.0.10.9000), rlang (>= 1.1.0), tidyselect (>= 1.2.0), sparsevctrs (>= 0.1.0.9002), @@ -42,7 +43,6 @@ Suggests: methods, modeldata (>= 1.0.0), probably, - recipes (>= 1.0.10.9000), rmarkdown, testthat (>= 3.0.0) VignetteBuilder: