Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

enable tuning postprocessors #966

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 56 additions & 25 deletions R/grid_code_paths.R
Original file line number Diff line number Diff line change
Expand Up @@ -375,10 +375,12 @@ tune_grid_loop_iter <- function(split,

model_params <- vctrs::vec_slice(params, params$source == "model_spec")
preprocessor_params <- vctrs::vec_slice(params, params$source == "recipe")
postprocessor_params <- vctrs::vec_slice(params, params$source == "tailor")

param_names <- params$id
model_param_names <- model_params$id
preprocessor_param_names <- preprocessor_params$id
postprocessor_param_names <- postprocessor_params$id

# inline rsample::assessment so that we can pass indices to `predict_model()`
assessment_rows <- as.integer(split, data = "assessment")
Expand Down Expand Up @@ -542,34 +544,62 @@ tune_grid_loop_iter <- function(split,
# if the postprocessor does not require training, then `calibration` will
# be NULL and nothing other than the column names is learned from
# `assessment`.
workflow_with_post <- .fit_post(workflow, calibration %||% assessment)

workflow_with_post <- .fit_finalize(workflow_with_post)
# --------------------------------------------------------------------------
# Postprocessor loop
iter_postprocessors <- iter_grid_info_model[[".iter_postprocessor"]]

# run extract function on workflow with trained postprocessor
elt_extract <- .catch_and_log(
extract_details(workflow_with_post, control$extract),
control,
split_labels,
paste(iter_msg_model, "(extracts)"),
bad_only = TRUE,
notes = out_notes
)
elt_extract <- make_extracts(elt_extract, iter_grid, split_labels, .config = iter_config)
out_extracts <- append_extracts(out_extracts, elt_extract)
workflow_pre_and_model <- workflow

# generate predictions on the assessment set from the model and apply the
# post-processor to those predictions to generate updated predictions
iter_predictions <- .catch_and_log(
predict_model(assessment, assessment_rows, workflow_with_post, iter_grid,
metrics, iter_submodels, metrics_info = metrics_info,
eval_time = eval_time),
control,
split_labels,
paste(iter_msg_model, "(predictions with post-processor)"),
bad_only = TRUE,
notes = out_notes
)
for (iter_postprocessor in iter_postprocessors) {
workflow <- workflow_pre_and_model

iter_grid_info_postprocessor <- vctrs::vec_slice(
iter_grid_info_model,
iter_grid_info_model$.iter_postprocessor == iter_postprocessor
)

iter_grid_postprocessor <- iter_grid_info_postprocessor[, postprocessor_param_names]

iter_msg_postprocessor <- iter_grid_postprocessor[[".msg_postprocessor"]]
iter_config <- iter_grid_info_postprocessor[[".iter_config_post"]][[1L]]

workflow <- finalize_workflow_postprocessor(workflow, iter_grid_postprocessor)

workflow_with_post <- .fit_post(workflow, calibration %||% assessment)

workflow_with_post <- .fit_finalize(workflow_with_post)

iter_grid <- dplyr::bind_cols(
iter_grid_preprocessor,
iter_grid_model,
iter_grid_postprocessor
)

# run extract function on workflow with trained postprocessor
elt_extract <- .catch_and_log(
extract_details(workflow_with_post, control$extract),
control,
split_labels,
paste(iter_msg_model, "(extracts)"),
bad_only = TRUE,
notes = out_notes
)
elt_extract <- make_extracts(elt_extract, iter_grid, split_labels, .config = iter_config)
out_extracts <- append_extracts(out_extracts, elt_extract)

# generate predictions on the assessment set from the model and apply the
# post-processor to those predictions to generate updated predictions
iter_predictions <- .catch_and_log(
predict_model(assessment, assessment_rows, workflow_with_post, iter_grid,
metrics, iter_submodels, metrics_info = metrics_info,
eval_time = eval_time),
control,
split_labels,
paste(iter_msg_postprocessor, "(predictions with post-processor)"),
bad_only = TRUE,
notes = out_notes
)

# now, assess those predictions with performance metrics
}
Expand All @@ -595,6 +625,7 @@ tune_grid_loop_iter <- function(split,
control = control,
.config = iter_config_metrics
)
} # postprocessor loop
} # model loop
} # preprocessor loop

Expand Down
100 changes: 91 additions & 9 deletions R/grid_helpers.R
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@

predict_model <- function(new_data, orig_rows, workflow, grid, metrics,
submodels = NULL, metrics_info, eval_time = NULL) {

model <- extract_fit_parsnip(workflow)

forged <- forge_from_workflow(new_data, workflow)
Expand Down Expand Up @@ -260,6 +259,22 @@ finalize_workflow_preprocessor <- function(workflow, grid_preprocessor) {
workflow
}

#' @export
#' @rdname tune-internal-functions
finalize_workflow_postprocessor <- function(workflow, grid_postprocessor) {
# Already finalized, nothing to tune
if (ncol(grid_postprocessor) == 0L) {
return(workflow)
}

postprocessor <- workflows::extract_postprocessor(workflow)
postprocessor <- merge(postprocessor, grid_postprocessor)$x[[1]]

workflow <- set_workflow_tailor(workflow, postprocessor)

workflow
}

# ------------------------------------------------------------------------------

# For any type of tuning, and for fit-resamples, we generate a unified
Expand Down Expand Up @@ -310,16 +325,23 @@ compute_grid_info <- function(workflow, grid) {
grid <- tibble::as_tibble(grid)

parameters <- hardhat::extract_parameter_set_dials(workflow)
parameters_model <- dplyr::filter(parameters, source == "model_spec")

parameters_preprocessor <- dplyr::filter(parameters, source == "recipe")
parameters_model <- dplyr::filter(parameters, source == "model_spec")
parameters_postprocessor <- dplyr::filter(parameters, source == "tailor")

any_parameters_model <- nrow(parameters_model) > 0
any_parameters_preprocessor <- nrow(parameters_preprocessor) > 0

res <- min_grid(extract_spec_parsnip(workflow), grid)
any_parameters_model <- nrow(parameters_model) > 0
any_parameters_postprocessor <- nrow(parameters_postprocessor) > 0

syms_pre <- rlang::syms(parameters_preprocessor$id)
syms_mod <- rlang::syms(parameters_model$id)
syms_post <- rlang::syms(parameters_postprocessor$id)

res <- min_grid(extract_spec_parsnip(workflow), grid)
if (any_parameters_postprocessor) {
res <- nest_min_grid(res, parameters_postprocessor$id)
}

# ----------------------------------------------------------------------------
# Create an order of execution to train the preprocessor (if any). This will
Expand All @@ -340,7 +362,7 @@ compute_grid_info <- function(workflow, grid) {
res$.lab_pre <- "Preprocessor1"
}

# Make the label shown in the grid and in loggining
# Make the label shown in the grid and in logging
res$.msg_preprocessor <-
new_msgs_preprocessor(
res$.iter_preprocessor,
Expand All @@ -351,7 +373,6 @@ compute_grid_info <- function(workflow, grid) {
# Now make a similar iterator across models. Conditioning on each unique
# preprocessing candidate set, make an iterator for the model candidate sets
# (if any)

res <-
res %>%
dplyr::group_nest(.iter_preprocessor, keep = TRUE) %>%
Expand All @@ -370,21 +391,74 @@ compute_grid_info <- function(workflow, grid) {
n = res$.num_models,
res$.msg_preprocessor)

res %>%
res <- res %>%
dplyr::select(-.num_models) %>%
dplyr::relocate(dplyr::starts_with(".msg"))

# ----------------------------------------------------------------------------
# Finally, iterate across postprocessors. Conditioning on an .iter_config,
# make an iterator for each postprocessing candidate set (if any).
if (!any_parameters_postprocessor) {
return(res)
}

res <-
res %>%
dplyr::group_nest(.iter_config, keep = TRUE) %>%
dplyr::mutate(
data = purrr::map(data, make_iter_postprocessor, parameters_postprocessor$id)
) %>%
tidyr::unnest(cols = data) %>%
dplyr::relocate(dplyr::starts_with(".iter"), dplyr::starts_with(".msg"))

res
}

make_iter_config <- function(dat) {
# Compute labels for the models *within* each preprocessing loop.
num_submodels <- purrr::map_int(dat$.submodels, ~ length(unlist(.x)))
num_submodels <- purrr::map_int(
dat$.submodels,
function(.x) {if (length(.x) == 0) 0 else length(.x[[1]])}
)
num_models <- sum(num_submodels + 1) # +1 for the model being trained
.mod_label <- recipes::names0(num_models, "Model")
.iter_config <- paste(dat$.lab_pre[1], .mod_label, sep = "_")
.iter_config <- vctrs::vec_chop(.iter_config, sizes = num_submodels + 1)
tibble::tibble(.iter_config = .iter_config)
}

make_iter_postprocessor <- function(data, post_params) {
nested_by_post <- "post" %in% names(data)
if (nested_by_post) {
data <- data %>% unnest(post)
}

data %>%
mutate(
.iter_postprocessor = seq_len(nrow(.)),
.msg_postprocessor = new_msgs_postprocessor(
i = .iter_postprocessor,
n = max(.iter_postprocessor),
msgs_model = .msg_model
),
.iter_config_post = purrr::map2(
.iter_config,
.iter_postprocessor,
make_iter_config_post
)
) %>%
select(-.iter_config) %>%
nest(post = c(any_of(post_params), ".iter_postprocessor", ".msg_postprocessor", ".iter_config_post"))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We actually want postprocessor parameters names both inside and outside of .submodels, but this results in them only being inside of .submodels.

}

make_iter_config_post <- function(iter_config, iter_postprocessor) {
paste0(
iter_config,
"_Postprocessor",
iter_postprocessor
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Still needs a format_with_padding().

)
}

# This generates a "dummy" grid_info object that has the same
# structure as a grid-info object with no tunable recipe parameters
# and no tunable model parameters.
Expand Down Expand Up @@ -420,6 +494,9 @@ new_msgs_preprocessor <- function(i, n) {
new_msgs_model <- function(i, n, msgs_preprocessor) {
paste0(msgs_preprocessor, ", model ", i, "/", n)
}
new_msgs_postprocessor <- function(i, n, msgs_model) {
paste0(msgs_model, ", postprocessor ", i, "/", n)
}

# c(1, 10) -> c("01", "10")
format_with_padding <- function(x) {
Expand Down Expand Up @@ -467,3 +544,8 @@ set_workflow_recipe <- function(workflow, recipe) {
workflow$pre$actions$recipe$recipe <- recipe
workflow
}

set_workflow_tailor <- function(workflow, tailor) {
workflow$post$actions$tailor$tailor <- tailor
workflow
}
21 changes: 20 additions & 1 deletion R/merge.R
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,12 @@ merge.model_spec <- function(x, y, ...) {
merger(x, y, ...)
}

#' @export
#' @rdname merge.recipe
merge.tailor <- function(x, y, ...) {
merger(x, y, ...)
}

update_model <- function(grid, object, pset, step_id, nms, ...) {
for (i in nms) {
param_info <- pset %>% dplyr::filter(id == i & source == "model_spec")
Expand Down Expand Up @@ -108,6 +114,16 @@ update_recipe <- function(grid, object, pset, step_id, nms, ...) {
object
}

update_tailor <- function(grid, object, pset, adjustment_id, nms, ...) {
for (i in nms) {
param_info <- pset %>% dplyr::filter(id == i & source == "tailor")
if (nrow(param_info) == 1) {
idx <- which(adjustment_id == param_info$component_id)
object$adjustments[[idx]][["arguments"]][[param_info$name]] <- grid[[i]]
}
}
object
}

merger <- function(x, y, ...) {
if (!is.data.frame(y)) {
Expand All @@ -127,9 +143,12 @@ merger <- function(x, y, ...) {
if (inherits(x, "recipe")) {
updater <- update_recipe
step_ids <- purrr::map_chr(x$steps, ~ .x$id)
} else {
} else if (inherits(x, "model_spec")) {
updater <- update_model
step_ids <- NULL
} else {
updater <- update_tailor
step_ids <- purrr::map_chr(x$adjustments, ~class(.x)[1])
}

if (!any(grid_name %in% pset$id)) {
Expand Down
45 changes: 45 additions & 0 deletions R/min_grid.R
Original file line number Diff line number Diff line change
Expand Up @@ -326,3 +326,48 @@ min_grid.pls <- fit_max_value
#' @export min_grid.poisson_reg
#' @rdname min_grid
min_grid.poisson_reg <- fit_max_value


# When `min_grid()` is applied to grids with additional columns for
# postprocessors, we need to nest the postprocessor columns into
# .submodels to effectively enable the submodel trick.
# See: https://gist.github.com/simonpcouch/28d984cdcc3fc6d22ff776ed8740004e
nest_min_grid <- function(min_grid, post_params) {
if (!has_submodels(min_grid)) {
return(min_grid)
}
non_post_param_cols <- names(min_grid)[
!names(min_grid) %in% c(post_params, ".submodels")
]
submodel_param_name <- names(min_grid$.submodels[[1]])

res <-
min_grid %>%
# unnest from `list(list())` to `list()`
unnest(.submodels) %>%
# unnest from `list()` to vector
unnest(.submodels)

tibble(
vctrs::vec_unique(res[non_post_param_cols]),
post = list(vctrs::vec_unique(res[post_params])),
.submodels = list(
res[c(post_params, ".submodels")] %>%
rename(!!submodel_param_name := .submodels) %>%
group_by(across(all_of(submodel_param_name))) %>%
group_split()
)
)
}

has_submodels <- function(min_grid) {
if (!".submodels" %in% names(min_grid)) {
return(FALSE)
}

if (length(min_grid$.submodels[[1]]) == 0) {
return(FALSE)
}

TRUE
}
Loading
Loading