Skip to content

Commit

Permalink
first pass at the post-processing container (#1)
Browse files Browse the repository at this point in the history
Co-authored-by: ‘topepo’ <‘[email protected]’>
Co-authored-by: simonpcouch <[email protected]>
Co-authored-by: Emil Hvitfeldt <[email protected]>
  • Loading branch information
4 people authored Apr 25, 2024
1 parent aa5ac35 commit 3874410
Show file tree
Hide file tree
Showing 46 changed files with 1,862 additions and 10 deletions.
22 changes: 16 additions & 6 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,28 @@ Authors@R: c(
person("Hannah", "Frick", , "[email protected]", role = "aut"),
person("Emil", "HvitFeldt", , "[email protected]", role = "aut"),
person("Max", "Kuhn", , "[email protected]", role = c("aut", "cre")),
person(given = "Posit Software, PBC", role = c("cph", "fnd"))
person("Posit Software, PBC", role = c("cph", "fnd"))
)
Description: Sandbox for a postprocessor object.
License: MIT + file LICENSE
URL: https://github.com/tidymodels/container
BugReports: https://github.com/tidymodels/container/issues
Imports:
cli,
dplyr,
generics,
hardhat,
probably (>= 1.0.3.9000),
purrr,
rlang (>= 1.1.0),
tibble,
tidyselect
Suggests:
modeldata,
testthat (>= 3.0.0)
Remotes:
tidymodels/probably
Config/testthat/edition: 3
Encoding: UTF-8
Roxygen: list(markdown = TRUE)
RoxygenNote: 7.3.1
URL: https://github.com/tidymodels/container
BugReports: https://github.com/tidymodels/container/issues
Imports:
cli,
rlang (>= 1.1.0)
57 changes: 57 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
@@ -1,6 +1,63 @@
# Generated by roxygen2: do not edit by hand

S3method(fit,container)
S3method(fit,equivocal_zone)
S3method(fit,numeric_calibration)
S3method(fit,numeric_range)
S3method(fit,predictions_custom)
S3method(fit,probability_calibration)
S3method(fit,probability_threshold)
S3method(predict,container)
S3method(predict,equivocal_zone)
S3method(predict,numeric_calibration)
S3method(predict,numeric_range)
S3method(predict,predictions_custom)
S3method(predict,probability_calibration)
S3method(predict,probability_threshold)
S3method(print,container)
S3method(print,equivocal_zone)
S3method(print,numeric_calibration)
S3method(print,numeric_range)
S3method(print,predictions_custom)
S3method(print,probability_calibration)
S3method(print,probability_threshold)
S3method(required_pkgs,equivocal_zone)
S3method(required_pkgs,numeric_calibration)
S3method(required_pkgs,numeric_range)
S3method(required_pkgs,predictions_custom)
S3method(required_pkgs,probability_calibration)
S3method(required_pkgs,probability_threshold)
S3method(tunable,equivocal_zone)
S3method(tunable,numeric_calibration)
S3method(tunable,numeric_range)
S3method(tunable,predictions_custom)
S3method(tunable,probability_calibration)
S3method(tunable,probability_threshold)
export("%>%")
export(adjust_equivocal_zone)
export(adjust_numeric_calibration)
export(adjust_numeric_range)
export(adjust_predictions_custom)
export(adjust_probability_calibration)
export(adjust_probability_threshold)
export(container)
export(extract_parameter_dials)
export(extract_parameter_set_dials)
export(fit)
export(required_pkgs)
export(tidy)
export(tunable)
export(tune_args)
import(rlang)
importFrom(cli,cli_abort)
importFrom(cli,cli_inform)
importFrom(cli,cli_warn)
importFrom(dplyr,"%>%")
importFrom(generics,fit)
importFrom(generics,required_pkgs)
importFrom(generics,tidy)
importFrom(generics,tunable)
importFrom(generics,tune_args)
importFrom(hardhat,extract_parameter_dials)
importFrom(hardhat,extract_parameter_set_dials)
importFrom(stats,predict)
118 changes: 118 additions & 0 deletions R/adjust-equivocal-zone.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
#' Apply an equivocal zone to a binary classification model.
#'
#' @param x A [container()].
#' @param value A numeric value (between zero and 1/2) or [hardhat::tune()]. The
#' value is the size of the buffer around the threshold.
#' @param threshold A numeric value (between zero and one) or [hardhat::tune()].
#' @examples
#' library(dplyr)
#' library(modeldata)
#'
#' post_obj <-
#' container(mode = "classification") %>%
#' adjust_equivocal_zone(value = 1 / 4)
#'
#'
#' post_res <- fit(
#' post_obj,
#' two_class_example,
#' outcome = c(truth),
#' estimate = c(predicted),
#' probabilities = c(Class1, Class2)
#' )
#'
#' predict(post_res, two_class_example)
#' @export
adjust_equivocal_zone <- function(x, value = 0.1, threshold = 1 / 2) {
check_container(x)
if (!is_tune(value)) {
check_number_decimal(value, min = 0, max = 1 / 2)
}
if (!is_tune(threshold)) {
check_number_decimal(threshold, min = 10^-10, max = 1 - 10^-10)
}

op <-
new_operation(
"equivocal_zone",
inputs = "probability",
outputs = "class",
arguments = list(value = value, threshold = threshold),
results = list(),
trained = FALSE
)

new_container(
mode = x$mode,
type = x$type,
operations = c(x$operations, list(op)),
columns = x$dat,
ptype = x$ptype,
call = current_env()
)
}

#' @export
print.equivocal_zone <- function(x, ...) {
# check for tune() first

if (is_tune(x$arguments$value)) {
cli::cli_bullets(c("*" = "Add equivocal zone of optimized size."))
} else {
trn <- ifelse(x$trained, " [trained]", "")
cli::cli_bullets(c(
"*" = "Add equivocal zone of size
{signif(x$arguments$value, digits = 3)}.{trn}"
))
}
invisible(x)
}

#' @export
fit.equivocal_zone <- function(object, data, container = NULL, ...) {
new_operation(
class(object),
inputs = object$inputs,
outputs = object$outputs,
arguments = object$arguments,
results = list(),
trained = TRUE
)
}

#' @export
predict.equivocal_zone <- function(object, new_data, container, ...) {
est_nm <- container$columns$estimate
prob_nm <- container$columns$probabilities[1]
lvls <- levels(new_data[[est_nm]])
col_syms <- syms(prob_nm[1])
cls_pred <- probably::make_two_class_pred(
new_data[[prob_nm]],
levels = lvls,
buffer = object$arguments$value,
threshold = object$arguments$threshold
)
new_data[[est_nm]] <- cls_pred # todo convert to factor?
new_data
}

#' @export
required_pkgs.equivocal_zone <- function(x, ...) {
c("container", "probably")
}

#' @export
tunable.equivocal_zone <- function(x, ...) {
tibble::new_tibble(list(
name = "buffer",
call_info = list(list(pkg = "dials", fun = "buffer")),
source = "container",
component = "equivocal_zone",
component_id = "equivocal_zone"
))
}

# todo missing methods:
# todo tune_args
# todo tidy
# todo extract_parameter_set_dials
99 changes: 99 additions & 0 deletions R/adjust-numeric-calibration.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
#' Re-calibrate numeric predictions
#'
#' @param x A [container()].
#' @param calibrator A pre-trained calibration method from the \pkg{probably}
#' package, such as [probably::cal_estimate_linear()].
#' @examples
#' library(modeldata)
#' library(probably)
#' library(tibble)
#'
#' # create example data
#' set.seed(1)
#' dat <- tibble(y = rnorm(100), y_pred = y/2 + rnorm(100))
#'
#' dat
#'
#' # calibrate numeric predictions
#' reg_cal <- cal_estimate_linear(dat, truth = y, estimate = y_pred)
#'
#' # specify calibration
#' reg_ctr <-
#' container(mode = "regression") %>%
#' adjust_numeric_calibration(reg_cal)
#'
#' # "train" container
#' reg_ctr_trained <- fit(reg_ctr, dat, outcome = y, estimate = y_pred)
#'
#' predict(reg_ctr, dat)
#' @export
adjust_numeric_calibration <- function(x, calibrator) {
check_container(x)
check_required(calibrator)
if (!inherits(calibrator, "cal_regression")) {
cli_abort(
"{.arg calibrator} should be a \\
{.help [<cal_regression> object](probably::cal_estimate_linear)}, \\
not {.obj_type_friendly {calibrator}}."
)
}

op <-
new_operation(
"numeric_calibration",
inputs = "numeric",
outputs = "numeric",
arguments = list(calibrator = calibrator),
results = list(),
trained = FALSE
)

new_container(
mode = x$mode,
type = x$type,
operations = c(x$operations, list(op)),
columns = x$dat,
ptype = x$ptype,
call = current_env()
)
}

#' @export
print.numeric_calibration <- function(x, ...) {
trn <- ifelse(x$trained, " [trained]", "")
cli::cli_bullets(c("*" = "Re-calibrate numeric predictions.{trn}"))
invisible(x)
}

#' @export
fit.numeric_calibration <- function(object, data, container = NULL, ...) {
new_operation(
class(object),
inputs = object$inputs,
outputs = object$outputs,
arguments = object$arguments,
results = list(),
trained = TRUE
)
}

#' @export
predict.numeric_calibration <- function(object, new_data, container, ...) {
probably::cal_apply(new_data, object$argument$calibrator)
}

# todo probably needs required_pkgs methods for cal objects
#' @export
required_pkgs.numeric_calibration <- function(x, ...) {
c("container", "probably")
}

#' @export
tunable.numeric_calibration <- function(x, ...) {
no_param
}

# todo missing methods:
# todo tune_args
# todo tidy
# todo extract_parameter_set_dials
Loading

0 comments on commit 3874410

Please sign in to comment.