Skip to content

Commit

Permalink
first implementation of Interprestability (WiP)
Browse files Browse the repository at this point in the history
  • Loading branch information
cregouby committed Dec 21, 2023
1 parent 5e84ab6 commit fb431c5
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 14 deletions.
27 changes: 25 additions & 2 deletions R/explain.R
Original file line number Diff line number Diff line change
Expand Up @@ -99,13 +99,34 @@ explain_impl <- function(network, x, x_na_mask, with_stability = FALSE) {
c(M_explain_emb_dim, masks_emb_dim) %<-% network$forward_masks(x, x_na_mask)

if (with_stability) {
# TODO
# Compute InterpreStability value through 5-group, the lazy way
# define 5 sampled id groups
obs_group <- lapply(1:5, FUN = sample.int, n = nrow(x), size = ceiling(nrow(x)/5))
x_group <- map(obs_group, ~x[.x,] )
x_na_mask_group <- map(obs_group, ~x_na_mask[.x,] )

M_explain_mask_lst <- map2(x_group, x_na_mask_group, network$forward_masks)
M_explain <- map(M_explain_mask_lst, ~tabnet:::sum_embedding_masks(
mask = .x[[1]],
input_dim = network$input_dim,
cat_idx = network$cat_idxs,
cat_emb_dim = network$cat_emb_dim
)
)
m <- map(M_explain, ~as.numeric(as.matrix(.x$sum(dim = 1)$detach()$to(device = "cpu"))))
compute_feature_importance <- map(m, ~.x/sum(.x))

corr_mat <- torch::torch_stack(compute_feature_importance, dim = 1) |> as.matrix() |> t() |> cor(method = "pearson")
corr_mat <- corr_mat * (torch::torch_ones_like(corr_mat) - torch::torch_eye(n = dim(corr_mat)[1] ))
interprestability <- (corr_mat[corr_mat >=0.9]$sum() +
.8 * corr_mat[corr_mat < 0.9 & corr_mat >= 0.7]$sum() +
.6 * corr_mat[corr_mat < 0.7 & corr_mat >= 0.5]$sum() +
.4 * corr_mat[corr_mat < 0.5 & corr_mat >= 0.3]$sum() +
.2 * corr_mat[corr_mat < 0.3]$sum()) / (corr_mat$shape[1] * (corr_mat$shape[2] - 1))

} else {
interprestability <- NULL
}
# summarize the categorical embedding into 1 column
# per variable
Expand All @@ -124,13 +145,15 @@ explain_impl <- function(network, x, x_na_mask, with_stability = FALSE) {
cat_emb_dim = network$cat_emb_dim
)

list(M_explain = M_explain$to(device="cpu"), masks = to_device(masks, "cpu"))
list(M_explain = M_explain$to(device="cpu"),
masks = to_device(masks, "cpu"),
interprestability = interprestability)
}

compute_feature_importance <- function(network, x, x_na_mask) {
out <- explain_impl(network, x, x_na_mask)
m <- as.numeric(as.matrix(out$M_explain$sum(dim = 1)$detach()$to(device = "cpu")))
m/sum(m)
list(importance = m/sum(m), interprestability = out$interprestability)
}

# sum embeddings, taking their sizes into account.
Expand Down
14 changes: 9 additions & 5 deletions R/model.R
Original file line number Diff line number Diff line change
Expand Up @@ -676,22 +676,26 @@ tabnet_train_supervised <- function(obj, x, y, config = interpretabnet_config(),
1, train_ds$.length(), min(importance_sample_size, train_ds$.length()),
dtype = torch::torch_long()
))
importances <- tibble::tibble(
variables = colnames(x),
importance = compute_feature_importance(
cfi <- compute_feature_importance(
network,
train_ds$.getbatch(batch =indexes)$x$to(device = "cpu"),
train_ds$.getbatch(batch =indexes)$x_na_mask$to(device = "cpu"))
train_ds$.getbatch(batch =indexes)$x_na_mask$to(device = "cpu")
)
importances <- tibble::tibble(
variables = colnames(x),
importance = cfi$importance
)
} else {
importances <- NULL
cfi <- list(interprestability = NULL)
}
list(
network = network,
metrics = metrics,
config = config,
checkpoints = checkpoints,
importances = importances
importances = importances,
interprestability = cfi$interprestability
)
}

Expand Down
14 changes: 8 additions & 6 deletions R/pretraining.R
Original file line number Diff line number Diff line change
Expand Up @@ -247,20 +247,22 @@ tabnet_train_unsupervised <- function(x, config = tabnet_config(), epoch_shift =
1, train_ds$.length(), min(importance_sample_size, train_ds$.length()),
dtype = torch::torch_long()
))
cfi <- compute_feature_importance(
network,
train_ds$.getbatch(batch =indexes)$x$to(device = "cpu"),
train_ds$.getbatch(batch =indexes)$x_na_mask$to(device = "cpu")
)
importances <- tibble::tibble(
variables = colnames(x),
importance = compute_feature_importance(
network,
train_ds$.getbatch(batch =indexes)$x$to(device = "cpu"),
train_ds$.getbatch(batch =indexes)$x_na_mask$to(device = "cpu")
)
importance = cfi$importance
)

list(
network = network,
metrics = metrics,
config = config,
checkpoints = checkpoints,
importances = importances
importances = importances,
interprestability = cfi$interprestability
)
}
2 changes: 1 addition & 1 deletion tests/testthat/test-explain.R
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ test_that("explain provides correct result with data.frame", {

ex <- tabnet_explain(fit, x)

expect_length(ex, 2)
expect_length(ex, 3)
expect_length(ex[[2]], 1)
expect_equal(nrow(ex[[1]]), nrow(x))
expect_equal(nrow(ex[[2]][[1]]), nrow(x))
Expand Down

0 comments on commit fb431c5

Please sign in to comment.