-
Notifications
You must be signed in to change notification settings - Fork 2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
first pass at the post-processing container #1
Changes from 16 commits
69f1351
3b4f35f
56ca8e5
c1b173e
c62ac77
38b5662
8f0b4a2
162c212
c07348c
979841c
00efa2a
196a0ca
66f6e2e
2037357
aeca10d
7509ba0
025ae83
319ac5f
4164156
2689bdb
4db51af
6e811da
17e154a
fa1f95d
0eb7a04
561a795
b92edb0
809da4a
a5d5843
661a323
e3038ba
fbc9fbe
cfe9455
9e4f483
72d1007
ab07580
3cab3e0
917196f
5e6f981
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,18 +6,26 @@ Authors@R: c( | |
person("Hannah", "Frick", , "[email protected]", role = "aut"), | ||
person("Emil", "HvitFeldt", , "[email protected]", role = "aut"), | ||
person("Max", "Kuhn", , "[email protected]", role = c("aut", "cre")), | ||
person(given = "Posit Software, PBC", role = c("cph", "fnd")) | ||
person("Posit Software, PBC", role = c("cph", "fnd")) | ||
) | ||
Description: Sandbox for a postprocessor object. | ||
License: MIT + file LICENSE | ||
URL: https://github.com/tidymodels/container | ||
BugReports: https://github.com/tidymodels/container/issues | ||
Imports: | ||
cli, | ||
dplyr, | ||
generics, | ||
hardhat, | ||
probably, | ||
purrr, | ||
rlang (>= 1.1.0), | ||
tibble, | ||
tidyselect | ||
Suggests: | ||
modeldata, | ||
testthat (>= 3.0.0) | ||
Config/testthat/edition: 3 | ||
Encoding: UTF-8 | ||
Roxygen: list(markdown = TRUE) | ||
RoxygenNote: 7.3.1 | ||
URL: https://github.com/tidymodels/container | ||
BugReports: https://github.com/tidymodels/container/issues | ||
Imports: | ||
cli, | ||
rlang (>= 1.1.0) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,63 @@ | ||
# Generated by roxygen2: do not edit by hand | ||
|
||
S3method(fit,container) | ||
S3method(fit,equivocal_zone) | ||
S3method(fit,numeric_calibration) | ||
S3method(fit,numeric_range) | ||
S3method(fit,predictions_custom) | ||
S3method(fit,probability_calibration) | ||
S3method(fit,probability_threshold) | ||
S3method(predict,container) | ||
S3method(predict,equivocal_zone) | ||
S3method(predict,numeric_calibration) | ||
S3method(predict,numeric_range) | ||
S3method(predict,predictions_custom) | ||
S3method(predict,probability_calibration) | ||
S3method(predict,probability_threshold) | ||
S3method(print,container) | ||
S3method(print,equivocal_zone) | ||
S3method(print,numeric_calibration) | ||
S3method(print,numeric_range) | ||
S3method(print,predictions_custom) | ||
S3method(print,probability_calibration) | ||
S3method(print,probability_threshold) | ||
S3method(required_pkgs,equivocal_zone) | ||
S3method(required_pkgs,numeric_calibration) | ||
S3method(required_pkgs,numeric_range) | ||
S3method(required_pkgs,predictions_custom) | ||
S3method(required_pkgs,probability_calibration) | ||
S3method(required_pkgs,probability_threshold) | ||
S3method(tunable,equivocal_zone) | ||
S3method(tunable,numeric_calibration) | ||
S3method(tunable,numeric_range) | ||
S3method(tunable,predictions_custom) | ||
S3method(tunable,probability_calibration) | ||
S3method(tunable,probability_threshold) | ||
export("%>%") | ||
export(adjust_equivocal_zone) | ||
export(adjust_numeric_calibration) | ||
export(adjust_numeric_range) | ||
export(adjust_predictions_custom) | ||
export(adjust_probability_calibration) | ||
export(adjust_probability_threshold) | ||
export(container) | ||
export(extract_parameter_dials) | ||
export(extract_parameter_set_dials) | ||
export(fit) | ||
export(required_pkgs) | ||
export(tidy) | ||
export(tunable) | ||
export(tune_args) | ||
import(rlang) | ||
importFrom(cli,cli_abort) | ||
importFrom(cli,cli_inform) | ||
importFrom(cli,cli_warn) | ||
importFrom(dplyr,"%>%") | ||
importFrom(generics,fit) | ||
importFrom(generics,required_pkgs) | ||
importFrom(generics,tidy) | ||
importFrom(generics,tunable) | ||
importFrom(generics,tune_args) | ||
importFrom(hardhat,extract_parameter_dials) | ||
importFrom(hardhat,extract_parameter_set_dials) | ||
importFrom(stats,predict) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,8 +1,12 @@ | ||
#' @import rlang | ||
#' @importFrom cli cli_abort cli_warn cli_inform | ||
#' @importFrom stats predict | ||
#' @keywords internal | ||
"_PACKAGE" | ||
|
||
## usethis namespace: start | ||
utils::globalVariables("data") | ||
## usethis namespace: end | ||
NULL | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,174 @@ | ||
#' Declare post-processing for model predictions | ||
#' | ||
#' @param mode The model's mode, one of `"unknown"`, `"classification"`, or | ||
#' `"regression"`. Modes of `"censored regression"` are not currently supported. | ||
#' @param type The model sub-type. Possible values are `"unknown"`, `"regression"`, | ||
#' `"binary"`, or `"multiclass"`. | ||
#' @param outcome The name of the outcome variable. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same comment for all params: Right now the code specifies the names of the variables should be specified as character vectors. But I feel like we are moving away from strings and to have everything use {tidypredict} whether possible. if we decide to stick with character vectors, then the documentation should reflect that. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Did you mean {tidyselect} instead of {tidypredict}? We do use tidyselect in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes, tidyselect |
||
#' @param estimate The name of the point estimate (e.g. predicted class) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Tripped up on this argument name. This is (likely) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah. I added more context in this man file. We use it since it is the name of analgous argument in yardstick metric functions. |
||
#' @param probabilities The names of class probability estimates (if any). For | ||
#' classification, these should be given in the order of the factor levels of | ||
#' the `estimate`. | ||
#' @param time The name of the predicted event time. (not yet supported) | ||
#' @param call The call to be displayed in warnings or errors. | ||
#' @examples | ||
#' | ||
#' container() | ||
#' @export | ||
container <- function(mode = "unknown", type = "unknown", outcome = character(0), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We’ll probably need to validate the type of model as well as it “species.” Specifically, we might need to differentiate between binary and multiclass classification. We can do this via an argument/attribute or by adding a more specific class to the container. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'd prefer that users need not specify either of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, this is what I thought that we would do. |
||
estimate = character(0), probabilities = character(0), | ||
time = character(0), call = rlang::current_env()) { | ||
dat <- | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Another open question is: when do we have the users specify the relevant names of the expected columns? Within a workflow, we will set these but there are a few differences between recipes and what we are doing here are:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. related: #1 (comment) |
||
list( | ||
outcome = outcome, | ||
type = type, | ||
estimate = estimate, | ||
probabilities = probabilities, | ||
time = time | ||
) | ||
new_container( | ||
mode, | ||
type, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. seems like we don't need |
||
operations = list(), | ||
columns = dat, | ||
ptype = tibble::tibble(), | ||
call = call | ||
) | ||
} | ||
|
||
new_container <- function(mode, type, operations, columns, ptype, call) { | ||
mode <- rlang::arg_match0(mode, c("unknown", "regression", "classification", "censored regression")) | ||
|
||
if ( mode == "regression" ) { | ||
type <- "regression" | ||
} | ||
|
||
type <- rlang::arg_match0(type, c("unknown", "regression", "binary", "multiclass")) | ||
|
||
if ( !is.list(operations) ) { | ||
cli::cli_abort("The {.arg operations} argument should be a list.", call = call) | ||
} | ||
|
||
is_oper <- purrr::map_lgl(operations, ~ inherits(.x, "operation")) | ||
if ( length(is_oper) > 0 & !any(is_oper) ) { | ||
bad_oper <- names(is_oper)[!is_oper] | ||
cli::cli_abort("The following {.arg operations} do not have the class \\ | ||
{.val operation}: {bad_oper}.", call = call) | ||
} | ||
|
||
# validate operation order and check duplicates | ||
validate_oper_order(operations, mode, call) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This gets invoked when operations are added so we can validate them as they go. |
||
|
||
|
||
# check columns | ||
|
||
res <- list(mode = mode, type = type, operations = operations, | ||
columns = columns, ptype = ptype) | ||
class(res) <- "container" | ||
res | ||
} | ||
|
||
#' @export | ||
print.container <- function(x, ...) { | ||
# todo emulate Emil's recipe printing | ||
|
||
num_op <- length(x$operations) | ||
cli::cli_inform("{x$type} post-processing object with {num_op} operation{?s}") | ||
|
||
if (num_op > 0) { | ||
cat("\n") | ||
res <- purrr::map(x$operations, ~ print(.x)) | ||
} | ||
|
||
invisible(x) | ||
} | ||
|
||
|
||
# ------------------------------------------------------------------------------ | ||
|
||
#' @export | ||
fit.container <- function(object, .data, outcome, estimate, probabilities = c(), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How does There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'd vote we just use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah. I'm trying to avoid error messages about unsettable closures. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nit pick, use |
||
time = c(), call = rlang::current_env(), ...) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There are placeholders for survival models but we won't have any methods for this for a while. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No, it is a placeholder for the predicted time. We probably need There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
|
||
# ------------------------------------------------------------------------------ | ||
# set columns via tidyselect | ||
|
||
dat <- list() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 👍 ab07580 |
||
dat$outcome <- names(tidyselect::eval_select(rlang::enquo(outcome), .data)) | ||
dat$estimate <- names(tidyselect::eval_select(rlang::enquo(estimate), .data)) | ||
|
||
probabilities <- tidyselect::eval_select(rlang::enquo(probabilities), .data) | ||
if (length(probabilities) > 0) { | ||
dat$probabilities <- names(probabilities) | ||
} else { | ||
dat$probabilities <- character(0) | ||
} | ||
|
||
time <- tidyselect::eval_select(rlang::enquo(time), .data) | ||
if (length(time) > 0) { | ||
dat$time <- names(time) | ||
} else { | ||
dat$time <- character(0) | ||
} | ||
|
||
.data <- .data[, names(.data) %in% unlist(dat)] | ||
.data <- tibble::as_tibble(.data) | ||
ptype <- .data[0,] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just a note for future us: I considered adding methods for applicability scores. We'd need to update the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
|
||
|
||
|
||
object <- set_container_type(object, .data[[ dat$outcome ]]) | ||
|
||
object <- new_container(object$mode, object$type, | ||
operations = object$operations, | ||
columns = dat, ptype = ptype, call = call) | ||
|
||
# ------------------------------------------------------------------------------ | ||
|
||
num_oper <- length(object$operations) | ||
for (op in 1:num_oper) { | ||
simonpcouch marked this conversation as resolved.
Show resolved
Hide resolved
|
||
object$operations[[op]] <- fit(object$operations[[op]], data, object) | ||
.data <- predict(object$operations[[op]], .data, object) | ||
} | ||
|
||
# todo Add a fitted container class? | ||
object | ||
} | ||
|
||
#' @export | ||
predict.container <- function(object, new_data, ...) { | ||
|
||
# validate levels/classes | ||
num_oper <- length(object$operations) | ||
for (op in 1:num_oper) { | ||
simonpcouch marked this conversation as resolved.
Show resolved
Hide resolved
|
||
new_data <- predict(object$operations[[op]], new_data, object) | ||
} | ||
tibble::as_tibble(new_data) | ||
} | ||
|
||
set_container_type <- function(object, y) { | ||
if (object$type != "unknown") { | ||
return(object) | ||
} | ||
if (is.factor(y)) { | ||
lvls <- levels(y) | ||
if (length(lvls) == 2) { | ||
object$type <- "binary" | ||
} else { | ||
object$type <- "multiclass" | ||
} | ||
} else if (is.numeric(y)) { | ||
object$type <- "regression" | ||
} else { | ||
cli::cli_abort("Only factor and numeric outcomes are currently supported.") | ||
} | ||
object | ||
} | ||
|
||
# todo: where to validate #levels? | ||
# todo setup eval_time | ||
# todo missing methods: | ||
# todo tune_args | ||
# todo tidy | ||
# todo extract_parameter_set_dials | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,110 @@ | ||
#' Apply an equivocal zone to a binary classification model. | ||
#' | ||
#' @param x A [container()]. | ||
#' @param value A numeric value (between zero and 1/2) or [hardhat::tune()]. The | ||
#' value is the size of the buffer around the threshold. | ||
#' @param threshold A numeric value (between zero and one) or [hardhat::tune()]. | ||
#' @examples | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'll need to add a Details section to explain this as well as what to look out for. The new class predictions have a different class (similar to a factor) that, to date, works with our tidymodels functions. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
#' library(dplyr) | ||
#' library(modeldata) | ||
#' | ||
#' post_obj <- | ||
#' container(mode = "classification") %>% | ||
#' adjust_equivocal_zone(value = 1 / 4) | ||
#' | ||
#' | ||
#' post_res <- fit( | ||
#' post_obj, | ||
#' two_class_example, | ||
#' outcome = c(truth), | ||
#' estimate = c(predicted), | ||
#' probabilities = c(Class1, Class2) | ||
#' ) | ||
#' | ||
#' predict(post_res, two_class_example) | ||
#' @export | ||
adjust_equivocal_zone <- function(x, value = 0.1, threshold = 1 / 2) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Users might include a step before this to adjust the threshold, so we should maybe add some code to inherit a previous threshold if one exists in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
|
||
if ( !is_tune(value) ) { | ||
check_number_decimal(value, min = 0, max = 1 / 2) | ||
} | ||
if ( !is_tune(threshold) ) { | ||
check_number_decimal(threshold, min = 10^-10, max = 1 - 10^-10) | ||
} | ||
|
||
op <- | ||
new_operation( | ||
"equivocal_zone", | ||
inputs = "probability", | ||
outputs = "class", | ||
arguments = list(value = value, threshold = threshold), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I wish we had structured recipe steps with containers (ha) for inputs and computed results. Let's consider these names in case we do the same for recipes 2E (if that happens). |
||
results = list(trained = FALSE) | ||
) | ||
|
||
new_container( | ||
mode = x$mode, | ||
type = x$type, | ||
operations = c(x$operations, list(op)), | ||
columns = x$dat, | ||
ptype = x$ptype, | ||
call = rlang::current_env() | ||
) | ||
} | ||
|
||
#' @export | ||
print.equivocal_zone <- function(x, ...) { | ||
# check for tune() first | ||
|
||
if ( is_tune(x$arguments$value) ) { | ||
cli::cli_inform("Add equivocal zone to optimized value.") | ||
} else { | ||
trn <- ifelse(x$results$trained, " [trained]", "") | ||
cli::cli_inform(c("Add equivocal zone of size \\ | ||
{signif(x$arguments$value, digits = 3)}{trn}")) | ||
} | ||
invisible(x) | ||
} | ||
|
||
#' @export | ||
fit.equivocal_zone <- function(object, data, parent = NULL, ...) { | ||
topepo marked this conversation as resolved.
Show resolved
Hide resolved
|
||
new_operation( | ||
class(object), | ||
inputs = object$inputs, | ||
outputs = object$outputs, | ||
arguments = object$arguments, | ||
results = list(trained = TRUE) | ||
) | ||
} | ||
|
||
#' @export | ||
predict.equivocal_zone <- function(object, new_data, parent, ...) { | ||
est_nm <- parent$columns$estimate | ||
prob_nm <- parent$columns$probabilities[1] | ||
lvls <- levels(new_data[[ est_nm ]]) | ||
col_syms <- rlang::syms(prob_nm[1]) | ||
cls_pred <- probably::make_two_class_pred(new_data[[prob_nm]], levels = lvls, | ||
buffer = object$arguments$value, | ||
threshold = object$arguments$threshold) | ||
new_data[[ est_nm ]] <- cls_pred # todo convert to factor? | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yardstick will convert equivocals to NA so we probably don't have to do it here. |
||
new_data | ||
} | ||
|
||
#' @export | ||
required_pkgs.equivocal_zone <- function(x, ...) { | ||
c("container", "probably") | ||
} | ||
|
||
#' @export | ||
tunable.equivocal_zone <- function(x, ...) { | ||
tibble::tibble( | ||
name = "buffer", | ||
call_info = list(list(pkg = "dials", fun = "buffer")), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. A function to be created later (as are many others in the package). There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
source = "container", | ||
component = "equivocal_zone", | ||
component_id = "equivocal_zone") | ||
} | ||
|
||
# todo missing methods: | ||
# todo tune_args | ||
# todo tidy | ||
# todo extract_parameter_set_dials |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm fine with being corrected. Would it be worth rethinking
"unknown"
as a possible mode, as in just allow mode to be"classification"
or"regression"
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can do that (and easily un-do it later). This was meant to be analogous to parsnip's unknown mode. However, I think that this situation is different and we can remove it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
917196f