Skip to content

Commit

Permalink
Produce better error message in grassmann_lm
Browse files Browse the repository at this point in the history
  • Loading branch information
const-ae committed Aug 16, 2024
1 parent af6efb1 commit 44df5ef
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 4 deletions.
9 changes: 5 additions & 4 deletions R/geodesic_regression.R
Original file line number Diff line number Diff line change
Expand Up @@ -68,12 +68,13 @@ grassmann_lm <- function(data, design, base_point, tangent_regression = FALSE){

# Initialize with tangent regression
mm_groups <- get_groups(design)
if(any(table(mm_groups) < n_emb)){
stop("Too few datapoints in some design matrix group.\n",
"This error could be removed, but this feature hasn't been implemented yet.")
}
groups <- unique(mm_groups)
reduced_design <- mply_dbl(groups, \(gr) design[which(mm_groups == gr)[1],], ncol = ncol(design))
if(any(table(mm_groups) < n_emb)){
problematic_mat <- cbind(n_occurrences = c(table(mm_groups)), reduced_design)
stop("Too few datapoints in some design matrix group.\n\n", glmGamPoi:::format_matrix(problematic_mat),
"\nEach row must occurr at least n_embedding=", n_emb, " times.\n")
}
group_planes <- lapply(groups, \(gr) pca(data[,mm_groups == gr,drop=FALSE], n = n_emb, center = FALSE)$coordsystem)
group_sizes <- vapply(groups, \(gr) sum(mm_groups == gr), FUN.VALUE = 0L)
coef <- grassmann_geodesic_regression(group_planes, design = reduced_design, base_point = base_point, weights = group_sizes, tangent_regression = TRUE)
Expand Down
12 changes: 12 additions & 0 deletions tests/testthat/test-geodesic_regression.R
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,18 @@ test_that("grassmann_lm works", {
expect_equal(fit[,,"xb"], grassmann_log(base_point, plane_b) - grassmann_log(base_point, plane_a))
})

test_that("grassmann_lm throws a helpful error message", {
n_obs <- 100
data <- randn(5, n_obs)
col_data <- data.frame(x = sample(letters[1:3], size = n_obs, replace = TRUE))
col_data$x[1] <- "new_element"
des <- model.matrix(~ x, col_data)
base_point <- qr.Q(qr(randn(5, 2)))
expect_error({
fit <- grassmann_lm(data, des, base_point)
})
})


test_that("get_groups works and is fast", {
df <- data.frame(let = sample(letters[1:2], size = 100, replace = TRUE),
Expand Down

0 comments on commit 44df5ef

Please sign in to comment.