Skip to content

Commit

Permalink
fix fixed intercept cumulative, don't require ordered factor
Browse files Browse the repository at this point in the history
  • Loading branch information
santikka committed Apr 26, 2024
1 parent 888208b commit 4538705
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 51 deletions.
4 changes: 2 additions & 2 deletions R/prepare_stan_input.R
Original file line number Diff line number Diff line change
Expand Up @@ -1184,10 +1184,10 @@ prepare_channel_student <- function(y, Y, channel, sampling,
prepare_channel_cumulative <- function(y, Y, channel, sampling,
sd_x, resp_class, priors) {
stopifnot_(
all(c("ordered", "factor") %in% resp_class),
"factor" %in% resp_class,
c(
"Response variable {.var {y}} is invalid:",
`x` = "Cumulative family supports only {.cls ordered factor} variables."
`x` = "Cumulative family supports only {.cls factor} variables."
)
)
resp_levels <- attr(resp_class, "levels")
Expand Down
2 changes: 2 additions & 0 deletions R/stan_utilities.R
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ stan_array <- function(backend, type, name, arr_dims,
)
)
}

#' Create A Backward Compatible Stan Array for Function Arguments
#' @noRd
stan_array_arg <- function(backend, type, name, n_dims = 0, data = FALSE) {
Expand All @@ -53,6 +54,7 @@ stan_array_arg <- function(backend, type, name, n_dims = 0, data = FALSE) {
paste0(data, type, "[",commas, "] ", name)
)
}

#' Is Array Keyword Syntax Supported By Current Stan Version
#'
#' @param backend Either `"rstan"` or `"cmdstanr"`.
Expand Down
66 changes: 34 additions & 32 deletions R/stanblocks.R
Original file line number Diff line number Diff line change
Expand Up @@ -362,37 +362,6 @@ create_parameters_lines <- function(idt, backend, cvars, cgvars) {
cgvars$univariate <- univariate
}
lines_wrap("parameters", family, idt, backend, cgvars)
} else if (is_cumulative(family)) {
# the linear predictor without intercept
has_varying_intercept <- cvars[[1L]]$has_varying_intercept
cvars[[1L]]$has_fixed_intercept <- FALSE
cvars[[1L]]$has_varying_intercept <- FALSE
par_main <- lines_wrap(
"parameters", "default", idt, backend, cvars[[1L]]
)
# time-varying intercepts only
cvars[[1L]]$has_random_intercept <- FALSE
cvars[[1L]]$has_fixed <- FALSE
cvars[[1L]]$has_varying <- FALSE
cvars[[1L]]$has_random <- FALSE
cvars[[1L]]$has_lfactor <- FALSE
cvars[[1L]]$has_varying_intercept <- has_varying_intercept
par_alpha <- ulapply(
seq_len(cvars[[1L]]$S - 1L),
function(s) {
cvars[[1L]]$ydim <- cvars[[1L]]$y
cvars[[1L]]$y <- paste0(cvars[[1L]]$y, "_", s)
cvars[[1L]]$pos_omega_alpha <- s > 1L
lines_wrap(
"parameters", "default", idt, backend, cvars[[1L]]
)
}
)
cvars[[1L]]$default <- paste_rows(
par_main,
par_alpha,
.parse = FALSE
)
} else {
if (is_categorical(family)) {
cvars[[1L]]$default <- lapply(
Expand All @@ -405,7 +374,40 @@ create_parameters_lines <- function(idt, backend, cvars, cgvars) {
)
}
)
} else {
} else if (is_cumulative(family)) {
# the linear predictor without intercept
def_args <- cvars[[1L]]
has_varying_intercept <- def_args$has_varying_intercept
def_args$has_fixed_intercept <- FALSE
def_args$has_varying_intercept <- FALSE
par_main <- lines_wrap(
"parameters", "default", idt, backend, def_args
)
# time-varying intercepts only
def_args$has_random_intercept <- FALSE
def_args$has_fixed <- FALSE
def_args$has_varying <- FALSE
def_args$has_random <- FALSE
def_args$has_lfactor <- FALSE
def_args$has_varying_intercept <- has_varying_intercept
par_alpha <- ulapply(
seq_len(def_args$S - 1L),
function(s) {
def_args$ydim <- def_args$y
def_args$y <- paste0(def_args$y, "_", s)
def_args$pos_omega_alpha <- s > 1L
lines_wrap(
"parameters", "default", idt, backend, def_args
)
}
)
cvars[[1L]]$default <- paste_rows(
par_main,
par_alpha,
.parse = FALSE
)
}
else {
cvars[[1L]]$default <- lines_wrap(
"parameters", "default", idt, backend, cvars[[1L]]
)
Expand Down
46 changes: 29 additions & 17 deletions R/stanblocks_families.R
Original file line number Diff line number Diff line change
Expand Up @@ -1742,23 +1742,35 @@ transformed_parameters_lines_categorical <- function(default, idt, ...) {
}

transformed_parameters_lines_cumulative <- function(y, categories,
default, idt, ...) {
S <- length(categories)
declare_alpha <- glue::glue("array[T] ordered[S_{y} - 1] alpha_y;")
alpha_loop <- vapply(
seq(2L, S - 1L),
function(s) {
glue::glue("alpha_{y}[t, {s}] = alpha_{y}[t, {s - 1}] + alpha_{y}_{s}[t];")
},
character(1L)
)
state_alpha <- paste_rows(
"for (t in 1:T) {{",
"alpha_{y}[t, 1] = alpha_{y}_1[t];",
alpha_loop,
"}}",
.indent = idt(c(1, 2, 2, 1))
)
has_varying_intercept,
default, idt,
backend, ...) {
declare_alpha <- ""
state_alpha <- ""
if (has_varying_intercept) {
S <- length(categories)
declare_alpha <- glue::glue(
stan_array(
backend, "real", "alpha_{y}", "T", "", "S_{y} - 1"
)
)
alpha_loop <- vapply(
seq(2L, S - 1L),
function(s) {
glue::glue(
"alpha_{y}[t, {s}] = alpha_{y}[t, {s - 1}] + alpha_{y}_{s}[t];"
)
},
character(1L)
)
state_alpha <- paste_rows(
"for (t in 1:T) {{",
"alpha_{y}[t, 1] = alpha_{y}_1[t];",
alpha_loop,
"}}",
.indent = idt(c(1, 2, 2, 1))
)
}
list(
declarations = paste_rows(
ulapply(default, "[[", "declarations"),
Expand Down

0 comments on commit 4538705

Please sign in to comment.