From 0e56cd7d159694d5308fa76939c2dbc4fc56606c Mon Sep 17 00:00:00 2001 From: Cole Trapnell Date: Mon, 6 Feb 2023 08:38:46 -0800 Subject: [PATCH 01/11] Pass optimization configuration options to postTreatment, so that it can be used e.g. in bootstrap/jackknife --- R/PLN.R | 2 +- R/PLNLDA.R | 2 +- R/PLNLDAfit-class.R | 4 +-- R/PLNPCA.R | 2 +- R/PLNPCAfit-class.R | 4 +-- R/PLNfamily-class.R | 11 +++++---- R/PLNfit-class.R | 37 +++++++++++++++------------- R/PLNmixture.R | 2 +- R/PLNmixturefit-class.R | 5 ++-- R/PLNnetwork.R | 14 ++++++++--- tests/testthat/test-standard-error.R | 2 +- 11 files changed, 49 insertions(+), 36 deletions(-) diff --git a/R/PLN.R b/R/PLN.R index df8e6fd5..9a08bcec 100644 --- a/R/PLN.R +++ b/R/PLN.R @@ -45,7 +45,7 @@ PLN <- function(formula, data, subset, weights, control = PLN_param()) { ## post-treatment if (control$trace > 0) cat("\n Post-treatments...") - myPLN$postTreatment(args$Y, args$X, args$O, args$w, control$config_post) + myPLN$postTreatment(args$Y, args$X, args$O, args$w, control$config_post, control$config_optim) if (control$trace > 0) cat("\n DONE!\n") myPLN diff --git a/R/PLNLDA.R b/R/PLNLDA.R index 51c8bd63..f40d7b0e 100644 --- a/R/PLNLDA.R +++ b/R/PLNLDA.R @@ -64,7 +64,7 @@ PLNLDA <- function(formula, data, subset, weights, grouping, control = PLN_param myLDA$optimize(grouping, args$Y, args$X, args$O, args$w, control$config_optim) ## Post-treatment: prepare LDA visualization - myLDA$postTreatment(grouping, args$Y, args$X, args$O, control$config_post) + myLDA$postTreatment(grouping, args$Y, args$X, args$O, control$config_post, control$config_optim) if (control$trace > 0) cat("\n DONE!\n") myLDA diff --git a/R/PLNLDAfit-class.R b/R/PLNLDAfit-class.R index 540120bc..dc28f34e 100644 --- a/R/PLNLDAfit-class.R +++ b/R/PLNLDAfit-class.R @@ -85,9 +85,9 @@ PLNLDAfit <- R6Class( ## Post treatment -------------------- #' @description Update R2, fisher and std_err fields and visualization #' @param config list controlling the post-treatment - postTreatment = function(grouping, responses, covariates, offsets, config) { + postTreatment = function(grouping, responses, covariates, offsets, config_post, config_optim) { covariates <- cbind(covariates, model.matrix( ~ grouping + 0)) - super$postTreatment(responses, covariates, offsets, config = config) + super$postTreatment(responses, covariates, offsets, config_post = config_post, config_optim = config_optim) rownames(private$C) <- colnames(private$C) <- colnames(responses) colnames(private$S) <- 1:self$q if (config$trace > 1) cat("\n\tCompute LD scores for visualization...") diff --git a/R/PLNPCA.R b/R/PLNPCA.R index ca72e145..2d2b128e 100644 --- a/R/PLNPCA.R +++ b/R/PLNPCA.R @@ -52,7 +52,7 @@ PLNPCA <- function(formula, data, subset, weights, ranks = 1:5, control = PLNPCA ## Post-treatments: pseudo-R2, rearrange criteria and prepare PCA visualization if (control$trace > 0) cat("\n Post-treatments") config_post <- config_post_default_PLNPCA; config_post$trace <- control$trace - myPCA$postTreatment(config_post) + myPCA$postTreatment(config_post, control$config_optim) if (control$trace > 0) cat("\n DONE!\n") myPCA diff --git a/R/PLNPCAfit-class.R b/R/PLNPCAfit-class.R index dffb550f..a6ea44ac 100644 --- a/R/PLNPCAfit-class.R +++ b/R/PLNPCAfit-class.R @@ -169,8 +169,8 @@ PLNPCAfit <- R6Class( #' * variational_var boolean indicating whether variational Fisher information matrix should be computed to estimate the variance of the model parameters (highly underestimated). Default is FALSE. #' * rsquared boolean indicating whether approximation of R2 based on deviance should be computed. Default is TRUE #' * trace integer for verbosity. should be > 1 to see output in post-treatments - postTreatment = function(responses, covariates, offsets, weights, config, nullModel) { - super$postTreatment(responses, covariates, offsets, weights, config, nullModel) + postTreatment = function(responses, covariates, offsets, weights, config_post, config_optim, nullModel) { + super$postTreatment(responses, covariates, offsets, weights, config_post, config_optim, nullModel) colnames(private$C) <- colnames(private$M) <- 1:self$q rownames(private$C) <- colnames(responses) self$setVisualization() diff --git a/R/PLNfamily-class.R b/R/PLNfamily-class.R index 4c686855..fda35a9c 100644 --- a/R/PLNfamily-class.R +++ b/R/PLNfamily-class.R @@ -64,17 +64,18 @@ PLNfamily <- ## %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% ## Post treatment -------------------- #' @description Update fields after optimization - #' @param config a list for controlling the post-treatment. - postTreatment = function(config) { - nullModel <- nullModelPoisson(self$responses, self$covariates, self$offsets, self$weights) + #' @param config_post a list for controlling the post-treatment. + postTreatment = function(config_post, config_optim) { + #nullModel <- nullModelPoisson(self$responses, self$covariates, self$offsets, self$weights) for (model in self$models) model$postTreatment( self$responses, self$covariates, self$offsets, self$weights, - config, - nullModel = nullModel + config_post=config_post, + config_optim=config_optim, + nullModel = NULL ) }, diff --git a/R/PLNfit-class.R b/R/PLNfit-class.R index b65e2b1a..6b76d7f7 100644 --- a/R/PLNfit-class.R +++ b/R/PLNfit-class.R @@ -168,7 +168,7 @@ PLNfit <- R6Class( ## PRIVATE METHODS FOR VARIANCE OF THE ESTIMATORS ## %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% - variance_variational = function(X) { + variance_variational = function(X, config = config_default_nlopt) { ## Variance of B for n data points fisher <- Matrix::bdiag(lapply(1:self$p, function(j) { crossprod(X, private$A[, j] * X) # t(X) %*% diag(A[, i]) %*% X @@ -375,7 +375,7 @@ PLNfit <- R6Class( #' * variational_var boolean indicating whether variational Fisher information matrix should be computed to estimate the variance of the model parameters (highly underestimated). Default is FALSE. #' * rsquared boolean indicating whether approximation of R2 based on deviance should be computed. Default is TRUE #' * trace integer for verbosity. should be > 1 to see output in post-treatments - postTreatment = function(responses, covariates, offsets, weights = rep(1, nrow(responses)), config, nullModel = NULL) { + postTreatment = function(responses, covariates, offsets, weights = rep(1, nrow(responses)), config_post, config_optim, nullModel = NULL) { ## PARAMATERS DIMNAMES ## Set names according to those of the data matrices. If missing, use sensible defaults if (is.null(colnames(responses))) @@ -392,24 +392,27 @@ PLNfit <- R6Class( ## OPTIONAL POST-TREATMENT (potentially costly) ## 1. compute and store approximated R2 with Poisson-based deviance - if (config$rsquared) { - if(config$trace > 1) cat("\n\tComputing bootstrap estimator of the variance...") + if (config_post$rsquared) { + if(config_post$trace > 1) cat("\n\tComputing approximate R^2...") private$approx_r2(responses, covariates, offsets, weights, nullModel) } ## 2. compute and store matrix of standard variances for B and Omega with rough variational approximation - if (config$variational_var) { - if(config$trace > 1) cat("\n\tComputing variational estimator of the variance...") - private$variance_variational(covariates) + if (config_post$variational_var) { + if(config_post$trace > 1) cat("\n\tComputing variational estimator of the variance...") + private$variance_variational(covariates, config = config_optim) } ## 3. Jackknife estimation of bias and variance - if (config$jackknife) { - if(config$trace > 1) cat("\n\tComputing jackknife estimator of the variance...") - private$variance_jackknife(responses, covariates, offsets, weights) + if (config_post$jackknife) { + if(config_post$trace > 1) cat("\n\tComputing jackknife estimator of the variance...") + private$variance_jackknife(responses, covariates, offsets, weights, config = config_optim) } ## 4. Bootstrap estimation of variance - if (config$bootstrap > 0) { - if(config$trace > 1) cat("\n\tComputing bootstrap estimator of the variance...") - private$variance_bootstrap(responses, covariates, offsets, weights, config$bootstrap) + if (config_post$bootstrap > 0) { + if(config_post$trace > 1) { + cat("\n\tComputing bootstrap estimator of the variance...") + print (str(config_optim)) + } + private$variance_bootstrap(responses, covariates, offsets, weights, n_resamples=config_post$bootstrap, config = config_optim) } }, @@ -804,11 +807,11 @@ PLNfit_fixedcov <- R6Class( #' * bootstrap integer indicating the number of bootstrap resamples generated to evaluate the variance of the model parameters. Default is 0 (inactivated). #' * variational_var boolean indicating whether variational Fisher information matrix should be computed to estimate the variance of the model parameters (highly underestimated). Default is FALSE. #' * rsquared boolean indicating whether approximation of R2 based on deviance should be computed. Default is TRUE - postTreatment = function(responses, covariates, offsets, weights = rep(1, nrow(responses)), config, nullModel = NULL) { - super$postTreatment(responses, covariates, offsets, weights, config, nullModel) + postTreatment = function(responses, covariates, offsets, weights = rep(1, nrow(responses)), config_post, config_optim, nullModel = NULL) { + super$postTreatment(responses, covariates, offsets, weights, config_post, config_optim, nullModel) ## 6. compute and store matrix of standard variances for B with sandwich correction approximation - if (config$sandwich_var) { - if(config$trace > 1) cat("\n\tComputing sandwich estimator of the variance...") + if (config_post$sandwich_var) { + if(config_post$trace > 1) cat("\n\tComputing sandwich estimator of the variance...") private$vcov_sandwich_B(responses, covariates) } } diff --git a/R/PLNmixture.R b/R/PLNmixture.R index dea6e9af..4e0252bb 100644 --- a/R/PLNmixture.R +++ b/R/PLNmixture.R @@ -60,7 +60,7 @@ PLNmixture <- function(formula, data, subset, clusters = 1:5, control = PLNmixt ## Post-treatments: Compute pseudo-R2, rearrange criteria and the visualization for PCA if (control$trace > 0) cat("\n Post-treatments") config_post <- config_post_default_PLNmixture; config_post$trace <- control$trace - myPLN$postTreatment(config_post) + myPLN$postTreatment(config_post, control$config_optim) if (control$trace > 0) cat("\n DONE!\n") myPLN diff --git a/R/PLNmixturefit-class.R b/R/PLNmixturefit-class.R index 23363380..12eca6a9 100644 --- a/R/PLNmixturefit-class.R +++ b/R/PLNmixturefit-class.R @@ -281,7 +281,7 @@ PLNmixturefit <- ## Post treatment -------------------- #' @description Update fields after optimization #' @param config a list for controlling the post-treatment - postTreatment = function(responses, covariates, offsets, weights, config, nullModel) { + postTreatment = function(responses, covariates, offsets, weights, config_post, config_optim, nullModel) { ## restoring the full design matrix (group means + covariates) mu_k <- matrix(1, self$n, ncol = 1); colnames(mu_k) <- 'Intercept' @@ -292,7 +292,8 @@ PLNmixturefit <- mu_k, offsets, private$tau[,k_], - config, + config_post, + config_optim, nullModel = nullModel ) }, diff --git a/R/PLNnetwork.R b/R/PLNnetwork.R index 32be12b7..93924f46 100644 --- a/R/PLNnetwork.R +++ b/R/PLNnetwork.R @@ -41,8 +41,9 @@ PLNnetwork <- function(formula, data, subset, weights, penalties = NULL, control ## Post-treatments if (control$trace > 0) cat("\n Post-treatments") - config_post <- config_post_default_PLNnetwork; config_post$trace <- control$trace - myPLN$postTreatment(config_post) + #config_post <- config_post_default_PLNnetwork; + #config_post$trace <- control$trace + myPLN$postTreatment(control$config_post, control$config_optim) if (control$trace > 0) cat("\n DONE!\n") myPLN @@ -85,18 +86,24 @@ PLNnetwork <- function(formula, data, subset, weights, penalties = NULL, control #' #' @export PLNnetwork_param <- function( - backend = "nlopt", + backend = c("nlopt", "torch"), trace = 1 , n_penalties = 30 , min_ratio = 0.1 , penalize_diagonal = TRUE , penalty_weights = NULL , + config_post = list(), config_optim = list(), inception = NULL ) { if (!is.null(inception)) stopifnot(isPLNfit(inception)) + ## post-treatment config + config_pst <- config_post_default_PLN + config_pst[names(config_post)] <- config_post + config_pst$trace <- trace + ## optimization config backend <- match.arg(backend) stopifnot(backend %in% c("nlopt", "torch")) @@ -123,6 +130,7 @@ PLNnetwork_param <- function( jackknife = FALSE , bootstrap = 0 , variance = TRUE , + config_post = config_pst , config_optim = config_opt , inception = inception ), class = "PLNmodels_param") } diff --git a/tests/testthat/test-standard-error.R b/tests/testthat/test-standard-error.R index 9bd20803..43768350 100644 --- a/tests/testthat/test-standard-error.R +++ b/tests/testthat/test-standard-error.R @@ -95,7 +95,7 @@ test_that("Check that variance estimation are coherent in PLNfit", { trace = 2 ) - myPLN$postTreatment(Y, X, exp(log_O), config = config_post) + myPLN$postTreatment(Y, X, exp(log_O), config_post = config_post) tr_variational <- sum(standard_error(myPLN, "variational")^2) tr_bootstrap <- sum(standard_error(myPLN, "bootstrap")^2) From 4dd745106e5789c95e42267987b7d4ff7f5a4379 Mon Sep 17 00:00:00 2001 From: Cole Trapnell Date: Mon, 6 Feb 2023 08:40:02 -0800 Subject: [PATCH 02/11] Various improvements for torch optimizer needed to run it on the GPU --- R/PLN.R | 2 +- R/PLNfit-class.R | 68 ++++++++++++++++++++++++++++++++++-------------- R/utils.R | 8 +++++- 3 files changed, 57 insertions(+), 21 deletions(-) diff --git a/R/PLN.R b/R/PLN.R index 9a08bcec..d4148f83 100644 --- a/R/PLN.R +++ b/R/PLN.R @@ -95,7 +95,7 @@ PLN_param <- function( Omega = NULL, config_post = list(), config_optim = list(), - inception = NULL # pretrained PLNfit used as initialization + inception = NULL # pretrained PLNfit used as initialization, ) { covariance <- match.arg(covariance) diff --git a/R/PLNfit-class.R b/R/PLNfit-class.R index 6b76d7f7..2fb320e0 100644 --- a/R/PLNfit-class.R +++ b/R/PLNfit-class.R @@ -80,21 +80,27 @@ PLNfit <- R6Class( torch_vloglik = function(data, params) { S2 <- torch_square(params$S) - Ji <- .5 * self$p - rowSums(.logfactorial(as.matrix(data$Y))) + as.numeric( - .5 * torch_logdet(params$Omega) + - torch_sum(data$Y * params$Z - params$A + .5 * torch_log(S2), dim = 2) - - .5 * torch_sum(torch_mm(params$M, params$Omega) * params$M + S2 * torch_diag(params$Omega), dim = 2) - ) - attr(Ji, "weights") <- as.numeric(data$w) + + Ji_tmp = .5 * torch_logdet(params$Omega) + + torch_sum(data$Y * params$Z - params$A + .5 * torch_log(S2), dim = 2) - + .5 * torch_sum(torch_mm(params$M, params$Omega) * params$M + S2 * torch_diag(params$Omega), dim = 2) + Ji_tmp = Ji_tmp$cpu() + Ji_tmp = as.numeric(Ji_tmp) + Ji <- .5 * self$p - rowSums(.logfactorial(as.matrix(data$Y$cpu()))) + Ji_tmp + + attr(Ji, "weights") <- as.numeric(data$w$cpu()) Ji }, #' @import torch torch_optimize = function(data, params, config) { + #config$device = "mps" + if (config$trace > 1) + message (paste("optimizing with device: ", config$device)) ## Conversion of data and parameters to torch tensors (pointers) - data <- lapply(data, torch_tensor) # list with Y, X, O, w - params <- lapply(params, torch_tensor, requires_grad = TRUE) # list with B, M, S + data <- lapply(data, torch_tensor, dtype = torch_float32(), device = config$device) # list with Y, X, O, w + params <- lapply(params, torch_tensor, dtype = torch_float32(), requires_grad = TRUE, device = config$device) # list with B, M, S ## Initialize optimizer optimizer <- switch(config$algorithm, @@ -111,11 +117,14 @@ PLNfit <- R6Class( batch_size <- floor(self$n/num_batch) objective <- double(length = config$num_epoch + 1) + #B_old = optimizer$param_groups[[1]]$params$B$clone() for (iterate in 1:num_epoch) { - B_old <- as.numeric(optimizer$param_groups[[1]]$params$B) - + #B_old <- as.numeric(optimizer$param_groups[[1]]$params$B) + B_old = optimizer$param_groups[[1]]$params$B$clone() # rearrange the data each epoch - permute <- torch::torch_randperm(self$n) + 1L + #permute <- torch::torch_randperm(self$n, device = "cpu") + 1L + permute = torch::torch_tensor(sample.int(self$n), dtype = torch_long(), device=config$device) + for (batch_idx in 1:num_batch) { # here index is a vector of the indices in the batch index <- permute[(batch_size*(batch_idx - 1) + 1):(batch_idx*batch_size)] @@ -129,14 +138,21 @@ PLNfit <- R6Class( ## assess convergence objective[iterate + 1] <- loss$item() - B_new <- as.numeric(optimizer$param_groups[[1]]$params$B) + B_new <- optimizer$param_groups[[1]]$params$B delta_f <- abs(objective[iterate] - objective[iterate + 1]) / abs(objective[iterate + 1]) - delta_x <- sum(abs(B_old - B_new))/sum(abs(B_new)) + delta_x <- torch::torch_sum(torch::torch_abs(B_old - B_new))/torch::torch_sum(torch::torch_abs(B_new)) + + #print (delta_f) + #print (delta_x) + delta_x = delta_x$cpu() + #print (delta_x) + delta_x = as.matrix(delta_x) + #print (delta_x) ## display progress if (config$trace > 1 && (iterate %% 50 == 0)) cat('\niteration: ', iterate, 'objective', objective[iterate + 1], - 'delta_f' , round(delta_f, 6), 'delta_x', ro% map("B") %>% reduce(`+`) / self$n var_jack <- jacks %>% map("B") %>% map(~( (. - B_jack)^2)) %>% reduce(`+`) %>% @@ -228,17 +247,28 @@ PLNfit <- R6Class( variance_bootstrap = function(Y, X, O, w, n_resamples = 100, config = config_default_nlopt) { resamples <- replicate(n_resamples, sample.int(self$n, replace = TRUE), simplify = FALSE) - boots <- future.apply::future_lapply(resamples, function(resample) { + boots <- lapply(resamples, function(resample) { data <- list(Y = Y[resample, , drop = FALSE], X = X[resample, , drop = FALSE], O = O[resample, , drop = FALSE], w = w[resample]) + #print (config$torch_device) + #print (config) + if (config$algorithm %in% c("RPROP", "RMSPROP", "ADAM", "ADAGRAD")) # hack, to know if we're doing torch or not + data <- lapply(data, torch_tensor, device = config$device) # list with Y, X, O, w + + #print (data$Y$device) + args <- list(data = data, params = list(B = private$B, M = matrix(0,self$n,self$p), S = private$S[resample, ]), config = config) + if (config$algorithm %in% c("RPROP", "RMSPROP", "ADAM", "ADAGRAD")) # hack, to know if we're doing torch or not + args$params <- lapply(args$params, torch_tensor, requires_grad = TRUE, device = config$device) # list with B, M, S + optim_out <- do.call(private$optimizer$main, args) + #print (optim_out) optim_out[c("B", "Omega", "monitoring")] - }, future.seed = TRUE) + }) B_boots <- boots %>% map("B") %>% reduce(`+`) / n_resamples attr(private$B, "variance_bootstrap") <- diff --git a/R/utils.R b/R/utils.R index 2d6e52ad..f6d3f9db 100644 --- a/R/utils.R +++ b/R/utils.R @@ -26,7 +26,8 @@ config_default_torch <- step_sizes = c(1e-3, 50), etas = c(0.5, 1.2), centered = FALSE, - trace = 1 + trace = 1, + device = "cpu" ) config_post_default_PLN <- @@ -107,6 +108,11 @@ trace <- function(x) sum(diag(x)) x } +.logfactorial_torch <- function(n){ + n[n == 0] <- 1 ## 0! = 1! + n*torch_log(n) - n + torch_log(8*torch_pow(n,3) + 4*torch_pow(n,2) + n + 1/30)/6 + log(pi)/2 +} + .logfactorial <- function(n) { # Ramanujan's formula n[n == 0] <- 1 ## 0! = 1! n*log(n) - n + log(8*n^3 + 4*n^2 + n + 1/30)/6 + log(pi)/2 From 635dd223b9a1429611b5b4cd1f9141606c0dd89f Mon Sep 17 00:00:00 2001 From: Cole Trapnell Date: Fri, 10 Feb 2023 12:48:47 -0800 Subject: [PATCH 03/11] Compute vcov on parameters when using jackknife or bootstrap --- R/PLNfit-class.R | 61 ++++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 59 insertions(+), 2 deletions(-) diff --git a/R/PLNfit-class.R b/R/PLNfit-class.R index 2fb320e0..0e901507 100644 --- a/R/PLNfit-class.R +++ b/R/PLNfit-class.R @@ -60,6 +60,8 @@ PLNfit <- R6Class( ## PRIVATE TORCH METHODS FOR OPTIMIZATION ## %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% torch_elbo = function(data, params, index=torch_tensor(1:self$n)) { + #print (index) + #print (params$S) S2 <- torch_square(params$S[index]) Z <- data$O[index] + params$M[index] + torch_mm(data$X[index], params$B) res <- .5 * sum(data$w[index]) * torch_logdet(private$torch_Sigma(data, params, index)) + @@ -140,11 +142,12 @@ PLNfit <- R6Class( objective[iterate + 1] <- loss$item() B_new <- optimizer$param_groups[[1]]$params$B delta_f <- abs(objective[iterate] - objective[iterate + 1]) / abs(objective[iterate + 1]) + #delta_x = 0 delta_x <- torch::torch_sum(torch::torch_abs(B_old - B_new))/torch::torch_sum(torch::torch_abs(B_new)) + delta_x = delta_x$cpu() #print (delta_f) #print (delta_x) - delta_x = delta_x$cpu() #print (delta_x) delta_x = as.matrix(delta_x) #print (delta_x) @@ -156,7 +159,7 @@ PLNfit <- R6Class( ## Check for convergence if (delta_f < config$ftol_rel) status <- 3 - if (delta_x < config$xtol_rel) status <- 4 + #if (delta_x < config$xtol_rel) status <- 4 if (status %in% c(3,4)) { objective <- objective[1:iterate + 1] break @@ -217,6 +220,54 @@ PLNfit <- R6Class( invisible(list(var_B = var_B, var_Omega = var_Omega)) }, + compute_vcov_from_resamples = function(resamples){ + # compute the covariance of the parameters + get_cov_mat = function(data, cell_group) { + + cov_matrix = cov(data) + rownames(cov_matrix) = paste0(cell_group, "_", rownames(cov_matrix)) + colnames(cov_matrix) = paste0(cell_group, "_", colnames(cov_matrix)) + return(cov_matrix) + } + + + B_list = resamples %>% map("B") + #print (B_list) + vcov_B = lapply(seq(1, ncol(private$B)), function(B_col){ + param_ests_for_col = B_list %>% map(~.x[, B_col]) + param_ests_for_col = do.call(rbind, param_ests_for_col) + print (param_ests_for_col) + row_vcov = cov(param_ests_for_col) + }) + #print ("vcov blocks") + #print (vcov_B) + + #B_vcov <- resamples %>% map("B") %>% map(~( . )) %>% reduce(cov) + + #var_jack <- jacks %>% map("B") %>% map(~( (. - B_jack)^2)) %>% reduce(`+`) %>% + # `dimnames<-`(dimnames(private$B)) + #B_hat <- private$B[,] ## strips attributes while preserving names + + vcov_B = Matrix::bdiag(vcov_B) %>% as.matrix() + + rownames(vcov_B) <- colnames(vcov_B) <- + expand.grid(covariates = rownames(private$B), + responses = colnames(private$B)) %>% rev() %>% + ## Hack to make sure that species is first and varies slowest + apply(1, paste0, collapse = "_") + + #print (pheatmap::pheatmap(vcov_B, cluster_rows=FALSE, cluster_cols=FALSE)) + + + #names = lapply(bootstrapped_df$cov_mat, function(m){ colnames(m)}) %>% unlist() + #rownames(bootstrapped_vhat) = names + #colnames(bootstrapped_vhat) = names + + vcov_B = methods::as(vcov_B, "dgCMatrix") + + return(vcov_B) + }, + variance_jackknife = function(Y, X, O, w, config = config_default_nlopt) { jacks <- lapply(seq_len(self$n), function(i) { data <- list(Y = Y[-i, , drop = FALSE], @@ -237,6 +288,9 @@ PLNfit <- R6Class( attr(private$B, "bias") <- (self$n - 1) * (B_jack - B_hat) attr(private$B, "variance_jackknife") <- (self$n - 1) / self$n * var_jack + vcov_boots = private$compute_vcov_from_resamples(boots) + attr(private$B, "vcov_jackknife") <- vcov_boots + Omega_jack <- jacks %>% map("Omega") %>% reduce(`+`) / self$n var_jack <- jacks %>% map("Omega") %>% map(~( (. - Omega_jack)^2)) %>% reduce(`+`) %>% `dimnames<-`(dimnames(private$Omega)) @@ -275,6 +329,9 @@ PLNfit <- R6Class( boots %>% map("B") %>% map(~( (. - B_boots)^2)) %>% reduce(`+`) %>% `dimnames<-`(dimnames(private$B)) / n_resamples + vcov_boots = private$compute_vcov_from_resamples(boots) + attr(private$B, "vcov_bootstrap") <- vcov_boots + Omega_boots <- boots %>% map("Omega") %>% reduce(`+`) / n_resamples attr(private$Omega, "variance_bootstrap") <- boots %>% map("Omega") %>% map(~( (. - Omega_boots)^2)) %>% reduce(`+`) %>% From 0d5b9a75f6dc665234e40ede776adf5b670af286 Mon Sep 17 00:00:00 2001 From: Cole Trapnell Date: Tue, 14 Feb 2023 11:49:06 -0800 Subject: [PATCH 04/11] Remove stray print statement --- R/PLNfit-class.R | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/R/PLNfit-class.R b/R/PLNfit-class.R index 0e901507..ce0c08cf 100644 --- a/R/PLNfit-class.R +++ b/R/PLNfit-class.R @@ -236,7 +236,7 @@ PLNfit <- R6Class( vcov_B = lapply(seq(1, ncol(private$B)), function(B_col){ param_ests_for_col = B_list %>% map(~.x[, B_col]) param_ests_for_col = do.call(rbind, param_ests_for_col) - print (param_ests_for_col) + #print (param_ests_for_col) row_vcov = cov(param_ests_for_col) }) #print ("vcov blocks") @@ -497,7 +497,7 @@ PLNfit <- R6Class( if (config_post$bootstrap > 0) { if(config_post$trace > 1) { cat("\n\tComputing bootstrap estimator of the variance...") - print (str(config_optim)) + #print (str(config_optim)) } private$variance_bootstrap(responses, covariates, offsets, weights, n_resamples=config_post$bootstrap, config = config_optim) } From 6a745c5f28fe7b6d31a94efe6804313b30fc2d3d Mon Sep 17 00:00:00 2001 From: maddyduran Date: Tue, 14 Feb 2023 13:06:38 -0800 Subject: [PATCH 05/11] pass covariance type to PLNnetwork --- R/PLNnetwork.R | 3 +++ R/PLNnetworkfamily-class.R | 7 ++++++- R/PLNnetworkfit-class.R | 2 +- 3 files changed, 10 insertions(+), 2 deletions(-) diff --git a/R/PLNnetwork.R b/R/PLNnetwork.R index 93924f46..4735b7fc 100644 --- a/R/PLNnetwork.R +++ b/R/PLNnetwork.R @@ -87,6 +87,7 @@ PLNnetwork <- function(formula, data, subset, weights, penalties = NULL, control #' @export PLNnetwork_param <- function( backend = c("nlopt", "torch"), + covariance = c("fixed", "spherical", "diagonal"), trace = 1 , n_penalties = 30 , min_ratio = 0.1 , @@ -115,6 +116,7 @@ PLNnetwork_param <- function( stopifnot(config_optim$algorithm %in% available_algorithms_torch) config_opt <- config_default_torch } + covariance <- match.arg(covariance) config_opt$trace <- trace config_opt$ftol_out <- 1e-5 config_opt$maxit_out <- 20 @@ -123,6 +125,7 @@ PLNnetwork_param <- function( structure(list( backend = backend , trace = trace , + covariance = covariance , n_penalties = n_penalties , min_ratio = min_ratio , penalize_diagonal = penalize_diagonal, diff --git a/R/PLNnetworkfamily-class.R b/R/PLNnetworkfamily-class.R index 1f9b9f3b..e6ce0a7f 100644 --- a/R/PLNnetworkfamily-class.R +++ b/R/PLNnetworkfamily-class.R @@ -45,7 +45,12 @@ PLNnetworkfamily <- R6Class( ## A basic model for inception, useless one is defined by the user ### TODO check if it is useful if (is.null(control$inception)) { - myPLN <- PLNfit$new(responses, covariates, offsets, weights, formula, control) + + myPLN <- switch(control$covariance, + "spherical" = PLNfit_spherical$new(responses, covariates, offsets, weights, formula, control), + "diagonal" = PLNfit_diagonal$new(responses, covariates, offsets, weights, formula, control), + PLNfit$new(responses, covariates, offsets, weights, formula, control)) # defaults to fixed + # myPLN <- PLNfit$new(responses, covariates, offsets, weights, formula, control) myPLN$optimize(responses, covariates, offsets, weights, control$config_optim) control$inception <- myPLN } diff --git a/R/PLNnetworkfit-class.R b/R/PLNnetworkfit-class.R index 1dff42da..6a24bc32 100644 --- a/R/PLNnetworkfit-class.R +++ b/R/PLNnetworkfit-class.R @@ -33,7 +33,7 @@ #' @seealso The function [PLNnetwork()], the class [`PLNnetworkfamily`] PLNnetworkfit <- R6Class( classname = "PLNnetworkfit", - inherit = PLNfit_fixedcov, + inherit = PLNfit_spherical, #_fixedcov, ## %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% ## PUBLIC MEMBERS ---- ## %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% From 057fd3bdfcba4fa7d971c918a0c07e5c5f663026 Mon Sep 17 00:00:00 2001 From: maddyduran Date: Tue, 14 Feb 2023 14:23:22 -0800 Subject: [PATCH 06/11] fix jackknife bug --- R/PLNfit-class.R | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/R/PLNfit-class.R b/R/PLNfit-class.R index ce0c08cf..7a0b7550 100644 --- a/R/PLNfit-class.R +++ b/R/PLNfit-class.R @@ -288,8 +288,8 @@ PLNfit <- R6Class( attr(private$B, "bias") <- (self$n - 1) * (B_jack - B_hat) attr(private$B, "variance_jackknife") <- (self$n - 1) / self$n * var_jack - vcov_boots = private$compute_vcov_from_resamples(boots) - attr(private$B, "vcov_jackknife") <- vcov_boots + vcov_jacks = private$compute_vcov_from_resamples(jacks) + attr(private$B, "vcov_jackknife") <- vcov_jacks Omega_jack <- jacks %>% map("Omega") %>% reduce(`+`) / self$n var_jack <- jacks %>% map("Omega") %>% map(~( (. - Omega_jack)^2)) %>% reduce(`+`) %>% From 423775bd4f58c383efa2dfa5538d27c6529d943b Mon Sep 17 00:00:00 2001 From: Cole Trapnell Date: Sat, 18 Feb 2023 10:59:05 -0800 Subject: [PATCH 07/11] cast Sigma to dense when fetching the max penalty in order to avoid warnings about ineffiecient access --- R/PLNnetworkfamily-class.R | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/R/PLNnetworkfamily-class.R b/R/PLNnetworkfamily-class.R index e6ce0a7f..315cea40 100644 --- a/R/PLNnetworkfamily-class.R +++ b/R/PLNnetworkfamily-class.R @@ -47,7 +47,7 @@ PLNnetworkfamily <- R6Class( if (is.null(control$inception)) { myPLN <- switch(control$covariance, - "spherical" = PLNfit_spherical$new(responses, covariates, offsets, weights, formula, control), + "spherical" = PLNfit_spherical$new(responses, covariates, offsets, weights, formula, control), "diagonal" = PLNfit_diagonal$new(responses, covariates, offsets, weights, formula, control), PLNfit$new(responses, covariates, offsets, weights, formula, control)) # defaults to fixed # myPLN <- PLNfit$new(responses, covariates, offsets, weights, formula, control) @@ -74,7 +74,7 @@ PLNnetworkfamily <- R6Class( if (is.null(penalties)) { if (control$trace > 1) cat("\n Recovering an appropriate grid of penalties.") max_pen <- list_penalty_weights %>% - map(~ myPLN$model_par$Sigma / .x) %>% + map(~ as.matrix(myPLN$model_par$Sigma) / .x) %>% map_dbl(~ max(abs(.x[upper.tri(.x, diag = control$penalize_diagonal)]))) %>% max() penalties <- 10^seq(log10(max_pen), log10(max_pen*control$min_ratio), len = control$n_penalties) From f801ca6a886988c7f08eec594c57df1177aec948 Mon Sep 17 00:00:00 2001 From: Cole Trapnell Date: Sat, 18 Feb 2023 10:59:31 -0800 Subject: [PATCH 08/11] Revert back to using fixed covariance instead of spherical as base class for PLNnetworkfit --- R/PLNnetworkfit-class.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/R/PLNnetworkfit-class.R b/R/PLNnetworkfit-class.R index 6a24bc32..1dff42da 100644 --- a/R/PLNnetworkfit-class.R +++ b/R/PLNnetworkfit-class.R @@ -33,7 +33,7 @@ #' @seealso The function [PLNnetwork()], the class [`PLNnetworkfamily`] PLNnetworkfit <- R6Class( classname = "PLNnetworkfit", - inherit = PLNfit_spherical, #_fixedcov, + inherit = PLNfit_fixedcov, ## %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% ## PUBLIC MEMBERS ---- ## %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% From 31b6153035ebb6074087ed118a2eeff25cb9aa0d Mon Sep 17 00:00:00 2001 From: maddyduran Date: Mon, 14 Aug 2023 16:05:53 -0700 Subject: [PATCH 09/11] changing line 85 to match PLNmodels/master --- R/PLNnetworkfamily-class.R | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/R/PLNnetworkfamily-class.R b/R/PLNnetworkfamily-class.R index ca629e39..e8f0b273 100644 --- a/R/PLNnetworkfamily-class.R +++ b/R/PLNnetworkfamily-class.R @@ -79,9 +79,10 @@ PLNnetworkfamily <- R6Class( # CHECK_ME_TORCH_GPU # This appears to be in torch_gpu only. The commented out line below is # in both PLNmodels/master and PLNmodels/dev. + # changed it to other one max_pen <- list_penalty_weights %>% - map(~ as.matrix(myPLN$model_par$Sigma) / .x) %>% - # map(~ control$inception$model_par$Sigma / .x) %>% + # map(~ as.matrix(myPLN$model_par$Sigma) / .x) %>% + map(~ control$inception$model_par$Sigma / .x) %>% map_dbl(~ max(abs(.x[upper.tri(.x, diag = control$penalize_diagonal)]))) %>% max() penalties <- 10^seq(log10(max_pen), log10(max_pen*control$min_ratio), len = control$n_penalties) From ed7c1811fdbcd62c53a9f7fa6940f094261aedb0 Mon Sep 17 00:00:00 2001 From: maddyduran Date: Fri, 25 Aug 2023 09:20:08 -0700 Subject: [PATCH 10/11] actually putting line 85 back --- R/PLNnetworkfamily-class.R | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/R/PLNnetworkfamily-class.R b/R/PLNnetworkfamily-class.R index e8f0b273..efe58309 100644 --- a/R/PLNnetworkfamily-class.R +++ b/R/PLNnetworkfamily-class.R @@ -81,8 +81,8 @@ PLNnetworkfamily <- R6Class( # in both PLNmodels/master and PLNmodels/dev. # changed it to other one max_pen <- list_penalty_weights %>% - # map(~ as.matrix(myPLN$model_par$Sigma) / .x) %>% - map(~ control$inception$model_par$Sigma / .x) %>% + map(~ as.matrix(myPLN$model_par$Sigma) / .x) %>% + # map(~ control$inception$model_par$Sigma / .x) %>% map_dbl(~ max(abs(.x[upper.tri(.x, diag = control$penalize_diagonal)]))) %>% max() penalties <- 10^seq(log10(max_pen), log10(max_pen*control$min_ratio), len = control$n_penalties) From 02e3501047a3a071bf9f7ad1b754f0dd86fca8a5 Mon Sep 17 00:00:00 2001 From: Cole Trapnell Date: Tue, 17 Oct 2023 10:48:24 -0700 Subject: [PATCH 11/11] slight code cleanup in the torch optimizer --- R/PLNfit-class.R | 22 ++++++---------------- 1 file changed, 6 insertions(+), 16 deletions(-) diff --git a/R/PLNfit-class.R b/R/PLNfit-class.R index 476d0d8e..ac1dd363 100644 --- a/R/PLNfit-class.R +++ b/R/PLNfit-class.R @@ -60,12 +60,11 @@ PLNfit <- R6Class( ## PRIVATE TORCH METHODS FOR OPTIMIZATION ## %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% torch_elbo = function(data, params, index=torch_tensor(1:self$n)) { - #print (index) - #print (params$S) S2 <- torch_square(params$S[index]) Z <- data$O[index] + params$M[index] + torch_mm(data$X[index], params$B) + A <- torch_exp(Z + .5 * S2) res <- .5 * sum(data$w[index]) * torch_logdet(private$torch_Sigma(data, params, index)) + - sum(data$w[index,NULL] * (torch_exp(Z + .5 * S2) - data$Y[index] * Z - .5 * torch_log(S2))) + sum(data$w[index,NULL] * (A - data$Y[index] * Z - .5 * torch_log(S2))) res }, @@ -122,11 +121,11 @@ PLNfit <- R6Class( #B_old = optimizer$param_groups[[1]]$params$B$clone() for (iterate in 1:num_epoch) { #B_old <- as.numeric(optimizer$param_groups[[1]]$params$B) - B_old = optimizer$param_groups[[1]]$params$B$clone() # rearrange the data each epoch #permute <- torch::torch_randperm(self$n, device = "cpu") + 1L permute = torch::torch_tensor(sample.int(self$n), dtype = torch_long(), device=config$device) + #print (paste("num batches", num_batch)) for (batch_idx in 1:num_batch) { # here index is a vector of the indices in the batch index <- permute[(batch_size*(batch_idx - 1) + 1):(batch_idx*batch_size)] @@ -140,24 +139,15 @@ PLNfit <- R6Class( ## assess convergence objective[iterate + 1] <- loss$item() - B_new <- optimizer$param_groups[[1]]$params$B delta_f <- abs(objective[iterate] - objective[iterate + 1]) / abs(objective[iterate + 1]) - #delta_x = 0 - delta_x <- torch::torch_sum(torch::torch_abs(B_old - B_new))/torch::torch_sum(torch::torch_abs(B_new)) - delta_x = delta_x$cpu() - - #print (delta_f) - #print (delta_x) - #print (delta_x) - delta_x = as.matrix(delta_x) - #print (delta_x) ## display progress - if (config$trace > 1 && (iterate %% 50 == 0)) + if (config$trace > 1 && (iterate %% 50 == 1)) cat('\niteration: ', iterate, 'objective', objective[iterate + 1], - 'delta_f' , round(delta_f, 6), 'delta_x', round(delta_x, 6)) + 'delta_f' , round(delta_f, 6)) ## Check for convergence + #print (delta_f) if (delta_f < config$ftol_rel) status <- 3 #if (delta_x < config$xtol_rel) status <- 4 if (status %in% c(3,4)) {