Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/tabnet model #124

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,14 @@ S3method(handler_predict,default)
S3method(handler_predict,glm)
S3method(handler_predict,lm)
S3method(handler_predict,ranger)
S3method(handler_predict,tabnet_fit)
S3method(handler_predict,train)
S3method(handler_predict,workflow)
S3method(handler_predict,xgb.Booster)
S3method(handler_startup,Learner)
S3method(handler_startup,default)
S3method(handler_startup,ranger)
S3method(handler_startup,tabnet_fit)
S3method(handler_startup,train)
S3method(handler_startup,workflow)
S3method(handler_startup,xgb.Booster)
Expand All @@ -32,12 +34,14 @@ S3method(vetiver_create_description,default)
S3method(vetiver_create_description,glm)
S3method(vetiver_create_description,lm)
S3method(vetiver_create_description,ranger)
S3method(vetiver_create_description,tabnet_fit)
S3method(vetiver_create_description,train)
S3method(vetiver_create_description,workflow)
S3method(vetiver_create_description,xgb.Booster)
S3method(vetiver_create_meta,Learner)
S3method(vetiver_create_meta,default)
S3method(vetiver_create_meta,ranger)
S3method(vetiver_create_meta,tabnet_fit)
S3method(vetiver_create_meta,train)
S3method(vetiver_create_meta,workflow)
S3method(vetiver_create_meta,xgb.Booster)
Expand All @@ -46,13 +50,15 @@ S3method(vetiver_prepare_model,default)
S3method(vetiver_prepare_model,glm)
S3method(vetiver_prepare_model,lm)
S3method(vetiver_prepare_model,ranger)
S3method(vetiver_prepare_model,tabnet_fit)
S3method(vetiver_prepare_model,train)
S3method(vetiver_prepare_model,workflow)
S3method(vetiver_ptype,Learner)
S3method(vetiver_ptype,default)
S3method(vetiver_ptype,glm)
S3method(vetiver_ptype,lm)
S3method(vetiver_ptype,ranger)
S3method(vetiver_ptype,tabnet_fit)
S3method(vetiver_ptype,train)
S3method(vetiver_ptype,workflow)
S3method(vetiver_ptype,xgb.Booster)
Expand Down
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# vetiver (development version)

* Added support for tabnet models (#124 @cregouby)

* Trailing slashes are now removed from `vetiver_endpoint()` (#134).

# vetiver 0.1.7
Expand Down
54 changes: 54 additions & 0 deletions R/tabnet.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
#' @rdname vetiver_create_description
#' @export
vetiver_create_description.tabnet_fit <- function(model) {
paste0(
"A tabnet `nn_module` containing ",
format(sum(sapply(model$fit$network$parameters, function(x) prod(x$shape))), nsmall = 0, big.mark = ",", scientific = FALSE),
" parameters."
)
}

#' @rdname vetiver_create_meta
#' @export
vetiver_create_meta.tabnet_fit <- function(model, metadata) {
vetiver_meta(metadata, required_pkgs = "tabnet")
}

#' @rdname vetiver_create_description
#' @export
vetiver_prepare_model.tabnet_fit <- function(model) {
butcher::butcher(model)
}

#' @rdname vetiver_create_ptype
#' @export
vetiver_ptype.tabnet_fit <- function(model, ...) {
rlang::check_dots_used()
dots <- list(...)
check_ptype_data(dots)
ptype <- vctrs::vec_ptype(dots$ptype_data)
tibble::as_tibble(ptype)
}

#' @rdname handler_startup
#' @export
handler_startup.tabnet_fit <- function(vetiver_model) {
attach_pkgs("tabnet")
}

#' @rdname handler_startup
#' @export
handler_predict.tabnet_fit <- function(vetiver_model, ...) {

ptype <- vetiver_model$blueprint$ptypes

function(req) {
new_data <- req$body
if (!is_null(ptype)) {
new_data <- vetiver_type_convert(new_data, ptype)
new_data <- hardhat::scream(new_data, ptype)
}
ret <- predict(vetiver_model, data = new_data, ...)
list(.pred = ret$.pred)
}
}
8 changes: 7 additions & 1 deletion man/handler_startup.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion man/vetiver-package.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

10 changes: 9 additions & 1 deletion man/vetiver_compute_metrics.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 7 additions & 1 deletion man/vetiver_create_description.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 4 additions & 1 deletion man/vetiver_create_meta.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

10 changes: 4 additions & 6 deletions man/vetiver_create_ptype.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

17 changes: 0 additions & 17 deletions man/vetiver_create_rsconnect_bundle.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion man/vetiver_dashboard.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

25 changes: 0 additions & 25 deletions man/vetiver_deploy_rsconnect.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion man/vetiver_pr_predict.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

19 changes: 0 additions & 19 deletions man/vetiver_write_docker.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

41 changes: 41 additions & 0 deletions tests/testthat/_snaps/tabnet.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# can print tabnet model

Code
v
Output

-- cars3 - <tabnet_fit> model for deployment
A tabnet `nn_module` containing 6,301 parameters. using 10 features

# error for no ptype_data with tabnet

Code
vetiver_model(car_tn, "cars3")
Condition
Error in `vetiver_create_description()`:
! object 'car_tn' not found

# create plumber.R for tabnet

Code
cat(readr::read_lines(tmp), sep = "\n")
Output
# Generated by the vetiver package; edit with care

library(pins)
library(plumber)
library(rapidoc)
library(vetiver)

# Packages needed to generate model predictions
if (FALSE) {
library(tabnet)
}
b <- board_folder(path = "<redacted>")
v <- vetiver_pin_read(b, "cars3")

#* @plumber
function(pr) {
pr %>% vetiver_api(v)
}

Loading