diff --git a/NAMESPACE b/NAMESPACE index d6b6b37f..973e19dc 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -11,12 +11,14 @@ S3method(handler_predict,default) S3method(handler_predict,glm) S3method(handler_predict,lm) S3method(handler_predict,ranger) +S3method(handler_predict,tabnet_fit) S3method(handler_predict,train) S3method(handler_predict,workflow) S3method(handler_predict,xgb.Booster) S3method(handler_startup,Learner) S3method(handler_startup,default) S3method(handler_startup,ranger) +S3method(handler_startup,tabnet_fit) S3method(handler_startup,train) S3method(handler_startup,workflow) S3method(handler_startup,xgb.Booster) @@ -32,12 +34,14 @@ S3method(vetiver_create_description,default) S3method(vetiver_create_description,glm) S3method(vetiver_create_description,lm) S3method(vetiver_create_description,ranger) +S3method(vetiver_create_description,tabnet_fit) S3method(vetiver_create_description,train) S3method(vetiver_create_description,workflow) S3method(vetiver_create_description,xgb.Booster) S3method(vetiver_create_meta,Learner) S3method(vetiver_create_meta,default) S3method(vetiver_create_meta,ranger) +S3method(vetiver_create_meta,tabnet_fit) S3method(vetiver_create_meta,train) S3method(vetiver_create_meta,workflow) S3method(vetiver_create_meta,xgb.Booster) @@ -46,6 +50,7 @@ S3method(vetiver_prepare_model,default) S3method(vetiver_prepare_model,glm) S3method(vetiver_prepare_model,lm) S3method(vetiver_prepare_model,ranger) +S3method(vetiver_prepare_model,tabnet_fit) S3method(vetiver_prepare_model,train) S3method(vetiver_prepare_model,workflow) S3method(vetiver_ptype,Learner) @@ -53,6 +58,7 @@ S3method(vetiver_ptype,default) S3method(vetiver_ptype,glm) S3method(vetiver_ptype,lm) S3method(vetiver_ptype,ranger) +S3method(vetiver_ptype,tabnet_fit) S3method(vetiver_ptype,train) S3method(vetiver_ptype,workflow) S3method(vetiver_ptype,xgb.Booster) diff --git a/NEWS.md b/NEWS.md index 7e1b4a73..3a214578 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,5 +1,7 @@ # vetiver (development version) +* Added support for tabnet models (#124 @cregouby) + * Trailing slashes are now removed from `vetiver_endpoint()` (#134). # vetiver 0.1.7 diff --git a/R/tabnet.R b/R/tabnet.R new file mode 100644 index 00000000..5f398d78 --- /dev/null +++ b/R/tabnet.R @@ -0,0 +1,54 @@ +#' @rdname vetiver_create_description +#' @export +vetiver_create_description.tabnet_fit <- function(model) { + paste0( + "A tabnet `nn_module` containing ", + format(sum(sapply(model$fit$network$parameters, function(x) prod(x$shape))), nsmall = 0, big.mark = ",", scientific = FALSE), + " parameters." + ) +} + +#' @rdname vetiver_create_meta +#' @export +vetiver_create_meta.tabnet_fit <- function(model, metadata) { + vetiver_meta(metadata, required_pkgs = "tabnet") +} + +#' @rdname vetiver_create_description +#' @export +vetiver_prepare_model.tabnet_fit <- function(model) { + butcher::butcher(model) +} + +#' @rdname vetiver_create_ptype +#' @export +vetiver_ptype.tabnet_fit <- function(model, ...) { + rlang::check_dots_used() + dots <- list(...) + check_ptype_data(dots) + ptype <- vctrs::vec_ptype(dots$ptype_data) + tibble::as_tibble(ptype) +} + +#' @rdname handler_startup +#' @export +handler_startup.tabnet_fit <- function(vetiver_model) { + attach_pkgs("tabnet") +} + +#' @rdname handler_startup +#' @export +handler_predict.tabnet_fit <- function(vetiver_model, ...) { + + ptype <- vetiver_model$blueprint$ptypes + + function(req) { + new_data <- req$body + if (!is_null(ptype)) { + new_data <- vetiver_type_convert(new_data, ptype) + new_data <- hardhat::scream(new_data, ptype) + } + ret <- predict(vetiver_model, data = new_data, ...) + list(.pred = ret$.pred) + } +} diff --git a/man/handler_startup.Rd b/man/handler_startup.Rd index 64ba6b9d..83bd98e3 100644 --- a/man/handler_startup.Rd +++ b/man/handler_startup.Rd @@ -1,6 +1,6 @@ % Generated by roxygen2: do not edit by hand % Please edit documentation in R/caret.R, R/glm.R, R/handlers.R, R/lm.R, -% R/mlr3.R, R/ranger.R, R/tidymodels.R, R/xgboost.R +% R/mlr3.R, R/ranger.R, R/tabnet.R, R/tidymodels.R, R/xgboost.R \name{handler_startup.train} \alias{handler_startup.train} \alias{handler_predict.train} @@ -14,6 +14,8 @@ \alias{handler_predict.Learner} \alias{handler_startup.ranger} \alias{handler_predict.ranger} +\alias{handler_startup.tabnet_fit} +\alias{handler_predict.tabnet_fit} \alias{handler_startup.workflow} \alias{handler_predict.workflow} \alias{handler_startup.xgb.Booster} @@ -44,6 +46,10 @@ handler_predict(vetiver_model, ...) \method{handler_predict}{ranger}(vetiver_model, ...) +\method{handler_startup}{tabnet_fit}(vetiver_model) + +\method{handler_predict}{tabnet_fit}(vetiver_model, ...) + \method{handler_startup}{workflow}(vetiver_model) \method{handler_predict}{workflow}(vetiver_model, ...) diff --git a/man/vetiver-package.Rd b/man/vetiver-package.Rd index 030ae4d6..f4314bb7 100644 --- a/man/vetiver-package.Rd +++ b/man/vetiver-package.Rd @@ -6,7 +6,7 @@ \alias{vetiver-package} \title{vetiver: Version, Share, Deploy, and Monitor Models} \description{ -\if{html}{\figure{logo.png}{options: style='float: right' alt='logo' width='120'}} +\if{html}{\figure{logo.png}{options: align='right' alt='logo' width='120'}} The goal of 'vetiver' is to provide fluent tooling to version, share, deploy, and monitor a trained model. Functions handle both recording and checking the model's input data prototype, and predicting from a remote API endpoint. The 'vetiver' package is extensible, with generics that can support many kinds of models. } diff --git a/man/vetiver_compute_metrics.Rd b/man/vetiver_compute_metrics.Rd index dd07eda6..1cb7926d 100644 --- a/man/vetiver_compute_metrics.Rd +++ b/man/vetiver_compute_metrics.Rd @@ -71,7 +71,15 @@ epoch time of \verb{1970-01-01 00:00:00}, \emph{in the time zone of the index}. This is generally used to define the anchor time to count from, which is relevant when the every value is \verb{> 1}.} -\item{before, after}{\verb{[integer(1) / Inf]} +\item{before}{\verb{[integer(1) / Inf]} + +The number of values before or after the current element to +include in the sliding window. Set to \code{Inf} to select all elements +before or after the current element. Negative values are allowed, which +allows you to "look forward" from the current element if used as the +\code{.before} value, or "look backwards" if used as \code{.after}.} + +\item{after}{\verb{[integer(1) / Inf]} The number of values before or after the current element to include in the sliding window. Set to \code{Inf} to select all elements diff --git a/man/vetiver_create_description.Rd b/man/vetiver_create_description.Rd index 051a8566..17af36de 100644 --- a/man/vetiver_create_description.Rd +++ b/man/vetiver_create_description.Rd @@ -1,6 +1,6 @@ % Generated by roxygen2: do not edit by hand % Please edit documentation in R/caret.R, R/glm.R, R/lm.R, R/mlr3.R, -% R/prepare.R, R/ranger.R, R/tidymodels.R, R/xgboost.R +% R/prepare.R, R/ranger.R, R/tabnet.R, R/tidymodels.R, R/xgboost.R \name{vetiver_create_description.train} \alias{vetiver_create_description.train} \alias{vetiver_prepare_model.train} @@ -16,6 +16,8 @@ \alias{vetiver_prepare_model.default} \alias{vetiver_create_description.ranger} \alias{vetiver_prepare_model.ranger} +\alias{vetiver_create_description.tabnet_fit} +\alias{vetiver_prepare_model.tabnet_fit} \alias{vetiver_create_description.workflow} \alias{vetiver_prepare_model.workflow} \alias{vetiver_create_description.xgb.Booster} @@ -49,6 +51,10 @@ vetiver_prepare_model(model) \method{vetiver_prepare_model}{ranger}(model) +\method{vetiver_create_description}{tabnet_fit}(model) + +\method{vetiver_prepare_model}{tabnet_fit}(model) + \method{vetiver_create_description}{workflow}(model) \method{vetiver_prepare_model}{workflow}(model) diff --git a/man/vetiver_create_meta.Rd b/man/vetiver_create_meta.Rd index 7af45fa6..43b1299a 100644 --- a/man/vetiver_create_meta.Rd +++ b/man/vetiver_create_meta.Rd @@ -1,6 +1,6 @@ % Generated by roxygen2: do not edit by hand % Please edit documentation in R/caret.R, R/meta.R, R/mlr3.R, R/ranger.R, -% R/tidymodels.R, R/xgboost.R +% R/tabnet.R, R/tidymodels.R, R/xgboost.R \name{vetiver_create_meta.train} \alias{vetiver_create_meta.train} \alias{vetiver_meta} @@ -8,6 +8,7 @@ \alias{vetiver_create_meta.default} \alias{vetiver_create_meta.Learner} \alias{vetiver_create_meta.ranger} +\alias{vetiver_create_meta.tabnet_fit} \alias{vetiver_create_meta.workflow} \alias{vetiver_create_meta.xgb.Booster} \title{Metadata constructors for \code{vetiver_model()} object} @@ -24,6 +25,8 @@ vetiver_create_meta(model, metadata) \method{vetiver_create_meta}{ranger}(model, metadata) +\method{vetiver_create_meta}{tabnet_fit}(model, metadata) + \method{vetiver_create_meta}{workflow}(model, metadata) \method{vetiver_create_meta}{xgb.Booster}(model, metadata) diff --git a/man/vetiver_create_ptype.Rd b/man/vetiver_create_ptype.Rd index f1f43191..51dd2326 100644 --- a/man/vetiver_create_ptype.Rd +++ b/man/vetiver_create_ptype.Rd @@ -1,6 +1,6 @@ % Generated by roxygen2: do not edit by hand % Please edit documentation in R/caret.R, R/glm.R, R/lm.R, R/mlr3.R, R/ptype.R, -% R/ranger.R, R/tidymodels.R, R/xgboost.R +% R/ranger.R, R/tabnet.R, R/tidymodels.R, R/xgboost.R \name{vetiver_ptype.train} \alias{vetiver_ptype.train} \alias{vetiver_ptype.glm} @@ -10,6 +10,7 @@ \alias{vetiver_ptype.default} \alias{vetiver_create_ptype} \alias{vetiver_ptype.ranger} +\alias{vetiver_ptype.tabnet_fit} \alias{vetiver_ptype.workflow} \alias{vetiver_ptype.xgb.Booster} \title{Create a vetiver input data prototype} @@ -30,6 +31,8 @@ vetiver_create_ptype(model, save_ptype, ...) \method{vetiver_ptype}{ranger}(model, ...) +\method{vetiver_ptype}{tabnet_fit}(model, ...) + \method{vetiver_ptype}{workflow}(model, ...) \method{vetiver_ptype}{xgb.Booster}(model, ...) @@ -80,9 +83,4 @@ vetiver_ptype(cars_lm) ## can also turn off `ptype` vetiver_create_ptype(cars_lm, FALSE) -\dontshow{if (rlang::is_installed("ranger")) (if (getRversion() >= "3.4") withAutoprint else force)(\{ # examplesIf} -## some models require that you pass in training features -cars_rf <- ranger::ranger(mpg ~ ., data = mtcars) -vetiver_ptype(cars_rf, ptype_data = mtcars[,-1]) -\dontshow{\}) # examplesIf} } diff --git a/man/vetiver_create_rsconnect_bundle.Rd b/man/vetiver_create_rsconnect_bundle.Rd index be440534..acd107d8 100644 --- a/man/vetiver_create_rsconnect_bundle.Rd +++ b/man/vetiver_create_rsconnect_bundle.Rd @@ -46,23 +46,6 @@ The two functions \code{vetiver_create_rsconnect_bundle()} and \code{\link[=vetiver_deploy_rsconnect]{vetiver_deploy_rsconnect()}} are alternatives to each other, providing different strategies for deploying a vetiver model API to RStudio Connect. } -\examples{ -\dontshow{if (rlang::is_installed("connectapi")) (if (getRversion() >= "3.4") withAutoprint else force)(\{ # examplesIf} -library(pins) -b <- board_temp(versioned = TRUE) -cars_lm <- lm(mpg ~ ., data = mtcars) -v <- vetiver_model(cars_lm, "cars_linear") -vetiver_pin_write(b, v) - -## when you pin to RStudio Connect, your pin name will be typically be like: -## "user.name/cars_linear" -vetiver_create_rsconnect_bundle( - b, - "cars_linear", - predict_args = list(debug = TRUE) -) -\dontshow{\}) # examplesIf} -} \seealso{ \code{\link[=vetiver_write_plumber]{vetiver_write_plumber()}}, \code{\link[=vetiver_deploy_rsconnect]{vetiver_deploy_rsconnect()}} } diff --git a/man/vetiver_dashboard.Rd b/man/vetiver_dashboard.Rd index 7f66a962..8dcf25e8 100644 --- a/man/vetiver_dashboard.Rd +++ b/man/vetiver_dashboard.Rd @@ -41,7 +41,7 @@ helper function \code{pin_example_kc_housing_model()} to set up demonstration model and metrics pins needed for the monitoring demo. This function will: \itemize{ \item fit an example model to training data -\item pin the vetiver model to your own \code{\link[pins:board_folder]{pins::board_local()}} +\item pin the vetiver model to your own \code{\link[pins:board_local]{pins::board_local()}} \item compute metrics from testing data \item pin these metrics to the same local board } diff --git a/man/vetiver_deploy_rsconnect.Rd b/man/vetiver_deploy_rsconnect.Rd index 457d04d1..110bd2ef 100644 --- a/man/vetiver_deploy_rsconnect.Rd +++ b/man/vetiver_deploy_rsconnect.Rd @@ -44,31 +44,6 @@ The two functions \code{vetiver_deploy_rsconnect()} and \code{\link[=vetiver_create_rsconnect_bundle]{vetiver_create_rsconnect_bundle()}} are alternatives to each other, providing different strategies for deploying a vetiver model API to RStudio Connect. } -\examples{ -\dontshow{if (rlang::is_installed("rsconnect")) (if (getRversion() >= "3.4") withAutoprint else force)(\{ # examplesIf} -library(pins) -b <- board_temp(versioned = TRUE) -cars_lm <- lm(mpg ~ ., data = mtcars) -v <- vetiver_model(cars_lm, "cars_linear") -vetiver_pin_write(b, v) - -if (FALSE) { -## pass args for predicting: -vetiver_deploy_rsconnect( - b, - "user.name/cars_linear", - predict_args = list(debug = TRUE) -) - -## specify an account name through `...`: -vetiver_deploy_rsconnect( - b, - "user.name/cars_linear", - account = "user.name" -) -} -\dontshow{\}) # examplesIf} -} \seealso{ \code{\link[=vetiver_write_plumber]{vetiver_write_plumber()}}, \code{\link[=vetiver_create_rsconnect_bundle]{vetiver_create_rsconnect_bundle()}} } diff --git a/man/vetiver_pr_predict.Rd b/man/vetiver_pr_predict.Rd index 071c7b9b..532799b2 100644 --- a/man/vetiver_pr_predict.Rd +++ b/man/vetiver_pr_predict.Rd @@ -24,7 +24,7 @@ vetiver_pr_predict( \item{...}{Other arguments passed to \code{predict()}, such as prediction \code{type}} } \description{ -\ifelse{html}{\href{https://lifecycle.r-lib.org/articles/stages.html#deprecated}{\figure{lifecycle-deprecated.svg}{options: alt='[Deprecated]'}}}{\strong{[Deprecated]}} +\verb{r lifecycle::badge("deprecated")} This function was deprecated to use \link{vetiver_api} directly instead. } diff --git a/man/vetiver_write_docker.Rd b/man/vetiver_write_docker.Rd index ff5b9ade..e792ddc4 100644 --- a/man/vetiver_write_docker.Rd +++ b/man/vetiver_write_docker.Rd @@ -45,22 +45,3 @@ After creating a Plumber file with \code{\link[=vetiver_write_plumber]{vetiver_w \code{vetiver_write_docker()} to create a Dockerfile plus a \code{vetiver_renv.lock} file for a pinned \code{\link[=vetiver_model]{vetiver_model()}}. } -\examples{ -\dontshow{if (interactive() || identical(Sys.getenv("IN_PKGDOWN"), "true")) (if (getRversion() >= "3.4") withAutoprint else force)(\{ # examplesIf} - -library(pins) -tmp_plumber <- tempfile() -b <- board_temp(versioned = TRUE) -cars_lm <- lm(mpg ~ ., data = mtcars) -v <- vetiver_model(cars_lm, "cars_linear") -vetiver_pin_write(b, v) -vetiver_write_plumber(b, "cars_linear", file = tmp_plumber) - -## default port -vetiver_write_docker(v, tmp_plumber, tempdir()) -## port from env variable -vetiver_write_docker(v, tmp_plumber, tempdir(), - port = 'as.numeric(Sys.getenv("PORT"))', - expose = FALSE) -\dontshow{\}) # examplesIf} -} diff --git a/tests/testthat/_snaps/tabnet.md b/tests/testthat/_snaps/tabnet.md new file mode 100644 index 00000000..6933bc3d --- /dev/null +++ b/tests/testthat/_snaps/tabnet.md @@ -0,0 +1,41 @@ +# can print tabnet model + + Code + v + Output + + -- cars3 - model for deployment + A tabnet `nn_module` containing 6,301 parameters. using 10 features + +# error for no ptype_data with tabnet + + Code + vetiver_model(car_tn, "cars3") + Condition + Error in `vetiver_create_description()`: + ! object 'car_tn' not found + +# create plumber.R for tabnet + + Code + cat(readr::read_lines(tmp), sep = "\n") + Output + # Generated by the vetiver package; edit with care + + library(pins) + library(plumber) + library(rapidoc) + library(vetiver) + + # Packages needed to generate model predictions + if (FALSE) { + library(tabnet) + } + b <- board_folder(path = "") + v <- vetiver_pin_read(b, "cars3") + + #* @plumber + function(pr) { + pr %>% vetiver_api(v) + } + diff --git a/tests/testthat/test-tabnet.R b/tests/testthat/test-tabnet.R new file mode 100644 index 00000000..9209bade --- /dev/null +++ b/tests/testthat/test-tabnet.R @@ -0,0 +1,74 @@ +library(pins) +library(plumber) +skip_if_not_installed("tabnet") + +set.seed(321) +cars_tn <- tabnet::tabnet_fit(mpg ~ ., data = mtcars, epoch=30) +v <- vetiver_model(cars_tn, "cars3", ptype_data = mtcars[,-1]) + +test_that("can print tabnet model", { + expect_snapshot(v) +}) + +test_that("error for no ptype_data with tabnet", { + expect_snapshot(vetiver_model(car_tn, "cars3"), error = TRUE) +}) + +test_that("can predict tabnet model", { + preds <- predict(v, mtcars[,-1]) + expect_equal(length(preds$.pred), 32) + expect_equal(mean(preds$.pred), 44, tolerance = 0.1) +}) + + +test_that("can pin an tabnet model", { + b <- board_temp() + vetiver_pin_write(b, v) + pinned <- pin_read(b, "cars3") + expect_equal(pinned$model, butcher::butcher(cars_tn)) + expect_equal( + pinned$ptype, + vctrs::vec_slice(tibble::as_tibble(mtcars[,-1]), 0) + ) + expect_equal( + pinned$required_pkgs, + "tabnet" + ) +}) + +test_that("default endpoint for tabnet", { + p <- pr() %>% vetiver_api(v) + p_routes <- p$routes[-1] + expect_equal(names(p_routes), c("ping", "predict")) + expect_equal(purrr::map_chr(p_routes, "verbs"), + c(ping = "GET", predict = "POST")) +}) + +test_that("default OpenAPI spec", { + v$metadata <- list(url = "potatoes") + p <- pr() %>% vetiver_api(v) + car_spec <- p$getApiSpec() + expect_equal(car_spec$info$description, + "A tabnet `nn_module` containing 6,301 parameters.") + post_spec <- car_spec$paths$`/predict`$post + expect_equal(names(post_spec), c("summary", "requestBody", "responses")) + expect_equal(as.character(post_spec$summary), + "Return predictions from model using 10 features") + get_spec <- car_spec$paths$`/pin-url`$get + expect_equal(as.character(get_spec$summary), + "Get URL of pinned vetiver model") + +}) + +test_that("create plumber.R for tabnet", { + skip_on_cran() + b <- board_folder(path = tmp_dir) + vetiver_pin_write(b, v) + tmp <- tempfile() + vetiver_write_plumber(b, "cars3", file = tmp) + expect_snapshot( + cat(readr::read_lines(tmp), sep = "\n"), + transform = redact_vetiver + ) +}) +