diff --git a/R/grid_code_paths.R b/R/grid_code_paths.R index 505d6ccc..a418605a 100644 --- a/R/grid_code_paths.R +++ b/R/grid_code_paths.R @@ -375,10 +375,12 @@ tune_grid_loop_iter <- function(split, model_params <- vctrs::vec_slice(params, params$source == "model_spec") preprocessor_params <- vctrs::vec_slice(params, params$source == "recipe") + postprocessor_params <- vctrs::vec_slice(params, params$source == "tailor") param_names <- params$id model_param_names <- model_params$id preprocessor_param_names <- preprocessor_params$id + postprocessor_param_names <- postprocessor_params$id # inline rsample::assessment so that we can pass indices to `predict_model()` assessment_rows <- as.integer(split, data = "assessment") @@ -542,34 +544,62 @@ tune_grid_loop_iter <- function(split, # if the postprocessor does not require training, then `calibration` will # be NULL and nothing other than the column names is learned from # `assessment`. - workflow_with_post <- .fit_post(workflow, calibration %||% assessment) - workflow_with_post <- .fit_finalize(workflow_with_post) + # -------------------------------------------------------------------------- + # Postprocessor loop + iter_postprocessors <- iter_grid_info_model[[".iter_postprocessor"]] - # run extract function on workflow with trained postprocessor - elt_extract <- .catch_and_log( - extract_details(workflow_with_post, control$extract), - control, - split_labels, - paste(iter_msg_model, "(extracts)"), - bad_only = TRUE, - notes = out_notes - ) - elt_extract <- make_extracts(elt_extract, iter_grid, split_labels, .config = iter_config) - out_extracts <- append_extracts(out_extracts, elt_extract) + workflow_pre_and_model <- workflow - # generate predictions on the assessment set from the model and apply the - # post-processor to those predictions to generate updated predictions - iter_predictions <- .catch_and_log( - predict_model(assessment, assessment_rows, workflow_with_post, iter_grid, - metrics, iter_submodels, metrics_info = metrics_info, - eval_time = eval_time), - control, - split_labels, - paste(iter_msg_model, "(predictions with post-processor)"), - bad_only = TRUE, - notes = out_notes - ) + for (iter_postprocessor in iter_postprocessors) { + workflow <- workflow_pre_and_model + + iter_grid_info_postprocessor <- vctrs::vec_slice( + iter_grid_info_model, + iter_grid_info_model$.iter_postprocessor == iter_postprocessor + ) + + iter_grid_postprocessor <- iter_grid_info_postprocessor[, postprocessor_param_names] + + iter_msg_postprocessor <- iter_grid_postprocessor[[".msg_postprocessor"]] + iter_config <- iter_grid_info_postprocessor[[".iter_config_post"]][[1L]] + + workflow <- finalize_workflow_postprocessor(workflow, iter_grid_postprocessor) + + workflow_with_post <- .fit_post(workflow, calibration %||% assessment) + + workflow_with_post <- .fit_finalize(workflow_with_post) + + iter_grid <- dplyr::bind_cols( + iter_grid_preprocessor, + iter_grid_model, + iter_grid_postprocessor + ) + + # run extract function on workflow with trained postprocessor + elt_extract <- .catch_and_log( + extract_details(workflow_with_post, control$extract), + control, + split_labels, + paste(iter_msg_model, "(extracts)"), + bad_only = TRUE, + notes = out_notes + ) + elt_extract <- make_extracts(elt_extract, iter_grid, split_labels, .config = iter_config) + out_extracts <- append_extracts(out_extracts, elt_extract) + + # generate predictions on the assessment set from the model and apply the + # post-processor to those predictions to generate updated predictions + iter_predictions <- .catch_and_log( + predict_model(assessment, assessment_rows, workflow_with_post, iter_grid, + metrics, iter_submodels, metrics_info = metrics_info, + eval_time = eval_time), + control, + split_labels, + paste(iter_msg_postprocessor, "(predictions with post-processor)"), + bad_only = TRUE, + notes = out_notes + ) # now, assess those predictions with performance metrics } @@ -595,6 +625,7 @@ tune_grid_loop_iter <- function(split, control = control, .config = iter_config_metrics ) + } # postprocessor loop } # model loop } # preprocessor loop diff --git a/R/grid_helpers.R b/R/grid_helpers.R index 365386d5..7d48a67b 100644 --- a/R/grid_helpers.R +++ b/R/grid_helpers.R @@ -1,7 +1,6 @@ predict_model <- function(new_data, orig_rows, workflow, grid, metrics, submodels = NULL, metrics_info, eval_time = NULL) { - model <- extract_fit_parsnip(workflow) forged <- forge_from_workflow(new_data, workflow) @@ -260,6 +259,22 @@ finalize_workflow_preprocessor <- function(workflow, grid_preprocessor) { workflow } +#' @export +#' @rdname tune-internal-functions +finalize_workflow_postprocessor <- function(workflow, grid_postprocessor) { + # Already finalized, nothing to tune + if (ncol(grid_postprocessor) == 0L) { + return(workflow) + } + + postprocessor <- workflows::extract_postprocessor(workflow) + postprocessor <- merge(postprocessor, grid_postprocessor)$x[[1]] + + workflow <- set_workflow_tailor(workflow, postprocessor) + + workflow +} + # ------------------------------------------------------------------------------ # For any type of tuning, and for fit-resamples, we generate a unified @@ -310,16 +325,23 @@ compute_grid_info <- function(workflow, grid) { grid <- tibble::as_tibble(grid) parameters <- hardhat::extract_parameter_set_dials(workflow) - parameters_model <- dplyr::filter(parameters, source == "model_spec") + parameters_preprocessor <- dplyr::filter(parameters, source == "recipe") + parameters_model <- dplyr::filter(parameters, source == "model_spec") + parameters_postprocessor <- dplyr::filter(parameters, source == "tailor") - any_parameters_model <- nrow(parameters_model) > 0 any_parameters_preprocessor <- nrow(parameters_preprocessor) > 0 - - res <- min_grid(extract_spec_parsnip(workflow), grid) + any_parameters_model <- nrow(parameters_model) > 0 + any_parameters_postprocessor <- nrow(parameters_postprocessor) > 0 syms_pre <- rlang::syms(parameters_preprocessor$id) syms_mod <- rlang::syms(parameters_model$id) + syms_post <- rlang::syms(parameters_postprocessor$id) + + res <- min_grid(extract_spec_parsnip(workflow), grid) + if (any_parameters_postprocessor) { + res <- nest_min_grid(res, parameters_postprocessor$id) + } # ---------------------------------------------------------------------------- # Create an order of execution to train the preprocessor (if any). This will @@ -340,7 +362,7 @@ compute_grid_info <- function(workflow, grid) { res$.lab_pre <- "Preprocessor1" } - # Make the label shown in the grid and in loggining + # Make the label shown in the grid and in logging res$.msg_preprocessor <- new_msgs_preprocessor( res$.iter_preprocessor, @@ -351,7 +373,6 @@ compute_grid_info <- function(workflow, grid) { # Now make a similar iterator across models. Conditioning on each unique # preprocessing candidate set, make an iterator for the model candidate sets # (if any) - res <- res %>% dplyr::group_nest(.iter_preprocessor, keep = TRUE) %>% @@ -370,14 +391,35 @@ compute_grid_info <- function(workflow, grid) { n = res$.num_models, res$.msg_preprocessor) - res %>% + res <- res %>% dplyr::select(-.num_models) %>% dplyr::relocate(dplyr::starts_with(".msg")) + + # ---------------------------------------------------------------------------- + # Finally, iterate across postprocessors. Conditioning on an .iter_config, + # make an iterator for each postprocessing candidate set (if any). + if (!any_parameters_postprocessor) { + return(res) + } + + res <- + res %>% + dplyr::group_nest(.iter_config, keep = TRUE) %>% + dplyr::mutate( + data = purrr::map(data, make_iter_postprocessor, parameters_postprocessor$id) + ) %>% + tidyr::unnest(cols = data) %>% + dplyr::relocate(dplyr::starts_with(".iter"), dplyr::starts_with(".msg")) + + res } make_iter_config <- function(dat) { # Compute labels for the models *within* each preprocessing loop. - num_submodels <- purrr::map_int(dat$.submodels, ~ length(unlist(.x))) + num_submodels <- purrr::map_int( + dat$.submodels, + function(.x) {if (length(.x) == 0) 0 else length(.x[[1]])} + ) num_models <- sum(num_submodels + 1) # +1 for the model being trained .mod_label <- recipes::names0(num_models, "Model") .iter_config <- paste(dat$.lab_pre[1], .mod_label, sep = "_") @@ -385,6 +427,38 @@ make_iter_config <- function(dat) { tibble::tibble(.iter_config = .iter_config) } +make_iter_postprocessor <- function(data, post_params) { + nested_by_post <- "post" %in% names(data) + if (nested_by_post) { + data <- data %>% unnest(post) + } + + data %>% + mutate( + .iter_postprocessor = seq_len(nrow(.)), + .msg_postprocessor = new_msgs_postprocessor( + i = .iter_postprocessor, + n = max(.iter_postprocessor), + msgs_model = .msg_model + ), + .iter_config_post = purrr::map2( + .iter_config, + .iter_postprocessor, + make_iter_config_post + ) + ) %>% + select(-.iter_config) %>% + nest(post = c(any_of(post_params), ".iter_postprocessor", ".msg_postprocessor", ".iter_config_post")) +} + +make_iter_config_post <- function(iter_config, iter_postprocessor) { + paste0( + iter_config, + "_Postprocessor", + iter_postprocessor + ) +} + # This generates a "dummy" grid_info object that has the same # structure as a grid-info object with no tunable recipe parameters # and no tunable model parameters. @@ -420,6 +494,9 @@ new_msgs_preprocessor <- function(i, n) { new_msgs_model <- function(i, n, msgs_preprocessor) { paste0(msgs_preprocessor, ", model ", i, "/", n) } +new_msgs_postprocessor <- function(i, n, msgs_model) { + paste0(msgs_model, ", postprocessor ", i, "/", n) +} # c(1, 10) -> c("01", "10") format_with_padding <- function(x) { @@ -467,3 +544,8 @@ set_workflow_recipe <- function(workflow, recipe) { workflow$pre$actions$recipe$recipe <- recipe workflow } + +set_workflow_tailor <- function(workflow, tailor) { + workflow$post$actions$tailor$tailor <- tailor + workflow +} diff --git a/R/merge.R b/R/merge.R index 7c6758c1..54cc4a1c 100644 --- a/R/merge.R +++ b/R/merge.R @@ -75,6 +75,12 @@ merge.model_spec <- function(x, y, ...) { merger(x, y, ...) } +#' @export +#' @rdname merge.recipe +merge.tailor <- function(x, y, ...) { + merger(x, y, ...) +} + update_model <- function(grid, object, pset, step_id, nms, ...) { for (i in nms) { param_info <- pset %>% dplyr::filter(id == i & source == "model_spec") @@ -108,6 +114,16 @@ update_recipe <- function(grid, object, pset, step_id, nms, ...) { object } +update_tailor <- function(grid, object, pset, adjustment_id, nms, ...) { + for (i in nms) { + param_info <- pset %>% dplyr::filter(id == i & source == "tailor") + if (nrow(param_info) == 1) { + idx <- which(adjustment_id == param_info$component_id) + object$adjustments[[idx]][["arguments"]][[param_info$name]] <- grid[[i]] + } + } + object +} merger <- function(x, y, ...) { if (!is.data.frame(y)) { @@ -127,9 +143,12 @@ merger <- function(x, y, ...) { if (inherits(x, "recipe")) { updater <- update_recipe step_ids <- purrr::map_chr(x$steps, ~ .x$id) - } else { + } else if (inherits(x, "model_spec")) { updater <- update_model step_ids <- NULL + } else { + updater <- update_tailor + step_ids <- purrr::map_chr(x$adjustments, ~class(.x)[1]) } if (!any(grid_name %in% pset$id)) { diff --git a/R/min_grid.R b/R/min_grid.R index 406c31ff..2b895003 100644 --- a/R/min_grid.R +++ b/R/min_grid.R @@ -326,3 +326,48 @@ min_grid.pls <- fit_max_value #' @export min_grid.poisson_reg #' @rdname min_grid min_grid.poisson_reg <- fit_max_value + + +# When `min_grid()` is applied to grids with additional columns for +# postprocessors, we need to nest the postprocessor columns into +# .submodels to effectively enable the submodel trick. +# See: https://gist.github.com/simonpcouch/28d984cdcc3fc6d22ff776ed8740004e +nest_min_grid <- function(min_grid, post_params) { + if (!has_submodels(min_grid)) { + return(min_grid) + } + non_post_param_cols <- names(min_grid)[ + !names(min_grid) %in% c(post_params, ".submodels") + ] + submodel_param_name <- names(min_grid$.submodels[[1]]) + + res <- + min_grid %>% + # unnest from `list(list())` to `list()` + unnest(.submodels) %>% + # unnest from `list()` to vector + unnest(.submodels) + + tibble( + vctrs::vec_unique(res[non_post_param_cols]), + post = list(vctrs::vec_unique(res[post_params])), + .submodels = list( + res[c(post_params, ".submodels")] %>% + rename(!!submodel_param_name := .submodels) %>% + group_by(across(all_of(submodel_param_name))) %>% + group_split() + ) + ) +} + +has_submodels <- function(min_grid) { + if (!".submodels" %in% names(min_grid)) { + return(FALSE) + } + + if (length(min_grid$.submodels[[1]]) == 0) { + return(FALSE) + } + + TRUE +} diff --git a/tests/testthat/test-grid_helpers.R b/tests/testthat/test-grid_helpers.R index dff242c9..88b997f0 100644 --- a/tests/testthat/test-grid_helpers.R +++ b/tests/testthat/test-grid_helpers.R @@ -342,3 +342,105 @@ test_that("compute_grid_info - recipe and model (no submodels but has inner grid expect_equal(nrow(res), 9) expect_equal(vctrs::vec_unique_count(res$.iter_config), nrow(grid)) }) + +test_that("compute_grid_info - model and postprocessor (no submodels)", { + library(workflows) + library(parsnip) + library(dials) + + spec <- boost_tree(mode = "regression", tree_depth = tune()) + tlr <- tailor() %>% adjust_probability_threshold(threshold = tune()) + + wflow <- workflow() + wflow <- add_model(wflow, spec) + wflow <- add_formula(wflow, mpg ~ .) + wflow <- add_tailor(wflow, tlr) + + grid <- grid_space_filling(extract_parameter_set_dials(wflow)) + res <- compute_grid_info(wflow, grid) + + expect_equal(res$.iter_preprocessor, rep(1, 5)) + expect_equal(res$.msg_preprocessor, rep("preprocessor 1/1", 5)) + expect_equal(res$tree_depth, grid$tree_depth) + expect_equal(res$.iter_model, 1:5) + expect_equal(res$.iter_config, as.list(paste0("Preprocessor1_Model", 1:5))) + expect_equal(res$.msg_model, paste0("preprocessor 1/1, model ", 1:5, "/5")) + expect_named( + res, + c(".iter_config", ".iter_preprocessor", ".iter_model", + ".msg_preprocessor", ".msg_model", + "tree_depth", "post", ".submodels"), + ignore.order = TRUE + ) + expect_equal(nrow(res), 5) +}) + +test_that("compute_grid_info - model and postprocessor (with submodels)", { + # when a workflow has a model with submodels and a postprocessor, we want + # to hook into the submodel trick in the same way we would have before + library(workflows) + library(parsnip) + library(dials) + + spec <- boost_tree(mode = "regression", trees = tune()) + tlr <- tailor() %>% adjust_probability_threshold(threshold = tune()) + + wflow <- workflow() + wflow <- add_model(wflow, spec) + wflow <- add_formula(wflow, mpg ~ .) + wflow <- add_tailor(wflow, tlr) + + grid <- grid_regular(extract_parameter_set_dials(wflow), levels = 3) + res <- compute_grid_info(wflow, grid) + + expect_equal(nrow(res), 1) + expect_equal(res$.iter_preprocessor, 1) + expect_equal(res$.msg_preprocessor, "preprocessor 1/1") + expect_equal(res$trees, max(grid$trees)) + expect_equal(res$.iter_model, 1) + expect_equal(res$.iter_config, list(paste0("Preprocessor1_Model", 1:3))) + expect_equal(res$.msg_model, "preprocessor 1/1, model 1/1") + + res_post <- res$post[[1]] + expect_equal(res_post$threshold, unique(grid$threshold)) + expect_equal(res_post$.iter_postprocessor, 1:3) + expect_equal( + res_post$.msg_postprocessor, + paste0("preprocessor 1/1, model 1/1, postprocessor ", 1:3, "/3") + ) + expect_equal( + res_post$.iter_config_post, + list( + paste0("Preprocessor1_Model", 1:3, "_Postprocessor1"), + paste0("Preprocessor1_Model", 1:3, "_Postprocessor2"), + paste0("Preprocessor1_Model", 1:3, "_Postprocessor3") + ) + ) + expect_named( + res, + c(".iter_config", ".iter_preprocessor", ".iter_model", + ".msg_preprocessor", ".msg_model", "trees", ".submodels", "post"), + ignore.order = TRUE + ) +}) + +test_that("compute_grid_info - model and postprocessor (with submodels but irregular)", { + library(workflows) + library(parsnip) + library(dials) + + spec <- boost_tree(mode = "regression", trees = tune()) + tlr <- tailor() %>% adjust_probability_threshold(threshold = tune()) + + wflow <- workflow() + wflow <- add_model(wflow, spec) + wflow <- add_formula(wflow, mpg ~ .) + wflow <- add_tailor(wflow, tlr) + + grid <- grid_regular(extract_parameter_set_dials(wflow), levels = 3) + grid <- grid[c(1:2, 5:nrow(grid)), ] + res <- compute_grid_info(wflow, grid) + + skip("does not work--removing some model fits shouldn't increase the number + of rows in the grid") +})