From 4a2f3a5e939e29f750274c8ef6b09bc38e224882 Mon Sep 17 00:00:00 2001 From: Nima Hejazi Date: Fri, 6 Dec 2024 12:40:06 -0500 Subject: [PATCH] catch up on edits to cutpoints --- R/utils.R | 23 ++++++++++------------- 1 file changed, 10 insertions(+), 13 deletions(-) diff --git a/R/utils.R b/R/utils.R index 3d3b44c..9f6abdf 100644 --- a/R/utils.R +++ b/R/utils.R @@ -50,38 +50,35 @@ format_long_hazards <- function(A, W, wts = rep(1, length(A)), grid_type <- match.arg(grid_type) # set grid along A and find interval membership of observations along grid - if (is.null(breaks) & !is.null(n_bins)) { + if (is.null(breaks) && !is.null(n_bins)) { if (grid_type == "equal_range") { bins <- ggplot2::cut_interval( x = A, n = n_bins, - right = FALSE, ordered_result = TRUE, dig.lab = 12 + right = FALSE, ordered_result = TRUE, dig.lab = 12L ) } else if (grid_type == "equal_mass") { bins <- ggplot2::cut_number( x = A, n = n_bins, - right = FALSE, ordered_result = TRUE, dig.lab = 12 + right = FALSE, ordered_result = TRUE, dig.lab = 12L ) } - } else if (!is.null(breaks) & is.null(n_bins)) { - # check that user-specified grid covers all of A - #assertthat::assert_that(min(breaks) <= min(A)) - #assertthat::assert_that(max(breaks) >= max(A)) + } else if (!is.null(breaks)) { + # augment grid to cover all of A + breaks <- unique(c(min(A), breaks, max(A))) # cut based on user-specified grid bins <- cut( - x = A, breaks = breaks, - right = FALSE, ordered_result = TRUE, dig.lab = 12 + x = A, breaks = breaks, include.lowest = TRUE, + right = FALSE, ordered_result = TRUE, dig.lab = 12L ) - } else { - stop("Invalid combination of `grid_type`, `n_bins`, and `breaks`.") } # see https://stackoverflow.com/questions/36581075/extract-the-breakpoints-from-cut breaks_left <- as.numeric(sub(".(.+),.+", "\\1", levels(bins))) breaks_right <- as.numeric(sub(".+,(.+).", "\\1", levels(bins))) - bin_length <- round(breaks_right - breaks_left, 3) + bin_length <- round(breaks_right - breaks_left, 3L) bin_id <- as.numeric(bins) - all_bins <- matrix(seq_len(max(bin_id)), ncol = 1) + all_bins <- matrix(seq_len(max(bin_id)), ncol = 1L) # loop over observations to create expanded set of records for each reformat_each_obs <- future.apply::future_lapply(seq_along(A), function(i) {