From de3d54d5c41da2540bcc80a47ac0b823b86b9923 Mon Sep 17 00:00:00 2001 From: Christophe Regouby Date: Sun, 20 Jun 2021 13:00:21 +0200 Subject: [PATCH 1/6] add more tests around config$loss shift left the config$loss_fn resolution --- R/model.R | 42 ++++++++------------- tests/testthat/test-hardhat.R | 70 +++++++++++++++++++++++++++++++++++ 2 files changed, 85 insertions(+), 27 deletions(-) diff --git a/R/model.R b/R/model.R index 309028d7..c7fde4ca 100644 --- a/R/model.R +++ b/R/model.R @@ -140,11 +140,26 @@ tabnet_config <- function(batch_size = 256, if (is.null(decision_width)) decision_width <- attention_width + # resolve loss + if (config$loss == "auto") { + if (data$y$dtype == torch::torch_long()) + config$loss <- "cross_entropy" + else + config$loss <- "mse" + } + + if (config$loss == "mse") + config$loss_fn <- torch::nn_mse_loss() + else if (config$loss %in% c("bce", "cross_entropy")) + config$loss_fn <- torch::nn_cross_entropy_loss() + + list( batch_size = batch_size, lambda_sparse = penalty, clip_value = clip_value, loss = loss, + loss_fn = loss_fn, epochs = epochs, drop_last = drop_last, n_d = decision_width, @@ -262,20 +277,6 @@ tabnet_initialize <- function(x, y, config = tabnet_config()) { # training data data <- resolve_data(x, y) - # resolve loss - if (config$loss == "auto") { - if (data$y$dtype == torch::torch_long()) - config$loss <- "cross_entropy" - else - config$loss <- "mse" - } - - if (config$loss == "mse") - config$loss_fn <- torch::nn_mse_loss() - else if (config$loss %in% c("bce", "cross_entropy")) - config$loss_fn <- torch::nn_cross_entropy_loss() - - # create network network <- tabnet_nn( input_dim = data$input_dim, @@ -364,19 +365,6 @@ tabnet_train_supervised <- function(obj, x, y, config = tabnet_config(), epoch_s ) } - if (config$loss == "auto") { - if (data$y$dtype == torch::torch_long()) - config$loss <- "cross_entropy" - else - config$loss <- "mse" - } - - # resolve loss - if (config$loss == "mse") { - config$loss_fn <- torch::nn_mse_loss() - } else if (config$loss %in% c("bce", "cross_entropy")) { - config$loss_fn <- torch::nn_cross_entropy_loss() - } # restore network from model and send it to device network <- obj$fit$network diff --git a/tests/testthat/test-hardhat.R b/tests/testthat/test-hardhat.R index fd3386eb..d0111f80 100644 --- a/tests/testthat/test-hardhat.R +++ b/tests/testthat/test-hardhat.R @@ -356,6 +356,7 @@ test_that("fit works with entmax mask-type", { ) }) + test_that("fit raise an error with non-supported mask-type", { library(recipes) @@ -371,3 +372,72 @@ test_that("fit raise an error with non-supported mask-type", { ) }) + +test_that("config$loss=`auto` adapt to recipe outcome str()", { + + library(recipes) + data("attrition", package = "modeldata") + ids <- sample(nrow(attrition), 256) + + # nominal outcome + rec <- recipe(EnvironmentSatisfaction ~ ., data = attrition[ids, ]) %>% + step_normalize(all_numeric(), -all_outcomes()) + fit_auto <- tabnet_fit(rec, attrition, epochs = 1, verbose = TRUE, + config = tabnet_config( loss="auto")) + expect_equal(fit_auto$fit$config$loss_fn, torch::nn_cross_entropy_loss()) + + # numerical outcome + rec <- recipe(MonthlyIncome ~ ., data = attrition[ids, ]) %>% + step_normalize(all_numeric(), -all_outcomes()) + fit_auto <- tabnet_fit(rec, attrition, epochs = 1, verbose = TRUE, + config = tabnet_config( loss="auto")) + expect_equal(fit_auto$fit$config$loss_fn, torch::nn_mse_loss()) + +}) + +test_that("config$loss not adapted to recipe outcome raise an explicit error", { + + library(recipes) + data("attrition", package = "modeldata") + ids <- sample(nrow(attrition), 256) + + # nominal outcome with numerical loss + rec <- recipe(EnvironmentSatisfaction ~ ., data = attrition[ids, ]) %>% + step_normalize(all_numeric(), -all_outcomes()) + fit_auto <- tabnet_fit(rec, attrition, epochs = 1, verbose = TRUE, + config = tabnet_config( loss="mse")) + expect_equal(fit_auto$fit$config$loss_fn, torch::nn_mse_loss()) + + # numerical outcome + rec <- recipe(MonthlyIncome ~ ., data = attrition[ids, ]) %>% + step_normalize(all_numeric(), -all_outcomes()) + fit_auto <- tabnet_fit(rec, attrition, epochs = 1, verbose = TRUE, + config = tabnet_config( loss="cross_entropy")) + expect_equal(fit_auto$fit$config$loss_fn, torch::nn_cross_entropy_loss()) + +}) + + +test_that("config$loss can be a function", { + + library(recipes) + data("attrition", package = "modeldata") + ids <- sample(nrow(attrition), 256) + + # nominal outcome loss + rec <- recipe(EnvironmentSatisfaction ~ ., data = attrition[ids, ]) %>% + step_normalize(all_numeric(), -all_outcomes()) + fit_auto <- tabnet_fit(rec, attrition, epochs = 1, verbose = TRUE, + config = tabnet_config( loss=torch::nn_kl_div_loss())) + expect_equal(fit_auto$fit$config$loss_fn, torch::nn_kl_div_loss()) + + # numerical outcome loss + rec <- recipe(MonthlyIncome ~ ., data = attrition[ids, ]) %>% + step_normalize(all_numeric(), -all_outcomes()) + fit_auto <- tabnet_fit(rec, attrition, epochs = 1, verbose = TRUE, + config = tabnet_config( loss=torch::nn_smooth_l1_loss())) + expect_equal(fit_auto$fit$config$loss_fn, torch::nn_smooth_l1_loss()) + +}) + + From a8b1c8a84c45d015198d3eca2ea5460162761e89 Mon Sep 17 00:00:00 2001 From: Christophe Regouby Date: Sun, 20 Jun 2021 13:22:17 +0200 Subject: [PATCH 2/6] move resolve_loss as a cunction add consistency check between loss and outcome dtype allow loss as function --- R/model.R | 42 +++++++++++++++++++++++++++--------------- 1 file changed, 27 insertions(+), 15 deletions(-) diff --git a/R/model.R b/R/model.R index c7fde4ca..999c703a 100644 --- a/R/model.R +++ b/R/model.R @@ -140,26 +140,11 @@ tabnet_config <- function(batch_size = 256, if (is.null(decision_width)) decision_width <- attention_width - # resolve loss - if (config$loss == "auto") { - if (data$y$dtype == torch::torch_long()) - config$loss <- "cross_entropy" - else - config$loss <- "mse" - } - - if (config$loss == "mse") - config$loss_fn <- torch::nn_mse_loss() - else if (config$loss %in% c("bce", "cross_entropy")) - config$loss_fn <- torch::nn_cross_entropy_loss() - - list( batch_size = batch_size, lambda_sparse = penalty, clip_value = clip_value, loss = loss, - loss_fn = loss_fn, epochs = epochs, drop_last = drop_last, n_d = decision_width, @@ -186,6 +171,27 @@ tabnet_config <- function(batch_size = 256, ) } +resolve_loss <- function(loss, dtype) { + if (loss == "auto") { + if (dtype == torch::torch_long()) + loss <- "cross_entropy" + else + loss <- "mse" + } + + if (is.function(loss)) + loss_fn <- loss + else if (loss == "mse" && !dtype == torch::torch_long()) + loss_fn <- torch::nn_mse_loss() + else if (loss %in% c("bce", "cross_entropy") && dtype == torch::torch_long()) + loss_fn <- torch::nn_cross_entropy_loss() + else + rlang::abort(paste0(loss," is not available for outcome of type ",dtype)) + + loss_fn +} + + batch_to_device <- function(batch, device) { batch <- list(x = batch$x, y = batch$y) lapply(batch, function(x) { @@ -277,6 +283,9 @@ tabnet_initialize <- function(x, y, config = tabnet_config()) { # training data data <- resolve_data(x, y) + # resolve loss + config$loss_fn <- resolve_loss(config$loss, data$y$dtype) + # create network network <- tabnet_nn( input_dim = data$input_dim, @@ -365,6 +374,9 @@ tabnet_train_supervised <- function(obj, x, y, config = tabnet_config(), epoch_s ) } + # resolve loss + config$loss_fn <- resolve_loss(config$loss, data$y$dtype) + # restore network from model and send it to device network <- obj$fit$network From d289ca01f2fecd71596cfcabc065a4c6495e8f69 Mon Sep 17 00:00:00 2001 From: Christophe Regouby Date: Sun, 20 Jun 2021 13:46:23 +0200 Subject: [PATCH 3/6] get rid of enumerate(dl) --- R/model.R | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/R/model.R b/R/model.R index 999c703a..e9d2afa9 100644 --- a/R/model.R +++ b/R/model.R @@ -440,10 +440,10 @@ tabnet_train_supervised <- function(obj, x, y, config = tabnet_config(), epoch_s network$eval() if (has_valid) { - for (batch in torch::enumerate(valid_dl)) { + coro::loop(for (batch in valid_dl) { m <- valid_batch(network, batch_to_device(batch, device), config) valid_metrics <- c(valid_metrics, m) - } + }) metrics[[epoch]][["valid"]] <- transpose_metrics(valid_metrics) } From f3f2f7719b5d2374b3ac66d20f8b3bc94bf4b5ae Mon Sep 17 00:00:00 2001 From: Christophe Regouby Date: Sun, 20 Jun 2021 13:59:28 +0200 Subject: [PATCH 4/6] fix the test to loss "auto" to not interfeare with function() adapt the tests expectations --- R/model.R | 13 +++---------- tests/testthat/test-hardhat.R | 24 ++++++++++++------------ 2 files changed, 15 insertions(+), 22 deletions(-) diff --git a/R/model.R b/R/model.R index e9d2afa9..1e887825 100644 --- a/R/model.R +++ b/R/model.R @@ -172,21 +172,14 @@ tabnet_config <- function(batch_size = 256, } resolve_loss <- function(loss, dtype) { - if (loss == "auto") { - if (dtype == torch::torch_long()) - loss <- "cross_entropy" - else - loss <- "mse" - } - if (is.function(loss)) loss_fn <- loss - else if (loss == "mse" && !dtype == torch::torch_long()) + else if (loss %in% c("mse", "auto") && !dtype == torch::torch_long()) loss_fn <- torch::nn_mse_loss() - else if (loss %in% c("bce", "cross_entropy") && dtype == torch::torch_long()) + else if (loss %in% c("bce", "cross_entropy", "auto") && dtype == torch::torch_long()) loss_fn <- torch::nn_cross_entropy_loss() else - rlang::abort(paste0(loss," is not available for outcome of type ",dtype)) + rlang::abort(paste0(loss," is not a valid loss for outcome of type ",dtype)) loss_fn } diff --git a/tests/testthat/test-hardhat.R b/tests/testthat/test-hardhat.R index d0111f80..31abc04c 100644 --- a/tests/testthat/test-hardhat.R +++ b/tests/testthat/test-hardhat.R @@ -404,17 +404,17 @@ test_that("config$loss not adapted to recipe outcome raise an explicit error", { # nominal outcome with numerical loss rec <- recipe(EnvironmentSatisfaction ~ ., data = attrition[ids, ]) %>% step_normalize(all_numeric(), -all_outcomes()) - fit_auto <- tabnet_fit(rec, attrition, epochs = 1, verbose = TRUE, - config = tabnet_config( loss="mse")) - expect_equal(fit_auto$fit$config$loss_fn, torch::nn_mse_loss()) - + expect_error(tabnet_fit(rec, attrition, epochs = 1, verbose = TRUE, + config = tabnet_config( loss="mse")), + regexp = "is not a valid loss for outcome of type" + ) # numerical outcome rec <- recipe(MonthlyIncome ~ ., data = attrition[ids, ]) %>% step_normalize(all_numeric(), -all_outcomes()) - fit_auto <- tabnet_fit(rec, attrition, epochs = 1, verbose = TRUE, - config = tabnet_config( loss="cross_entropy")) - expect_equal(fit_auto$fit$config$loss_fn, torch::nn_cross_entropy_loss()) - + expect_error(tabnet_fit(rec, attrition, epochs = 1, verbose = TRUE, + config = tabnet_config( loss="cross_entropy")), + regexp = "is not a valid loss for outcome of type" + ) }) @@ -428,15 +428,15 @@ test_that("config$loss can be a function", { rec <- recipe(EnvironmentSatisfaction ~ ., data = attrition[ids, ]) %>% step_normalize(all_numeric(), -all_outcomes()) fit_auto <- tabnet_fit(rec, attrition, epochs = 1, verbose = TRUE, - config = tabnet_config( loss=torch::nn_kl_div_loss())) - expect_equal(fit_auto$fit$config$loss_fn, torch::nn_kl_div_loss()) + config = tabnet_config( loss=torch::nn_nll_loss())) + expect_equivalent(fit_auto$fit$config$loss_fn, torch::nn_nll_loss()) # numerical outcome loss rec <- recipe(MonthlyIncome ~ ., data = attrition[ids, ]) %>% step_normalize(all_numeric(), -all_outcomes()) fit_auto <- tabnet_fit(rec, attrition, epochs = 1, verbose = TRUE, - config = tabnet_config( loss=torch::nn_smooth_l1_loss())) - expect_equal(fit_auto$fit$config$loss_fn, torch::nn_smooth_l1_loss()) + config = tabnet_config( loss=torch::nn_poisson_nll_loss())) + expect_equivalent(fit_auto$fit$config$loss_fn, torch::nn_poisson_nll_loss()) }) From 0dcb563bc027f31db39655ba9946fee044524d8f Mon Sep 17 00:00:00 2001 From: Christophe Regouby Date: Sun, 20 Jun 2021 14:52:31 +0200 Subject: [PATCH 5/6] add loss as a function PR --- NEWS.md | 1 + 1 file changed, 1 insertion(+) diff --git a/NEWS.md b/NEWS.md index 243e95ee..225c35a9 100644 --- a/NEWS.md +++ b/NEWS.md @@ -9,6 +9,7 @@ * Add autoplot() of model loss among epochs (@cregouby, #36) * Added a `config` argument to fit/pretrain functions so one can pass a pre-made config list. (#42) * Add `mask_type` configuration option with `entmax` additional to `sparsemax` (@cmcmaster1, #48) +* Allow `tabnet_config(loss=)` to be a function (@cregouby, #55) ## Bugfixes From 1353063ea726e31d5fcbd17631808fdda5794c82 Mon Sep 17 00:00:00 2001 From: Christophe Regouby Date: Sun, 20 Jun 2021 17:15:43 +0200 Subject: [PATCH 6/6] improve formatting of NEWS --- NEWS.md | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/NEWS.md b/NEWS.md index 225c35a9..33c6f39f 100644 --- a/NEWS.md +++ b/NEWS.md @@ -2,14 +2,14 @@ ## New features -* Allow model fine-tuning through passing a pre-trained model to tabnet_fit() (@cregouby, #26) +* Allow model fine-tuning through passing a pre-trained model to `tabnet_fit()` (@cregouby, #26) * Explicit error in case of missing values (@cregouby, #24) -* Better handling of larger datasets when running `tabnet_explain`. -* Add tabnet_pretrain() for unsupervised pretraining (@cregouby, #29) -* Add autoplot() of model loss among epochs (@cregouby, #36) -* Added a `config` argument to fit/pretrain functions so one can pass a pre-made config list. (#42) -* Add `mask_type` configuration option with `entmax` additional to `sparsemax` (@cmcmaster1, #48) -* Allow `tabnet_config(loss=)` to be a function (@cregouby, #55) +* Better handling of larger datasets when running `tabnet_explain()`. +* Add `tabnet_pretrain()` for unsupervised pretraining (@cregouby, #29) +* Add `autoplot()` of model loss among epochs (@cregouby, #36) +* Added a `config` argument to `fit() / pretrain()` so one can pass a pre-made config list. (#42) +* In `tabnet_config()`, new `mask_type` option with `entmax` additional to default `sparsemax` (@cmcmaster1, #48) +* In `tabnet_config()`, `loss` now also takes function (@cregouby, #55) ## Bugfixes