Skip to content

Commit

Permalink
add support for draw_indices in spread/gather_draws, closes #323
Browse files Browse the repository at this point in the history
  • Loading branch information
mjskay committed Apr 23, 2024
1 parent e9f8c5f commit 4f2eddd
Show file tree
Hide file tree
Showing 18 changed files with 177 additions and 47 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ Language: en-US
BugReports: https://github.com/mjskay/tidybayes/issues/new
URL: https://mjskay.github.io/tidybayes/, https://github.com/mjskay/tidybayes/
VignetteBuilder: knitr
RoxygenNote: 7.2.3
RoxygenNote: 7.3.1
LazyData: true
Encoding: UTF-8
Collate:
Expand Down
3 changes: 3 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ S3method(epred_rvars,brmsfit)
S3method(epred_rvars,default)
S3method(epred_rvars,stanreg)
S3method(fitted_draws,default)
S3method(flip_aes,"function")
S3method(flip_aes,character)
S3method(flip_aes,data.frame)
S3method(gather_emmeans_draws,default)
S3method(gather_emmeans_draws,emm_list)
S3method(get_variables,default)
Expand Down
8 changes: 7 additions & 1 deletion NEWS.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
# tidybayes (development version)

Buf fixes:
New features:

* Add support for `draw_indices` parameter in `spread_draws()` and
`gather_draws()`. (#323)


Bug fixes:

* Support for matrix columns in `nest_rvars()` and `unnest_rvars()`. (#316)

Expand Down
10 changes: 1 addition & 9 deletions R/compare_levels.R
Original file line number Diff line number Diff line change
Expand Up @@ -88,15 +88,7 @@ comparison_types = within(list(), {
#' in the output `variable` column instead converting the unevaluated
#' expression to a string. You can also use [emmeans_comparison()] to generate
#' a comparison function based on contrast methods from the `emmeans` package.
#' @param draw_indices Character vector of column names in `data` that
#' should be treated as indices when making the comparison (i.e. values of
#' `variable` within each level of `by` will be compared at each
#' unique combination of levels of `draw_indices`). Columns in `draw_indices`
#' not found in `data` are ignored. The default is `c(".chain",".iteration",".draw")`,
#' which are the same names used for chain/iteration/draw indices returned by
#' [spread_draws()] or [gather_draws()]; thus if you are using `compare_levels`
#' with [spread_draws()] or [gather_draws()] you generally should not need to change this
#' value.
#' @template param-draw_indices
#' @param ignore_groups character vector of names of groups to ignore by
#' default in the input grouping. This is primarily provided to make it
#' easier to pipe output of [add_epred_draws()] into this function,
Expand Down
3 changes: 3 additions & 0 deletions R/flip_aes.R
Original file line number Diff line number Diff line change
Expand Up @@ -29,17 +29,20 @@ flip_aes = function(x, lookup = flip_aes_lookup) {
UseMethod("flip_aes")
}

#' @export
flip_aes.character = function(x, lookup = flip_aes_lookup) {
flipped = lookup[x]
x[!is.na(flipped)] = flipped[!is.na(flipped)]
x
}

#' @export
flip_aes.data.frame = function(x, lookup = flip_aes_lookup) {
names(x) = flip_aes(names(x), lookup = lookup)
x
}

#' @export
flip_aes.function = function(x, lookup = flip_aes_lookup) {
name = force(deparse(substitute(x)))
function(...) {
Expand Down
17 changes: 15 additions & 2 deletions R/gather_draws.R
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,26 @@
#' @importFrom dplyr bind_rows group_by_at
#' @importFrom rlang enquos
#' @export
gather_draws = function(model, ..., regex = FALSE, sep = "[, ]", ndraws = NULL, seed = NULL, n) {
gather_draws = function(
model,
...,
regex = FALSE,
sep = "[, ]",
ndraws = NULL,
seed = NULL,
draw_indices = c(".chain", ".iteration", ".draw"),
n
) {
ndraws = .Deprecated_argument_alias(ndraws, n)

draws = sample_draws_from_model_(model, ndraws, seed)

draw_indices = intersect(draw_indices, names(draws))
tidysamples = lapply(enquos(...), function(variable_spec) {
gather_variables(spread_draws_(draws, variable_spec, regex = regex, sep = sep))
gather_variables(
spread_draws_(draws, variable_spec, regex = regex, sep = sep, draw_indices = draw_indices),
exclude = c(draw_indices, ".row")
)
})

#get the groups from all the samples --- when we bind them together,
Expand Down
57 changes: 46 additions & 11 deletions R/spread_draws.R
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,7 @@ globalVariables(c(".."))
#' @param sep Separator used to separate dimensions in variable names, as a regular expression.
#' @template param-ndraws
#' @template param-seed
#' @template param-draw_indices
#' @template param-deprecated-n
#' @return A data frame.
#' @author Matthew Kay
Expand Down Expand Up @@ -232,13 +233,29 @@ globalVariables(c(".."))
#' @importFrom dplyr inner_join group_by_at
#' @rdname spread_draws
#' @export
spread_draws = function(model, ..., regex = FALSE, sep = "[, ]", ndraws = NULL, seed = NULL, n) {
spread_draws = function(
model,
...,
regex = FALSE,
sep = "[, ]",
ndraws = NULL,
seed = NULL,
draw_indices = c(".chain", ".iteration", ".draw"),
n
) {
ndraws = .Deprecated_argument_alias(ndraws, n)

draws = sample_draws_from_model_(model, ndraws, seed)

draw_indices = intersect(draw_indices, names(draws))
tidysamples = lapply(enquos(...), function(variable_spec) {
spread_draws_(draws, variable_spec, regex = regex, sep = sep)
spread_draws_(
draws,
variable_spec,
regex = regex,
sep = sep,
draw_indices = draw_indices
)
})

#get the groups from all the samples --- when we join them together,
Expand All @@ -260,15 +277,21 @@ spread_draws = function(model, ..., regex = FALSE, sep = "[, ]", ndraws = NULL,
#' @importFrom dplyr mutate group_by_at
#' @importFrom tidyr spread
#' @importFrom rlang has_name
spread_draws_ = function(draws, variable_spec, regex = FALSE, sep = "[, ]") {
spread_draws_ = function(
draws,
variable_spec,
regex = FALSE,
sep = "[, ]",
draw_indices = c(".chain", ".iteration", ".draw")
) {
#parse a variable spec in the form variable_name[dimension_name_1, dimension_name_2, ..] | wide_dimension
spec = parse_variable_spec(variable_spec)
variable_names = spec[[1]]
dimension_names = spec[[2]]
wide_dimension_name = spec[[3]]

#extract the draws into a long format data frame
long_draws = spread_draws_long_(draws, variable_names, dimension_names, regex = regex, sep = sep)
long_draws = spread_draws_long_(draws, variable_names, dimension_names, regex = regex, sep = sep, draw_indices = draw_indices)

#convert variable and/or dimensions back into usable data types
#that were set on the model using recover_types
Expand Down Expand Up @@ -309,7 +332,14 @@ spread_draws_ = function(draws, variable_spec, regex = FALSE, sep = "[, ]") {
## dimension_names: a character vector of dimension names
#' @importFrom tidyr spread separate gather
#' @importFrom dplyr summarise_all group_by_at
spread_draws_long_ = function(draws, variable_names, dimension_names, regex = FALSE, sep = "[, ]") {
spread_draws_long_ = function(
draws,
variable_names,
dimension_names,
regex = FALSE,
sep = "[, ]",
draw_indices = c(".chain", ".iteration", ".draw")
) {
if (!regex) {
variable_names = escape_regex(variable_names)
}
Expand All @@ -326,7 +356,7 @@ spread_draws_long_ = function(draws, variable_names, dimension_names, regex = FA
}

variable_names = colnames(draws)[variable_names_index]
unnest_legacy(draws[, c(".chain", ".iteration", ".draw", variable_names)])
unnest_legacy(draws[, c(draw_indices, variable_names)])
}
else {
dimension_sep_regex = sep
Expand Down Expand Up @@ -399,19 +429,19 @@ spread_draws_long_ = function(draws, variable_names, dimension_names, regex = FA
# some dimensions were requested to be nested as list columns containing arrays.
# thus we have to ADD CHAIN INFO then UNNEST, then NEST DIMENSIONS then SPREAD
# 2. ADD CHAIN INFO
nested_draws[[".chain_info"]] = list(draws[,c(".chain", ".iteration", ".draw")])
nested_draws[[".chain_info"]] = list(draws[, draw_indices])
# 3. UNNEST
long_draws = unnest_legacy(nested_draws)
# NEST DIMENSIONS
long_draws = nest_dimensions_(long_draws, temp_dimension_names, nested_dimension_names)
long_draws = nest_dimensions_(long_draws, temp_dimension_names, nested_dimension_names, draw_indices)
# 1. SPREAD
long_draws = spread(long_draws, ".variable", ".value")
} else {
# no nested dimensions, so we can do the SPREAD then UNNEST then ADD CHAIN INFO
# 1. SPREAD
nested_draws = spread(nested_draws, ".variable", ".value")
# 2. ADD CHAIN INFO
nested_draws[[".chain_info"]] = list(draws[,c(".chain", ".iteration", ".draw")])
nested_draws[[".chain_info"]] = list(draws[, draw_indices])
# 3. UNNEST
long_draws = unnest_legacy(nested_draws)
}
Expand All @@ -429,7 +459,12 @@ spread_draws_long_ = function(draws, variable_names, dimension_names, regex = FA
## dimension_names: dimensions not used for nesting
## nested_dimension_names: dimensions to be nested
#' @importFrom dplyr filter summarise_at
nest_dimensions_ = function(long_draws, dimension_names, nested_dimension_names) {
nest_dimensions_ = function(
long_draws,
dimension_names,
nested_dimension_names,
draw_indices = c(".chain", ".iteration", ".draw")
) {
ragged = FALSE
value_name = ".value"
value = as.name(value_name)
Expand All @@ -443,7 +478,7 @@ nest_dimensions_ = function(long_draws, dimension_names, nested_dimension_names)
}

long_draws = group_by_at(long_draws,
c(".chain", ".iteration", ".draw", ".variable", dimension_names) %>%
c(draw_indices, ".variable", dimension_names) %>%
# nested dimension names must come at the end of the group list
# (minus the last nested dimension) so that we summarise in the
# correct order
Expand Down
3 changes: 1 addition & 2 deletions R/tidybayes-package.R
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
#' Tidy Data and 'Geoms' for Bayesian Models
#'
#' @docType package
#' @name tidybayes-package
#' @aliases tidybayes
#'
Expand Down Expand Up @@ -34,4 +33,4 @@
#' Wickham, Hadley. (2014). Tidy data. _Journal of Statistical Software_,
#' 59(10), 1-23. \doi{10.18637/jss.v059.i10}.
#'
NULL
"_PACKAGE"
2 changes: 1 addition & 1 deletion R/ungather_draws.R
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ globalVariables(c("..dimension_values"))
ungather_draws = function(
data, ..., variable = ".variable", value = ".value", draw_indices = c(".chain", ".iteration", ".draw"), drop_indices = FALSE
) {

draw_indices = intersect(draw_indices, names(data))
variable_specs = enquos(...)

if (length(variable_specs) == 0) {
Expand Down
6 changes: 2 additions & 4 deletions R/unspread_draws.R
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,7 @@ globalVariables(c("..dimension_values"))
#' @param data A tidy data frame of draws, such as one output by `spread_draws` or `gather_draws`.
#' @param ... Expressions in the form of
#' `variable_name[dimension_1, dimension_2, ...]`. See [spread_draws()].
#' @param draw_indices Character vector of column names in `data` that
#' should be treated as indices of draws. The default is `c(".chain",".iteration",".draw")`,
#' which are the same names used for chain, iteration, and draw indices returned by
#' [spread_draws()] or [gather_draws()].
#' @template param-draw_indices
#' @param drop_indices Drop the columns specified by `draw_indices` from the resulting data frame. Default `FALSE`.
#' @param variable The name of the column in `data` that contains the names of variables from the model.
#' @param value The name of the column in `data` that contains draws from the variables.
Expand Down Expand Up @@ -62,6 +59,7 @@ globalVariables(c("..dimension_values"))
#' @rdname unspread_draws
#' @export
unspread_draws = function(data, ..., draw_indices = c(".chain", ".iteration", ".draw"), drop_indices = FALSE) {
draw_indices = intersect(draw_indices, names(data))
result =
lapply(enquos(...), function(variable_spec) {
unspread_draws_(data, variable_spec, draw_indices = draw_indices)
Expand Down
5 changes: 5 additions & 0 deletions man-roxygen/param-draw_indices.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
#' @param draw_indices Character vector of column names that should be treated
#' as indices of draws. Operations are done within combinations of these values.
#' The default is `c(".chain", ".iteration", ".draw")`, which is the same names
#' used for chain, iteration, and draw indices returned by [tidy_draws()].
#' Names in `draw_indices` that are not found in the data are ignored.
14 changes: 5 additions & 9 deletions man/compare_levels.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit 4f2eddd

Please sign in to comment.