Skip to content

Commit

Permalink
Merge pull request #1376 from dajmcdon/strings2factors
Browse files Browse the repository at this point in the history
Strings2factors
  • Loading branch information
EmilHvitfeldt authored Oct 1, 2024
2 parents 3b12331 + 85b39ff commit 104d30d
Show file tree
Hide file tree
Showing 7 changed files with 106 additions and 32 deletions.
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

* `recipe()`, `prep()`, and `bake()` now work with sparse matrices. (#1364, #1368, #1369)

* `prep.recipe(..., strings_as_factors = TRUE)` now only converts string variables that have role "predictor" or "outcome". (@dajmcdon, #1358, #1376)

# recipes 1.1.0

## Improvements
Expand Down
48 changes: 31 additions & 17 deletions R/misc.R
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ get_rhs_vars <- function(formula, data, no_lhs = FALSE) {
## `inline_check` stops when in-line functions are used.

outcomes_names <- all.names(
rlang::f_lhs(formula),
functions = FALSE,
rlang::f_lhs(formula),
functions = FALSE,
unique = TRUE
)

Expand All @@ -33,7 +33,7 @@ get_rhs_vars <- function(formula, data, no_lhs = FALSE) {
functions = FALSE,
unique = TRUE
)

if (any(predictors_names == ".")) {
predictors_names <- predictors_names[predictors_names != "."]
predictors_names <- c(predictors_names, colnames(data))
Expand Down Expand Up @@ -182,6 +182,20 @@ has_lvls <- function(info) {
!vapply(info, function(x) all(is.na(x$values)), c(logic = TRUE))
}

kill_levels <- function(lvls, var_info) {
vars <- var_info$variable
roles <- var_info$role
preds_outcomes <- unique(vars[roles %in% c("outcome", "predictor")])
others <- unique(setdiff(vars, preds_outcomes))
if (length(others) > 0L) {
for (var in others) {
lvls[[var]] <- list(values = NA, ordered = NA)
}
}
lvls
}


strings2factors <- function(x, info) {
check_lvls <- has_lvls(info)
if (!any(check_lvls)) {
Expand Down Expand Up @@ -279,10 +293,10 @@ merge_term_info <- function(.new, .old) {
#' supported by all steps.
#'
#' @param ... Arguments pass in from a call to `step`.
#'
#' @return `ellipse_check()`: If not empty, a list of quosures. If empty, an
#'
#' @return `ellipse_check()`: If not empty, a list of quosures. If empty, an
#' error is thrown.
#'
#'
#' @keywords internal
#' @rdname recipes-internal
#' @export
Expand Down Expand Up @@ -311,9 +325,9 @@ ellipse_check <- function(...) {
#' recipe (e.g. `terms` in most steps).
#' @param trained A logical for whether the step has been trained.
#' @param width An integer denoting where the output should be wrapped.
#'
#'
#' @return `printer()`: `NULL`, invisibly.
#'
#'
#' @keywords internal
#' @rdname recipes-internal
#' @export
Expand Down Expand Up @@ -502,16 +516,16 @@ check_type <- function(dat, quant = TRUE, types = NULL, call = caller_env()) {
## Support functions

#' Check to see if a step or check as been trained
#'
#'
#' `is_trained()` is a helper function that returned a single logical to
#' indicate whether a recipe is traine or not.
#'
#'
#' @param x a step object.
#' @return `is_trained()`: A single logical.
#'
#'
#' @seealso [developer_functions]
#' @keywords internal
#'
#'
#' @rdname recipes-internal
#' @export
is_trained <- function(x) {
Expand All @@ -521,14 +535,14 @@ is_trained <- function(x) {

#' Convert Selectors to Character
#'
#' `sel2char()` takes a list of selectors (e.g. `terms` in most steps) and
#' `sel2char()` takes a list of selectors (e.g. `terms` in most steps) and
#' returns a character vector version for printing.
#'
#'
#' @param x A list of selectors
#' @return `sel2char()`: A character vector.
#'
#'
#' @seealso [developer_functions]
#'
#'
#' @keywords internal
#' @rdname recipes-internal
#' @export
Expand Down Expand Up @@ -967,7 +981,7 @@ recipes_remove_cols <- function(new_data, object, col_names = character()) {
#' This helper function is meant to be used in `prep()` methods to identify
#' predictors and outcomes by names.
#'
#' @param info data.frame with variable information of columns.
#' @param info data.frame with variable information of columns.
#'
#' @return Character vector of column names.
#' @keywords internal
Expand Down
22 changes: 12 additions & 10 deletions R/recipe.R
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ recipe.default <- function(x, ...) {
#' Dots are allowed as are simple multivariate outcome terms (i.e. no need for
#' `cbind`; see Examples). A model formula may not be the best choice for
#' high-dimensional data with many columns, because of problems with memory.
#' @param x,data A data frame, tibble, or sparse matrix from the `Matrix`
#' @param x,data A data frame, tibble, or sparse matrix from the `Matrix`
#' package of the *template* data set. See [sparse_data] for more information
#' about use of sparse data.
#' (see below).
Expand Down Expand Up @@ -321,8 +321,8 @@ prep <- function(x, ...) {
#' For a recipe with at least one preprocessing operation, estimate the required
#' parameters from a training set that can be later applied to other data
#' sets.
#' @param training A data frame, tibble, or sparse matrix from the `Matrix`
#' package, that will be used to estimate parameters for preprocessing. See
#' @param training A data frame, tibble, or sparse matrix from the `Matrix`
#' package, that will be used to estimate parameters for preprocessing. See
#' [sparse_data] for more information about use of sparse data.
#' @param fresh A logical indicating whether already trained operation should be
#' re-trained. If `TRUE`, you should pass in a data set to the argument
Expand All @@ -339,9 +339,10 @@ prep <- function(x, ...) {
#' the final recipe size large. When `verbose = TRUE`, a message is written
#' with the approximate object size in memory but may be an underestimate
#' since it does not take environments into account.
#' @param strings_as_factors A logical: should character columns be converted to
#' factors? This affects the preprocessed training set (when
#' `retain = TRUE`) as well as the results of `bake.recipe`.
#' @param strings_as_factors A logical: should character columns that have role
#' "predictor" or "outcome" be converted to factors? This affects the
#' preprocessed training set (when `retain = TRUE`) as well as the results of
#' `bake.recipe`.
#' @return A recipe whose step objects have been updated with the required
#' quantities (e.g. parameter estimates, model objects, etc). Also, the
#' `term_info` object is likely to be modified as the operations are
Expand Down Expand Up @@ -403,9 +404,9 @@ prep.recipe <-

# Record the original levels for later checking
orig_lvls <- lapply(training, get_levels)

if (strings_as_factors) {
lvls <- lapply(training, get_levels)
lvls <- kill_levels(lvls, x$var_info)
training <- strings2factors(training, lvls)
} else {
lvls <- NULL
Expand Down Expand Up @@ -545,6 +546,7 @@ prep.recipe <-
## The steps may have changed the data so reassess the levels
if (strings_as_factors) {
lvls <- lapply(training, get_levels)
lvls <- kill_levels(lvls, x$term_info)
check_lvls <- has_lvls(lvls)
if (!any(check_lvls)) lvls <- NULL
} else {
Expand Down Expand Up @@ -604,10 +606,10 @@ bake <- function(object, ...) {
#' [prep()], apply the computations to new data.
#' @param object A trained object such as a [recipe()] with at least
#' one preprocessing operation.
#' @param new_data A data frame, tibble, or sparse matrix from the `Matrix`
#' package for whom the preprocessing will be applied. If `NULL` is given to
#' @param new_data A data frame, tibble, or sparse matrix from the `Matrix`
#' package for whom the preprocessing will be applied. If `NULL` is given to
#' `new_data`, the pre-processed _training data_ will be returned (assuming
#' that `prep(retain = TRUE)` was used). See [sparse_data] for more
#' that `prep(retain = TRUE)` was used). See [sparse_data] for more
#' information about use of sparse data.
#' @param ... One or more selector functions to choose which variables will be
#' returned by the function. See [selections()] for more details.
Expand Down
7 changes: 3 additions & 4 deletions man/roles.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion tests/testthat/_snaps/cut.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
Error in `step_cut()`:
Caused by error in `prep()`:
x All columns selected for the step should be double or integer.
* 1 factor variable found: `cat_var`
* 1 string variable found: `cat_var`

---

Expand Down
3 changes: 3 additions & 0 deletions tests/testthat/test-dummy_multi_choice.R
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,9 @@ test_that("factor levels are preserved", {
# Infrastructure ---------------------------------------------------------------

test_that("bake method errors when needed non-standard role columns are missing", {
# lang_1 is not converted automatically because it has a non-standard role
# but it is used like a factor variable. See also `?step_string2factor`
languages <- languages %>% mutate(lang_1 = factor(lang_1))
rec <- recipe(~., data = languages) %>%
step_dummy_multi_choice(lang_1, lang_2, lang_3) %>%
update_role(lang_1, new_role = "potato") %>%
Expand Down
54 changes: 54 additions & 0 deletions tests/testthat/test-misc.R
Original file line number Diff line number Diff line change
Expand Up @@ -55,3 +55,57 @@ test_that("validate_training_data errors are thrown", {
prep(mtcars)
)
})

test_that("vars without role in predictor/outcome avoid string processing", {

x <- tibble(
real_pred = 1:5,
chr_pred_and_lime = letters[1:5],
chr_outcome = letters[1:5],
chr_only_lemon = letters[1:5],
chr_only_lime = letters[1:5]
)
var_info <- tibble(variable = names(x), source = "original")
var_info <- full_join(get_types(x), var_info, by = "variable")
var_info$role <- c("predictor", "predictor", "outcome", "lemon", "lime")
additional_row <- var_info[2, ]
additional_row$role <- "lime"
var_info <- var_info %>% add_row(additional_row)

orig_lvls <- lapply(x, get_levels)
training <- strings2factors(x, orig_lvls)
original_expectation <- c(FALSE, rep(TRUE, 4))
names(original_expectation) <- names(x)
expect_identical(has_lvls(orig_lvls), original_expectation)
expect_identical(orig_lvls$real_pred, list(values = NA, ordered = NA))
expect_identical(
orig_lvls$chr_pred_and_lime,
list(values = letters[1:5], ordered = FALSE, factor = FALSE)
)
expect_identical(
orig_lvls$chr_outcome,
list(values = letters[1:5], ordered = FALSE, factor = FALSE)
)
expect_identical(
orig_lvls$chr_only_lemon, # gets converted to fctr
list(values = letters[1:5], ordered = FALSE, factor = FALSE)
)
expect_identical(
orig_lvls$chr_only_lime, # gets converted to fctr
list(values = letters[1:5], ordered = FALSE, factor = FALSE)
)


new_lvls <- kill_levels(orig_lvls, var_info)
new_expect <- original_expectation
new_expect[4:5] <- FALSE
expect_identical(has_lvls(new_lvls), new_expect)
expect_identical(new_lvls$real_pred, orig_lvls$real_pred)
# chr predictor gets converted, despite also having another role
expect_identical(new_lvls$chr_pred_and_lime, orig_lvls$chr_pred_and_lime)
expect_identical(new_lvls$chr_outcome, orig_lvls$chr_outcome)
# non-predictor / non-outcome var remains chr, we don't log the levels
expect_identical(new_lvls$chr_only_lemon, list(values = NA, ordered = NA))
expect_identical(new_lvls$chr_only_lime, list(values = NA, ordered = NA))
})

0 comments on commit 104d30d

Please sign in to comment.