-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
first pass at the post-processing container (#1)
Co-authored-by: ‘topepo’ <‘[email protected]’> Co-authored-by: simonpcouch <[email protected]> Co-authored-by: Emil Hvitfeldt <[email protected]>
- Loading branch information
1 parent
aa5ac35
commit 3874410
Showing
46 changed files
with
1,862 additions
and
10 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,18 +6,28 @@ Authors@R: c( | |
person("Hannah", "Frick", , "[email protected]", role = "aut"), | ||
person("Emil", "HvitFeldt", , "[email protected]", role = "aut"), | ||
person("Max", "Kuhn", , "[email protected]", role = c("aut", "cre")), | ||
person(given = "Posit Software, PBC", role = c("cph", "fnd")) | ||
person("Posit Software, PBC", role = c("cph", "fnd")) | ||
) | ||
Description: Sandbox for a postprocessor object. | ||
License: MIT + file LICENSE | ||
URL: https://github.com/tidymodels/container | ||
BugReports: https://github.com/tidymodels/container/issues | ||
Imports: | ||
cli, | ||
dplyr, | ||
generics, | ||
hardhat, | ||
probably (>= 1.0.3.9000), | ||
purrr, | ||
rlang (>= 1.1.0), | ||
tibble, | ||
tidyselect | ||
Suggests: | ||
modeldata, | ||
testthat (>= 3.0.0) | ||
Remotes: | ||
tidymodels/probably | ||
Config/testthat/edition: 3 | ||
Encoding: UTF-8 | ||
Roxygen: list(markdown = TRUE) | ||
RoxygenNote: 7.3.1 | ||
URL: https://github.com/tidymodels/container | ||
BugReports: https://github.com/tidymodels/container/issues | ||
Imports: | ||
cli, | ||
rlang (>= 1.1.0) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,63 @@ | ||
# Generated by roxygen2: do not edit by hand | ||
|
||
S3method(fit,container) | ||
S3method(fit,equivocal_zone) | ||
S3method(fit,numeric_calibration) | ||
S3method(fit,numeric_range) | ||
S3method(fit,predictions_custom) | ||
S3method(fit,probability_calibration) | ||
S3method(fit,probability_threshold) | ||
S3method(predict,container) | ||
S3method(predict,equivocal_zone) | ||
S3method(predict,numeric_calibration) | ||
S3method(predict,numeric_range) | ||
S3method(predict,predictions_custom) | ||
S3method(predict,probability_calibration) | ||
S3method(predict,probability_threshold) | ||
S3method(print,container) | ||
S3method(print,equivocal_zone) | ||
S3method(print,numeric_calibration) | ||
S3method(print,numeric_range) | ||
S3method(print,predictions_custom) | ||
S3method(print,probability_calibration) | ||
S3method(print,probability_threshold) | ||
S3method(required_pkgs,equivocal_zone) | ||
S3method(required_pkgs,numeric_calibration) | ||
S3method(required_pkgs,numeric_range) | ||
S3method(required_pkgs,predictions_custom) | ||
S3method(required_pkgs,probability_calibration) | ||
S3method(required_pkgs,probability_threshold) | ||
S3method(tunable,equivocal_zone) | ||
S3method(tunable,numeric_calibration) | ||
S3method(tunable,numeric_range) | ||
S3method(tunable,predictions_custom) | ||
S3method(tunable,probability_calibration) | ||
S3method(tunable,probability_threshold) | ||
export("%>%") | ||
export(adjust_equivocal_zone) | ||
export(adjust_numeric_calibration) | ||
export(adjust_numeric_range) | ||
export(adjust_predictions_custom) | ||
export(adjust_probability_calibration) | ||
export(adjust_probability_threshold) | ||
export(container) | ||
export(extract_parameter_dials) | ||
export(extract_parameter_set_dials) | ||
export(fit) | ||
export(required_pkgs) | ||
export(tidy) | ||
export(tunable) | ||
export(tune_args) | ||
import(rlang) | ||
importFrom(cli,cli_abort) | ||
importFrom(cli,cli_inform) | ||
importFrom(cli,cli_warn) | ||
importFrom(dplyr,"%>%") | ||
importFrom(generics,fit) | ||
importFrom(generics,required_pkgs) | ||
importFrom(generics,tidy) | ||
importFrom(generics,tunable) | ||
importFrom(generics,tune_args) | ||
importFrom(hardhat,extract_parameter_dials) | ||
importFrom(hardhat,extract_parameter_set_dials) | ||
importFrom(stats,predict) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,118 @@ | ||
#' Apply an equivocal zone to a binary classification model. | ||
#' | ||
#' @param x A [container()]. | ||
#' @param value A numeric value (between zero and 1/2) or [hardhat::tune()]. The | ||
#' value is the size of the buffer around the threshold. | ||
#' @param threshold A numeric value (between zero and one) or [hardhat::tune()]. | ||
#' @examples | ||
#' library(dplyr) | ||
#' library(modeldata) | ||
#' | ||
#' post_obj <- | ||
#' container(mode = "classification") %>% | ||
#' adjust_equivocal_zone(value = 1 / 4) | ||
#' | ||
#' | ||
#' post_res <- fit( | ||
#' post_obj, | ||
#' two_class_example, | ||
#' outcome = c(truth), | ||
#' estimate = c(predicted), | ||
#' probabilities = c(Class1, Class2) | ||
#' ) | ||
#' | ||
#' predict(post_res, two_class_example) | ||
#' @export | ||
adjust_equivocal_zone <- function(x, value = 0.1, threshold = 1 / 2) { | ||
check_container(x) | ||
if (!is_tune(value)) { | ||
check_number_decimal(value, min = 0, max = 1 / 2) | ||
} | ||
if (!is_tune(threshold)) { | ||
check_number_decimal(threshold, min = 10^-10, max = 1 - 10^-10) | ||
} | ||
|
||
op <- | ||
new_operation( | ||
"equivocal_zone", | ||
inputs = "probability", | ||
outputs = "class", | ||
arguments = list(value = value, threshold = threshold), | ||
results = list(), | ||
trained = FALSE | ||
) | ||
|
||
new_container( | ||
mode = x$mode, | ||
type = x$type, | ||
operations = c(x$operations, list(op)), | ||
columns = x$dat, | ||
ptype = x$ptype, | ||
call = current_env() | ||
) | ||
} | ||
|
||
#' @export | ||
print.equivocal_zone <- function(x, ...) { | ||
# check for tune() first | ||
|
||
if (is_tune(x$arguments$value)) { | ||
cli::cli_bullets(c("*" = "Add equivocal zone of optimized size.")) | ||
} else { | ||
trn <- ifelse(x$trained, " [trained]", "") | ||
cli::cli_bullets(c( | ||
"*" = "Add equivocal zone of size | ||
{signif(x$arguments$value, digits = 3)}.{trn}" | ||
)) | ||
} | ||
invisible(x) | ||
} | ||
|
||
#' @export | ||
fit.equivocal_zone <- function(object, data, container = NULL, ...) { | ||
new_operation( | ||
class(object), | ||
inputs = object$inputs, | ||
outputs = object$outputs, | ||
arguments = object$arguments, | ||
results = list(), | ||
trained = TRUE | ||
) | ||
} | ||
|
||
#' @export | ||
predict.equivocal_zone <- function(object, new_data, container, ...) { | ||
est_nm <- container$columns$estimate | ||
prob_nm <- container$columns$probabilities[1] | ||
lvls <- levels(new_data[[est_nm]]) | ||
col_syms <- syms(prob_nm[1]) | ||
cls_pred <- probably::make_two_class_pred( | ||
new_data[[prob_nm]], | ||
levels = lvls, | ||
buffer = object$arguments$value, | ||
threshold = object$arguments$threshold | ||
) | ||
new_data[[est_nm]] <- cls_pred # todo convert to factor? | ||
new_data | ||
} | ||
|
||
#' @export | ||
required_pkgs.equivocal_zone <- function(x, ...) { | ||
c("container", "probably") | ||
} | ||
|
||
#' @export | ||
tunable.equivocal_zone <- function(x, ...) { | ||
tibble::new_tibble(list( | ||
name = "buffer", | ||
call_info = list(list(pkg = "dials", fun = "buffer")), | ||
source = "container", | ||
component = "equivocal_zone", | ||
component_id = "equivocal_zone" | ||
)) | ||
} | ||
|
||
# todo missing methods: | ||
# todo tune_args | ||
# todo tidy | ||
# todo extract_parameter_set_dials |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
#' Re-calibrate numeric predictions | ||
#' | ||
#' @param x A [container()]. | ||
#' @param calibrator A pre-trained calibration method from the \pkg{probably} | ||
#' package, such as [probably::cal_estimate_linear()]. | ||
#' @examples | ||
#' library(modeldata) | ||
#' library(probably) | ||
#' library(tibble) | ||
#' | ||
#' # create example data | ||
#' set.seed(1) | ||
#' dat <- tibble(y = rnorm(100), y_pred = y/2 + rnorm(100)) | ||
#' | ||
#' dat | ||
#' | ||
#' # calibrate numeric predictions | ||
#' reg_cal <- cal_estimate_linear(dat, truth = y, estimate = y_pred) | ||
#' | ||
#' # specify calibration | ||
#' reg_ctr <- | ||
#' container(mode = "regression") %>% | ||
#' adjust_numeric_calibration(reg_cal) | ||
#' | ||
#' # "train" container | ||
#' reg_ctr_trained <- fit(reg_ctr, dat, outcome = y, estimate = y_pred) | ||
#' | ||
#' predict(reg_ctr, dat) | ||
#' @export | ||
adjust_numeric_calibration <- function(x, calibrator) { | ||
check_container(x) | ||
check_required(calibrator) | ||
if (!inherits(calibrator, "cal_regression")) { | ||
cli_abort( | ||
"{.arg calibrator} should be a \\ | ||
{.help [<cal_regression> object](probably::cal_estimate_linear)}, \\ | ||
not {.obj_type_friendly {calibrator}}." | ||
) | ||
} | ||
|
||
op <- | ||
new_operation( | ||
"numeric_calibration", | ||
inputs = "numeric", | ||
outputs = "numeric", | ||
arguments = list(calibrator = calibrator), | ||
results = list(), | ||
trained = FALSE | ||
) | ||
|
||
new_container( | ||
mode = x$mode, | ||
type = x$type, | ||
operations = c(x$operations, list(op)), | ||
columns = x$dat, | ||
ptype = x$ptype, | ||
call = current_env() | ||
) | ||
} | ||
|
||
#' @export | ||
print.numeric_calibration <- function(x, ...) { | ||
trn <- ifelse(x$trained, " [trained]", "") | ||
cli::cli_bullets(c("*" = "Re-calibrate numeric predictions.{trn}")) | ||
invisible(x) | ||
} | ||
|
||
#' @export | ||
fit.numeric_calibration <- function(object, data, container = NULL, ...) { | ||
new_operation( | ||
class(object), | ||
inputs = object$inputs, | ||
outputs = object$outputs, | ||
arguments = object$arguments, | ||
results = list(), | ||
trained = TRUE | ||
) | ||
} | ||
|
||
#' @export | ||
predict.numeric_calibration <- function(object, new_data, container, ...) { | ||
probably::cal_apply(new_data, object$argument$calibrator) | ||
} | ||
|
||
# todo probably needs required_pkgs methods for cal objects | ||
#' @export | ||
required_pkgs.numeric_calibration <- function(x, ...) { | ||
c("container", "probably") | ||
} | ||
|
||
#' @export | ||
tunable.numeric_calibration <- function(x, ...) { | ||
no_param | ||
} | ||
|
||
# todo missing methods: | ||
# todo tune_args | ||
# todo tidy | ||
# todo extract_parameter_set_dials |
Oops, something went wrong.