Skip to content

Commit

Permalink
units tests for tunable postproc and some work on #194
Browse files Browse the repository at this point in the history
  • Loading branch information
topepo committed Dec 2, 2024
1 parent 6354f79 commit 241771f
Show file tree
Hide file tree
Showing 2 changed files with 160 additions and 0 deletions.
25 changes: 25 additions & 0 deletions tests/testthat/helper-tunable.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
check_tunable <- function(x) {
expect_equal(names(x), c("name", "call_info", "source", "component", "component_id"))
expect_equal(class(x$name), "character")
expect_equal(class(x$call_info), "list")
expect_equal(class(x$source), "character")
expect_equal(class(x$component), "character")
expect_equal(class(x$component_id), "character")

for (i in seq_along(x$call_info)) {
check_call_info(x$call_info[[i]])
}

invisible(TRUE)
}

check_call_info <- function(x) {
if (all(is.null(x))) {
# it is possible that engine parameter do not have call info
return(invisible(TRUE))
}
expect_true(all(c("pkg", "fun") %in% names(x)))
expect_equal(class(x$pkg), "character")
expect_equal(class(x$fun), "character")
invisible(TRUE)
}
135 changes: 135 additions & 0 deletions tests/testthat/test-generics.R
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,138 @@ test_that("can compute required packages of a workflow - recipes", {

expect_true("pkg" %in% generics::required_pkgs(workflow))
})

# ------------------------------------------------------------------------------
# tunable()

test_that("workflow with no tunable parameters", {
skip_if_not_installed("modeldata")
library(modeldata)
data("Chicago")

rm_rec <- recipes::recipe(ridership ~ ., data = head(Chicago)) %>%
recipes::step_rm(date, ends_with("away"))
lm_model <- parsnip::linear_reg() %>% parsnip::set_engine("lm")
wflow_untunable <- workflow(rm_rec, lm_model)

wflow_info <- tunable(wflow_untunable)
check_tunable(wflow_info)
expect_equal(nrow(wflow_info), 0)
})


test_that("extract tuning from workflow with tunable recipe", {
skip_if_not_installed("modeldata")
library(modeldata)
data("Chicago")

spline_rec <- recipes::recipe(ridership ~ ., data = head(Chicago)) %>%
recipes::step_date(date) %>%
recipes::step_holiday(date) %>%
recipes::step_rm(date, ends_with("away")) %>%
recipes::step_impute_knn(recipes::all_predictors(),
neighbors = hardhat::tune("imputation")) %>%
recipes::step_other(recipes::all_nominal(), threshold = hardhat::tune()) %>%
recipes::step_dummy(recipes::all_nominal()) %>%
recipes::step_normalize(recipes::all_predictors()) %>%
recipes::step_bs(recipes::all_predictors(),
deg_free = hardhat::tune(), degree = hardhat::tune())
lm_model <- parsnip::linear_reg() %>%
parsnip::set_engine("lm")
wflow_tunable_recipe <- workflow(spline_rec, lm_model)

wflow_info <- tunable(wflow_tunable_recipe)
check_tunable(wflow_info)
expect_true(all(wflow_info$source == "recipe"))
})

test_that("extract tuning from workflow with tunable model", {
skip_if_not_installed("modeldata")
library(modeldata)
data("Chicago")

rm_rec <- recipes::recipe(ridership ~ ., data = head(Chicago)) %>%
recipes::step_rm(date, ends_with("away"))
bst_model <-
parsnip::boost_tree(mode = "classification", trees = hardhat::tune("funky name \n")) %>%
parsnip::set_engine("C5.0", rules = hardhat::tune(), noGlobalPruning = TRUE)
wflow_tunable_model <- workflow(rm_rec, bst_model)

wflow_info <- tunable(wflow_tunable_model)
check_tunable(wflow_info)
expect_true(all(wflow_info$source == "model_spec"))
})

test_that("extract tuning from workflow with tunable postprocessor", {
wflow <- workflow()
wflow <- add_recipe(wflow, recipes::recipe(mpg ~ ., mtcars))
wflow <- add_model(wflow, parsnip::linear_reg())
wflow <- add_tailor(
wflow,
tailor::tailor() %>%
tailor::adjust_numeric_range(lower_limit = hardhat::tune())
)

wflow_info <- tunable(wflow)

check_tunable(wflow_info)
expect_true(all(wflow_info$source == "tailor"))
})

test_that("extract tuning from workflow with tunable recipe and model", {
skip_if_not_installed("modeldata")
library(modeldata)
data("Chicago")

spline_rec <- recipes::recipe(ridership ~ ., data = head(Chicago)) %>%
recipes::step_date(date) %>%
recipes::step_holiday(date) %>%
recipes::step_rm(date, ends_with("away")) %>%
recipes::step_impute_knn(recipes::all_predictors(),
neighbors = hardhat::tune("imputation")) %>%
recipes::step_other(recipes::all_nominal(), threshold = hardhat::tune()) %>%
recipes::step_dummy(recipes::all_nominal()) %>%
recipes::step_normalize(recipes::all_predictors()) %>%
recipes::step_bs(recipes::all_predictors(),
deg_free = hardhat::tune(), degree = hardhat::tune())
bst_model <-
parsnip::boost_tree(mode = "classification", trees = hardhat::tune("funky name \n")) %>%
parsnip::set_engine("C5.0", rules = hardhat::tune(), noGlobalPruning = TRUE)
wflow_tunable <- workflow(spline_rec, bst_model)

wflow_info <- tunable(wflow_tunable)
check_tunable(wflow_info)
expect_equal(
sort(unique(wflow_info$source)),
c("model_spec", "recipe")
)
})

test_that("extract tuning from workflow with tunable recipe, model, and tailor", {
wflow <- workflow()
wflow <- add_recipe(
wflow,
recipes::recipe(mpg ~ ., mtcars) %>%
recipes::step_impute_knn(
recipes::all_predictors(),
neighbors = hardhat::tune("imputation")
)
)
wflow <- add_model(
wflow,
parsnip::linear_reg(engine = "glmnet", penalty = tune())
)
wflow <- add_tailor(
wflow,
tailor::tailor() %>%
tailor::adjust_numeric_range(lower_limit = hardhat::tune())
)

wflow_info <- tunable(wflow)

check_tunable(wflow_info)
expect_equal(
sort(unique(wflow_info$source)),
c("model_spec", "recipe", "tailor")
)
})

0 comments on commit 241771f

Please sign in to comment.