diff --git a/NEWS.md b/NEWS.md index 243e95ee..33c6f39f 100644 --- a/NEWS.md +++ b/NEWS.md @@ -2,13 +2,14 @@ ## New features -* Allow model fine-tuning through passing a pre-trained model to tabnet_fit() (@cregouby, #26) +* Allow model fine-tuning through passing a pre-trained model to `tabnet_fit()` (@cregouby, #26) * Explicit error in case of missing values (@cregouby, #24) -* Better handling of larger datasets when running `tabnet_explain`. -* Add tabnet_pretrain() for unsupervised pretraining (@cregouby, #29) -* Add autoplot() of model loss among epochs (@cregouby, #36) -* Added a `config` argument to fit/pretrain functions so one can pass a pre-made config list. (#42) -* Add `mask_type` configuration option with `entmax` additional to `sparsemax` (@cmcmaster1, #48) +* Better handling of larger datasets when running `tabnet_explain()`. +* Add `tabnet_pretrain()` for unsupervised pretraining (@cregouby, #29) +* Add `autoplot()` of model loss among epochs (@cregouby, #36) +* Added a `config` argument to `fit() / pretrain()` so one can pass a pre-made config list. (#42) +* In `tabnet_config()`, new `mask_type` option with `entmax` additional to default `sparsemax` (@cmcmaster1, #48) +* In `tabnet_config()`, `loss` now also takes function (@cregouby, #55) ## Bugfixes diff --git a/R/model.R b/R/model.R index 309028d7..1e887825 100644 --- a/R/model.R +++ b/R/model.R @@ -171,6 +171,20 @@ tabnet_config <- function(batch_size = 256, ) } +resolve_loss <- function(loss, dtype) { + if (is.function(loss)) + loss_fn <- loss + else if (loss %in% c("mse", "auto") && !dtype == torch::torch_long()) + loss_fn <- torch::nn_mse_loss() + else if (loss %in% c("bce", "cross_entropy", "auto") && dtype == torch::torch_long()) + loss_fn <- torch::nn_cross_entropy_loss() + else + rlang::abort(paste0(loss," is not a valid loss for outcome of type ",dtype)) + + loss_fn +} + + batch_to_device <- function(batch, device) { batch <- list(x = batch$x, y = batch$y) lapply(batch, function(x) { @@ -263,18 +277,7 @@ tabnet_initialize <- function(x, y, config = tabnet_config()) { data <- resolve_data(x, y) # resolve loss - if (config$loss == "auto") { - if (data$y$dtype == torch::torch_long()) - config$loss <- "cross_entropy" - else - config$loss <- "mse" - } - - if (config$loss == "mse") - config$loss_fn <- torch::nn_mse_loss() - else if (config$loss %in% c("bce", "cross_entropy")) - config$loss_fn <- torch::nn_cross_entropy_loss() - + config$loss_fn <- resolve_loss(config$loss, data$y$dtype) # create network network <- tabnet_nn( @@ -364,19 +367,9 @@ tabnet_train_supervised <- function(obj, x, y, config = tabnet_config(), epoch_s ) } - if (config$loss == "auto") { - if (data$y$dtype == torch::torch_long()) - config$loss <- "cross_entropy" - else - config$loss <- "mse" - } - # resolve loss - if (config$loss == "mse") { - config$loss_fn <- torch::nn_mse_loss() - } else if (config$loss %in% c("bce", "cross_entropy")) { - config$loss_fn <- torch::nn_cross_entropy_loss() - } + config$loss_fn <- resolve_loss(config$loss, data$y$dtype) + # restore network from model and send it to device network <- obj$fit$network @@ -440,10 +433,10 @@ tabnet_train_supervised <- function(obj, x, y, config = tabnet_config(), epoch_s network$eval() if (has_valid) { - for (batch in torch::enumerate(valid_dl)) { + coro::loop(for (batch in valid_dl) { m <- valid_batch(network, batch_to_device(batch, device), config) valid_metrics <- c(valid_metrics, m) - } + }) metrics[[epoch]][["valid"]] <- transpose_metrics(valid_metrics) } diff --git a/tests/testthat/test-hardhat.R b/tests/testthat/test-hardhat.R index fd3386eb..31abc04c 100644 --- a/tests/testthat/test-hardhat.R +++ b/tests/testthat/test-hardhat.R @@ -356,6 +356,7 @@ test_that("fit works with entmax mask-type", { ) }) + test_that("fit raise an error with non-supported mask-type", { library(recipes) @@ -371,3 +372,72 @@ test_that("fit raise an error with non-supported mask-type", { ) }) + +test_that("config$loss=`auto` adapt to recipe outcome str()", { + + library(recipes) + data("attrition", package = "modeldata") + ids <- sample(nrow(attrition), 256) + + # nominal outcome + rec <- recipe(EnvironmentSatisfaction ~ ., data = attrition[ids, ]) %>% + step_normalize(all_numeric(), -all_outcomes()) + fit_auto <- tabnet_fit(rec, attrition, epochs = 1, verbose = TRUE, + config = tabnet_config( loss="auto")) + expect_equal(fit_auto$fit$config$loss_fn, torch::nn_cross_entropy_loss()) + + # numerical outcome + rec <- recipe(MonthlyIncome ~ ., data = attrition[ids, ]) %>% + step_normalize(all_numeric(), -all_outcomes()) + fit_auto <- tabnet_fit(rec, attrition, epochs = 1, verbose = TRUE, + config = tabnet_config( loss="auto")) + expect_equal(fit_auto$fit$config$loss_fn, torch::nn_mse_loss()) + +}) + +test_that("config$loss not adapted to recipe outcome raise an explicit error", { + + library(recipes) + data("attrition", package = "modeldata") + ids <- sample(nrow(attrition), 256) + + # nominal outcome with numerical loss + rec <- recipe(EnvironmentSatisfaction ~ ., data = attrition[ids, ]) %>% + step_normalize(all_numeric(), -all_outcomes()) + expect_error(tabnet_fit(rec, attrition, epochs = 1, verbose = TRUE, + config = tabnet_config( loss="mse")), + regexp = "is not a valid loss for outcome of type" + ) + # numerical outcome + rec <- recipe(MonthlyIncome ~ ., data = attrition[ids, ]) %>% + step_normalize(all_numeric(), -all_outcomes()) + expect_error(tabnet_fit(rec, attrition, epochs = 1, verbose = TRUE, + config = tabnet_config( loss="cross_entropy")), + regexp = "is not a valid loss for outcome of type" + ) +}) + + +test_that("config$loss can be a function", { + + library(recipes) + data("attrition", package = "modeldata") + ids <- sample(nrow(attrition), 256) + + # nominal outcome loss + rec <- recipe(EnvironmentSatisfaction ~ ., data = attrition[ids, ]) %>% + step_normalize(all_numeric(), -all_outcomes()) + fit_auto <- tabnet_fit(rec, attrition, epochs = 1, verbose = TRUE, + config = tabnet_config( loss=torch::nn_nll_loss())) + expect_equivalent(fit_auto$fit$config$loss_fn, torch::nn_nll_loss()) + + # numerical outcome loss + rec <- recipe(MonthlyIncome ~ ., data = attrition[ids, ]) %>% + step_normalize(all_numeric(), -all_outcomes()) + fit_auto <- tabnet_fit(rec, attrition, epochs = 1, verbose = TRUE, + config = tabnet_config( loss=torch::nn_poisson_nll_loss())) + expect_equivalent(fit_auto$fit$config$loss_fn, torch::nn_poisson_nll_loss()) + +}) + +