Skip to content

Commit

Permalink
Merge pull request #55 from mlverse/bugfix/#35_other_losses
Browse files Browse the repository at this point in the history
Bugfix/#35 other losses
  • Loading branch information
dfalbel authored Jun 22, 2021
2 parents b4b6765 + 1353063 commit 2d93f60
Show file tree
Hide file tree
Showing 3 changed files with 96 additions and 32 deletions.
13 changes: 7 additions & 6 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@

## New features

* Allow model fine-tuning through passing a pre-trained model to tabnet_fit() (@cregouby, #26)
* Allow model fine-tuning through passing a pre-trained model to `tabnet_fit()` (@cregouby, #26)
* Explicit error in case of missing values (@cregouby, #24)
* Better handling of larger datasets when running `tabnet_explain`.
* Add tabnet_pretrain() for unsupervised pretraining (@cregouby, #29)
* Add autoplot() of model loss among epochs (@cregouby, #36)
* Added a `config` argument to fit/pretrain functions so one can pass a pre-made config list. (#42)
* Add `mask_type` configuration option with `entmax` additional to `sparsemax` (@cmcmaster1, #48)
* Better handling of larger datasets when running `tabnet_explain()`.
* Add `tabnet_pretrain()` for unsupervised pretraining (@cregouby, #29)
* Add `autoplot()` of model loss among epochs (@cregouby, #36)
* Added a `config` argument to `fit() / pretrain()` so one can pass a pre-made config list. (#42)
* In `tabnet_config()`, new `mask_type` option with `entmax` additional to default `sparsemax` (@cmcmaster1, #48)
* In `tabnet_config()`, `loss` now also takes function (@cregouby, #55)

## Bugfixes

Expand Down
45 changes: 19 additions & 26 deletions R/model.R
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,20 @@ tabnet_config <- function(batch_size = 256,
)
}

resolve_loss <- function(loss, dtype) {
if (is.function(loss))
loss_fn <- loss
else if (loss %in% c("mse", "auto") && !dtype == torch::torch_long())
loss_fn <- torch::nn_mse_loss()
else if (loss %in% c("bce", "cross_entropy", "auto") && dtype == torch::torch_long())
loss_fn <- torch::nn_cross_entropy_loss()
else
rlang::abort(paste0(loss," is not a valid loss for outcome of type ",dtype))

loss_fn
}


batch_to_device <- function(batch, device) {
batch <- list(x = batch$x, y = batch$y)
lapply(batch, function(x) {
Expand Down Expand Up @@ -263,18 +277,7 @@ tabnet_initialize <- function(x, y, config = tabnet_config()) {
data <- resolve_data(x, y)

# resolve loss
if (config$loss == "auto") {
if (data$y$dtype == torch::torch_long())
config$loss <- "cross_entropy"
else
config$loss <- "mse"
}

if (config$loss == "mse")
config$loss_fn <- torch::nn_mse_loss()
else if (config$loss %in% c("bce", "cross_entropy"))
config$loss_fn <- torch::nn_cross_entropy_loss()

config$loss_fn <- resolve_loss(config$loss, data$y$dtype)

# create network
network <- tabnet_nn(
Expand Down Expand Up @@ -364,19 +367,9 @@ tabnet_train_supervised <- function(obj, x, y, config = tabnet_config(), epoch_s
)
}

if (config$loss == "auto") {
if (data$y$dtype == torch::torch_long())
config$loss <- "cross_entropy"
else
config$loss <- "mse"
}

# resolve loss
if (config$loss == "mse") {
config$loss_fn <- torch::nn_mse_loss()
} else if (config$loss %in% c("bce", "cross_entropy")) {
config$loss_fn <- torch::nn_cross_entropy_loss()
}
config$loss_fn <- resolve_loss(config$loss, data$y$dtype)

# restore network from model and send it to device
network <- obj$fit$network

Expand Down Expand Up @@ -440,10 +433,10 @@ tabnet_train_supervised <- function(obj, x, y, config = tabnet_config(), epoch_s

network$eval()
if (has_valid) {
for (batch in torch::enumerate(valid_dl)) {
coro::loop(for (batch in valid_dl) {
m <- valid_batch(network, batch_to_device(batch, device), config)
valid_metrics <- c(valid_metrics, m)
}
})
metrics[[epoch]][["valid"]] <- transpose_metrics(valid_metrics)
}

Expand Down
70 changes: 70 additions & 0 deletions tests/testthat/test-hardhat.R
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,7 @@ test_that("fit works with entmax mask-type", {
)

})

test_that("fit raise an error with non-supported mask-type", {

library(recipes)
Expand All @@ -371,3 +372,72 @@ test_that("fit raise an error with non-supported mask-type", {
)

})

test_that("config$loss=`auto` adapt to recipe outcome str()", {

library(recipes)
data("attrition", package = "modeldata")
ids <- sample(nrow(attrition), 256)

# nominal outcome
rec <- recipe(EnvironmentSatisfaction ~ ., data = attrition[ids, ]) %>%
step_normalize(all_numeric(), -all_outcomes())
fit_auto <- tabnet_fit(rec, attrition, epochs = 1, verbose = TRUE,
config = tabnet_config( loss="auto"))
expect_equal(fit_auto$fit$config$loss_fn, torch::nn_cross_entropy_loss())

# numerical outcome
rec <- recipe(MonthlyIncome ~ ., data = attrition[ids, ]) %>%
step_normalize(all_numeric(), -all_outcomes())
fit_auto <- tabnet_fit(rec, attrition, epochs = 1, verbose = TRUE,
config = tabnet_config( loss="auto"))
expect_equal(fit_auto$fit$config$loss_fn, torch::nn_mse_loss())

})

test_that("config$loss not adapted to recipe outcome raise an explicit error", {

library(recipes)
data("attrition", package = "modeldata")
ids <- sample(nrow(attrition), 256)

# nominal outcome with numerical loss
rec <- recipe(EnvironmentSatisfaction ~ ., data = attrition[ids, ]) %>%
step_normalize(all_numeric(), -all_outcomes())
expect_error(tabnet_fit(rec, attrition, epochs = 1, verbose = TRUE,
config = tabnet_config( loss="mse")),
regexp = "is not a valid loss for outcome of type"
)
# numerical outcome
rec <- recipe(MonthlyIncome ~ ., data = attrition[ids, ]) %>%
step_normalize(all_numeric(), -all_outcomes())
expect_error(tabnet_fit(rec, attrition, epochs = 1, verbose = TRUE,
config = tabnet_config( loss="cross_entropy")),
regexp = "is not a valid loss for outcome of type"
)
})


test_that("config$loss can be a function", {

library(recipes)
data("attrition", package = "modeldata")
ids <- sample(nrow(attrition), 256)

# nominal outcome loss
rec <- recipe(EnvironmentSatisfaction ~ ., data = attrition[ids, ]) %>%
step_normalize(all_numeric(), -all_outcomes())
fit_auto <- tabnet_fit(rec, attrition, epochs = 1, verbose = TRUE,
config = tabnet_config( loss=torch::nn_nll_loss()))
expect_equivalent(fit_auto$fit$config$loss_fn, torch::nn_nll_loss())

# numerical outcome loss
rec <- recipe(MonthlyIncome ~ ., data = attrition[ids, ]) %>%
step_normalize(all_numeric(), -all_outcomes())
fit_auto <- tabnet_fit(rec, attrition, epochs = 1, verbose = TRUE,
config = tabnet_config( loss=torch::nn_poisson_nll_loss()))
expect_equivalent(fit_auto$fit$config$loss_fn, torch::nn_poisson_nll_loss())

})


0 comments on commit 2d93f60

Please sign in to comment.