Skip to content

Commit

Permalink
stanfit wrapper working for rstan
Browse files Browse the repository at this point in the history
  • Loading branch information
santikka committed Jul 11, 2024
1 parent 41f5ab8 commit 58543f1
Show file tree
Hide file tree
Showing 11 changed files with 186 additions and 29 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ Config/testthat/edition: 3
Encoding: UTF-8
Roxygen: list(markdown = TRUE, roclets = c ("namespace", "rd",
"srr::srr_stats_roclet"))
RoxygenNote: 7.3.1
RoxygenNote: 7.3.2
LazyData: true
LazyDataCompression: xz
Additional_repositories: https://mc-stan.org/r-packages/
4 changes: 4 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
# dynamite 1.5.4

* Model fitting using `cmdstanr` backend no longer relies on `rstan::read_stan_csv()` to construct the fit object. Instead, the resulting `CmdStanMCMC` object is used directly. This should provide a substantial performance improvement in some instances.

# dynamite 1.5.3

* Restored and updated the main package vignette. The vignette now also contains a real data example and information on multiple imputation.
Expand Down
13 changes: 4 additions & 9 deletions R/as_data_table.R
Original file line number Diff line number Diff line change
Expand Up @@ -147,16 +147,11 @@ as.data.table.dynamitefit <- function(x, keep.rownames = FALSE,
""
)
if (type %in% c("xi", "corr_nu", "corr_psi")) {
draws <- rstan::extract(
x$stanfit,
pars = type,
permuted = FALSE
)
draws <- get_draws(x$stanfit, pars = type)
} else {
draws <- rstan::extract(
draws <- get_draws(
x$stanfit,
pars = paste0(type, "_", response, ycat),
permuted = FALSE
pars = paste0(type, "_", response, ycat)
)
}
idx <- which(names(x$stan$responses) %in% response)
Expand Down Expand Up @@ -269,7 +264,7 @@ as.data.table.dynamitefit <- function(x, keep.rownames = FALSE,
any(
grepl(
paste0("^", y["parameter"], "$"),
x$stanfit@sim$pars_oi
get_pars_oi(x$stanfit)
)
)
})
Expand Down
6 changes: 3 additions & 3 deletions R/dynamite.R
Original file line number Diff line number Diff line change
Expand Up @@ -494,11 +494,11 @@ dynamite_sampling <- function(sampling, backend, model_code, model,
dots,
threads_per_chain = onlyif(threads_per_chain > 1L, threads_per_chain)
)
sampling_out <- with(e, {
out <- with(e, {
do.call(model$sample, args)
})
out <- rstan::read_stan_csv(sampling_out$output_files())
out@stanmodel <- methods::new("stanmodel", model_code = model_code)
#out <- rstan::read_stan_csv(sampling_out$output_files())
#out@stanmodel <- methods::new("stanmodel", model_code = model_code)
}
}
out
Expand Down
3 changes: 2 additions & 1 deletion R/getters.R
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ get_code.dynamitefit <- function(x, blocks = NULL, ...) {
...
)$model_code
} else {
out <- x$stanfit@stanmodel@model_code[1L]
out <- get_model_code(x$stanfit)
}
get_code_(out, blocks)
}
Expand Down Expand Up @@ -269,6 +269,7 @@ get_parameter_dims.dynamitefit <- function(x, ...) {
)
pars_text <- get_code(x, blocks = "parameters")
pars <- get_parameters(pars_text)
# TODO no inits
out <- rstan::get_inits(x$stanfit)[[1L]]
out <- out[names(out) %in% pars]
lapply(
Expand Down
2 changes: 1 addition & 1 deletion R/loo.R
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ loo.dynamitefit <- function(x, separate_channels = FALSE, thin = 1L, ...) {
checkmate::test_int(x = thin, lower = 1L, upper = ndraws(x)),
"Argument {.arg thin} must be a single positive {.cls integer}."
)
n_chains <- x$stanfit@sim$chains
n_chains <- get_nchains(x$stanfit)
n_draws <- ndraws(x)
idx_draws <- seq.int(1L, n_draws, by = thin)
# need equal number of samples per chain
Expand Down
13 changes: 7 additions & 6 deletions R/mcmc_diagnostics.R
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ mcmc_diagnostics.dynamitefit <- function(x, n = 3L, ...) {
if (is.null(x$stanfit)) {
cat("No Stan model fit is available.")
} else {
algorithm <- x$stanfit@stan_args[[1L]]$algorithm
algorithm <- get_algorithm(x$stanfit)
stopifnot_(
algorithm %in% c("NUTS", "hmc"),
"MCMC diagnostics are only meaningful for samples from MCMC.
Expand Down Expand Up @@ -104,16 +104,17 @@ hmc_diagnostics.dynamitefit <- function(x, ...) {
if (is.null(x$stanfit)) {
cat("No Stan model fit is available.")
} else {
algorithm <- x$stanfit@stan_args[[1L]]$algorithm
algorithm <- get_algorithm(x$stanfit)
stopifnot_(
algorithm %in% c("NUTS", "hmc"),
"MCMC diagnostics are only meaningful for samples from MCMC.
Model was estimated using the ", algorithm, "algorithm."
)
n_draws <- ndraws(x)
n_divs <- rstan::get_num_divergent(x$stanfit)
n_trees <- rstan::get_num_max_treedepth(x$stanfit)
bfmis <- rstan::get_bfmi(x$stanfit)
diags <- get_diagnostics(x$stanfit)
n_divs <- diags$num_divergent
n_trees <- diags$num_max_treedepth
bfmis <- diags$ebfmi
all_ok <- n_divs == 0L && n_trees == 0L && all(bfmis > 0.2)
cat("NUTS sampler diagnostics:\n")
all_ok_str <- ifelse_(
Expand All @@ -131,7 +132,7 @@ hmc_diagnostics.dynamitefit <- function(x, ...) {
""
)
cat(div_str)
mt <- x$stanfit@stan_args[[1L]]$control$max_treedepth
mt <- get_max_treedepth(x$stanfit)
mt <- ifelse_(is.null(mt), 10, mt)
trees_str <- ifelse_(
n_trees > 0L,
Expand Down
5 changes: 1 addition & 4 deletions R/ndraws.R
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,5 @@ ndraws.dynamitefit <- function(x) {
!is.null(x$stanfit),
"No Stan model fit is available."
)
as.integer(
(x$stanfit@sim$n_save[1L] - x$stanfit@sim$warmup2[1L]) *
x$stanfit@sim$chains
)
as.integer(get_ndraws(x$stanfit))
}
4 changes: 2 additions & 2 deletions R/print.R
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ print.dynamitefit <- function(x, full_diagnostics = FALSE, ...) {
)
if (!is.null(x$stanfit)) {
cat("\n")
mcmc_algorithm <- x$stanfit@stan_args[[1L]]$algorithm %in% c("NUTS", "hmc")
mcmc_algorithm <- get_algorithm(x$stanfit) %in% c("NUTS", "hmc")
if (mcmc_algorithm) {
hmc_diagnostics(x)
}
Expand Down Expand Up @@ -83,7 +83,7 @@ print.dynamitefit <- function(x, full_diagnostics = FALSE, ...) {
sumr$variable[max_rhat], ")",
sep = ""
)
runtimes <- rstan::get_elapsed_time(x$stanfit)
runtimes <- get_elapsed_time(x$stanfit)
if (nrow(runtimes) > 2L) {
rs <- rowSums(runtimes)
cat("\n\nElapsed time (seconds) for fastest and slowest chains:\n")
Expand Down
159 changes: 159 additions & 0 deletions R/stan_utilities.R
Original file line number Diff line number Diff line change
Expand Up @@ -107,3 +107,162 @@ stan_supports_glm_likelihood <- function(family, backend, common_intercept) {
(identical(family$name, "cumulative") && identical(family$link, "logit"))
)
}


# Wrapper methods for backends --------------------------------------------

#' Get `pars_oi` of a Stan model fit
#'
#' @param x A `stanfit` (from `rstan`) or a `CmdStanMCMC`
#' (from `cmdstanr`) object.
#' @noRd
get_pars_oi <- function(x) {
UseMethod("get_pars_oi")
}

#' Get the model code of a Stan model fit
#'
#' @inheritParams get_pars_oi
#' @noRd
get_model_code <- function(x) {
UseMethod("get_model_code")
}

#' Get the number of chains of a Stan model fit
#'
#' @inheritParams get_pars_oi
#' @noRd
get_num_chains <- function(x) {
UseMethod("get_num_chains")
}

#' Get the algorithm used in a Stan model fit
#'
#' @inheritParams get_pars_oi
#' @noRd
get_algorithm <- function(x) {
UseMethod("get_algorithm")
}

#' Get the diagnostics of a Stan model fit
#'
#' @inheritParams get_pars_oi
#' @noRd
get_diagnostics <- function(x) {
UseMethod("get_diagnostics")
}

#' Get the maximum treedepth of chains of a Stan model fit
#'
#' @inheritParams get_pars_oi
#' @noRd
get_max_treedepth <- function(x) {
UseMethod("get_max_treedepth")
}

#' Get the number of draws of a Stan model fit
#'
#' @inheritParams get_pars_oi
#' @noRd
get_ndraws <- function(x) {
UseMethod("get_ndraws")
}

#' Get the draws of a Stan model fit
#'
#' @inheritParams get_pars_oi
#' @noRd
get_draws <- function(x, ...) {
UseMethod("get_draws")
}

#' Get the elapsed time of a Stan model fit
#'
#' @inheritParams get_pars_oi
#' @noRd
get_elapsed_time <- function(x) {
UseMethod("get_elapsed_time")
}

get_pars_oi.stanfit <- function(x) {
x@sim$pars_oi
}

get_pars_oi.CmdStanMCMC <- function(x) {
x$metadata()$stan_variables
}

get_model_code.stanfit <- function(x) {
x@stanmodel@model_code[1L]
}

get_model_code.CmdStanMCMC <- function(x) {
x$code()
}

get_num_chains.stanfit <- function(x) {
x@sim$chains
}

get_num_chains.CmdStanMCMC <- function(x) {
x$num_chains()
}

get_algorithm.stanfit <- function(x) {
x@stan_args[[1L]]$algorithm
}

get_algorithm.CmdStanMCMC <- function(x) {
x$metadata()$algorithm
}

get_diagnostics.stanfit <- function(x) {
list(
num_divergent = rstan::get_num_divergent(x),
num_max_treedepth = rstan::get_num_max_treedepth(x),
ebfmi = rstan::get_bfmi(x)
)
}

get_diagnostics.CmdStanMCMC <- function(x) {
x$diagnostic_summary(x)
}

get_max_treedepth.stanfit <- function(x) {
x@stan_args[[1L]]$control$max_treedepth
}

get_max_treedepth.CmdStanMCMC <- function(x) {
x$metadata()$max_treedepth
}

get_ndraws.stanfit <- function(x) {
(x@sim$n_save[1L] - x@sim$warmup2[1L]) * x@sim$chains
}

get_ndraws.CmdStanMCMC <- function(x) {
m <- x$metadata()
m$iter_sampling * m$num_chains
}

get_draws.stanfit <- function(x, pars) {
posterior::as_draws(
rstan::extract(
x,
pars = pars,
permuted = FALSE
)
)
}

get_draws.CmdStanMCMC <- function(x, pars) {
x$draws(varibles = pars)
}

get_elapsed_time.stanfit <- function(x) {
rstan::get_elapsed_time(x)
}

get_elapsed_time.CmdStanMCMC <- function(x) {
x$time()$chains
}
4 changes: 2 additions & 2 deletions tests/testthat/test-cmdstanr.R
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ test_that("stanc_options argument works", {
gaussian_example,
"time",
"id",
parallel_chains = 2,
chains = 2,
#parallel_chains = 2,
chains = 1,
refresh = 0,
backend = "cmdstanr",
stanc_options = list("O0"),
Expand Down

0 comments on commit 58543f1

Please sign in to comment.