Skip to content

Commit

Permalink
Handle fitting failures on individual slices
Browse files Browse the repository at this point in the history
  • Loading branch information
jpfitzinger committed Jan 21, 2023
1 parent e763a32 commit e721b80
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 8 deletions.
2 changes: 1 addition & 1 deletion R/fit_groups.R
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
test_samples <- rsample::complement(splits)
res_row$model_object[[1]]$set_args(weights = wts[train_samples])
res_row$model_object[[1]]$fit(df_train)
if (nrow(df_test) > 0 & !row$return_slices) {
if (nrow(df_test) > 0 & !row$return_slices & !is.null(res_row$model_object[[1]]$object)) {
pred <- predict.tidyfit.models(res_row, df_test, .keep_grid_id = TRUE)
metrics <- .eval_metrics(pred, res_row$model_object[[1]]$mode, weights = wts[test_samples])
} else {
Expand Down
14 changes: 7 additions & 7 deletions R/post_process.R
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,14 @@
.mask, .weights, gr_vars) {
if (!.return_slices & .cv == "none") {
df <- df %>%
dplyr::select(-.data$slice_id)
dplyr::select(-"slice_id")
}

if (.cv != "none") {
# Select optimal hyperparameter setting
df_no_cv <- df %>%
dplyr::filter(!sapply(df$model_object, function(mod) mod$cv)) %>%
dplyr::select(-.data$slice_id)
dplyr::select(-"slice_id")

df <- df %>%
dplyr::filter(sapply(df$model_object, function(mod) mod$cv))
Expand All @@ -29,24 +29,24 @@
if (!all(is.na(df_slices$metric))) {
df_slices <- df_slices %>%
dplyr::group_by(.data$model, .data$grid_id, .add = TRUE) %>%
dplyr::mutate(metric = mean(.data$metric)) %>%
dplyr::mutate(metric = mean(.data$metric, na.rm = TRUE)) %>%
dplyr::ungroup(.data$grid_id) %>%
dplyr::filter(.data$metric == min(.data$metric)) %>%
dplyr::filter(.data$metric == min(.data$metric, na.rm = TRUE)) %>%
dplyr::filter(.data$grid_id == unique(.data$grid_id)[1])
}

if (.return_slices) {
df <- df_slices %>%
dplyr::bind_rows(df_no_cv) %>%
dplyr::select(-.data$metric)
dplyr::select(-"metric")
} else {
df <- df_slices %>%
dplyr::ungroup() %>%
dplyr::select(!!gr_vars, .data$grid_id, .data$model) %>%
dplyr::select(!!gr_vars, "grid_id", "model") %>%
dplyr::distinct() %>%
dplyr::left_join(df %>% dplyr::ungroup() %>% dplyr::filter(.data$slice_id == "FULL"), by = c(gr_vars, "grid_id", "model")) %>%
dplyr::bind_rows(df_no_cv) %>%
dplyr::select(-.data$metric, -.data$slice_id)
dplyr::select(-"metric", -"slice_id")
}
}
}
Expand Down

0 comments on commit e721b80

Please sign in to comment.