diff --git a/R/PLN.R b/R/PLN.R index 9a1e321e..16131e9b 100644 --- a/R/PLN.R +++ b/R/PLN.R @@ -45,7 +45,7 @@ PLN <- function(formula, data, subset, weights, control = PLN_param()) { ## post-treatment if (control$trace > 0) cat("\n Post-treatments...") - myPLN$postTreatment(args$Y, args$X, args$O, args$w, control$config_post) + myPLN$postTreatment(args$Y, args$X, args$O, args$w, control$config_post, control$config_optim) if (control$trace > 0) cat("\n DONE!\n") myPLN @@ -105,7 +105,7 @@ PLN_param <- function( Omega = NULL, config_post = list(), config_optim = list(), - inception = NULL # pretrained PLNfit used as initialization + inception = NULL # pretrained PLNfit used as initialization, ) { covariance <- match.arg(covariance) diff --git a/R/PLNLDA.R b/R/PLNLDA.R index a56036cd..80549d8c 100644 --- a/R/PLNLDA.R +++ b/R/PLNLDA.R @@ -64,7 +64,7 @@ PLNLDA <- function(formula, data, subset, weights, grouping, control = PLN_param myLDA$optimize(grouping, args$Y, args$X, args$O, args$w, control$config_optim) ## Post-treatment: prepare LDA visualization - myLDA$postTreatment(grouping, args$Y, args$X, args$O, control$config_post) + myLDA$postTreatment(grouping, args$Y, args$X, args$O, control$config_post, control$config_optim) if (control$trace > 0) cat("\n DONE!\n") myLDA diff --git a/R/PLNLDAfit-class.R b/R/PLNLDAfit-class.R index f4a8762d..feb6c5c5 100644 --- a/R/PLNLDAfit-class.R +++ b/R/PLNLDAfit-class.R @@ -85,9 +85,9 @@ PLNLDAfit <- R6Class( ## Post treatment -------------------- #' @description Update R2, fisher and std_err fields and visualization #' @param config list controlling the post-treatment - postTreatment = function(grouping, responses, covariates, offsets, config) { + postTreatment = function(grouping, responses, covariates, offsets, config_post, config_optim) { covariates <- cbind(covariates, model.matrix( ~ grouping + 0)) - super$postTreatment(responses, covariates, offsets, config = config) + super$postTreatment(responses, covariates, offsets, config_post = config_post, config_optim = config_optim) rownames(private$C) <- colnames(private$C) <- colnames(responses) colnames(private$S) <- 1:self$q if (config$trace > 1) cat("\n\tCompute LD scores for visualization...") diff --git a/R/PLNPCA.R b/R/PLNPCA.R index c5c26a95..4606550b 100644 --- a/R/PLNPCA.R +++ b/R/PLNPCA.R @@ -52,7 +52,7 @@ PLNPCA <- function(formula, data, subset, weights, ranks = 1:5, control = PLNPCA ## Post-treatments: pseudo-R2, rearrange criteria and prepare PCA visualization if (control$trace > 0) cat("\n Post-treatments") config_post <- config_post_default_PLNPCA; config_post$trace <- control$trace - myPCA$postTreatment(config_post) + myPCA$postTreatment(config_post, control$config_optim) if (control$trace > 0) cat("\n DONE!\n") myPCA diff --git a/R/PLNPCAfit-class.R b/R/PLNPCAfit-class.R index 85f799ca..9b55e5a9 100644 --- a/R/PLNPCAfit-class.R +++ b/R/PLNPCAfit-class.R @@ -169,8 +169,8 @@ PLNPCAfit <- R6Class( #' * variational_var boolean indicating whether variational Fisher information matrix should be computed to estimate the variance of the model parameters (highly underestimated). Default is FALSE. #' * rsquared boolean indicating whether approximation of R2 based on deviance should be computed. Default is TRUE #' * trace integer for verbosity. should be > 1 to see output in post-treatments - postTreatment = function(responses, covariates, offsets, weights, config, nullModel) { - super$postTreatment(responses, covariates, offsets, weights, config, nullModel) + postTreatment = function(responses, covariates, offsets, weights, config_post, config_optim, nullModel) { + super$postTreatment(responses, covariates, offsets, weights, config_post, config_optim, nullModel) colnames(private$C) <- colnames(private$M) <- 1:self$q rownames(private$C) <- colnames(responses) self$setVisualization() diff --git a/R/PLNfamily-class.R b/R/PLNfamily-class.R index 4c686855..fda35a9c 100644 --- a/R/PLNfamily-class.R +++ b/R/PLNfamily-class.R @@ -64,17 +64,18 @@ PLNfamily <- ## %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% ## Post treatment -------------------- #' @description Update fields after optimization - #' @param config a list for controlling the post-treatment. - postTreatment = function(config) { - nullModel <- nullModelPoisson(self$responses, self$covariates, self$offsets, self$weights) + #' @param config_post a list for controlling the post-treatment. + postTreatment = function(config_post, config_optim) { + #nullModel <- nullModelPoisson(self$responses, self$covariates, self$offsets, self$weights) for (model in self$models) model$postTreatment( self$responses, self$covariates, self$offsets, self$weights, - config, - nullModel = nullModel + config_post=config_post, + config_optim=config_optim, + nullModel = NULL ) }, diff --git a/R/PLNfit-class.R b/R/PLNfit-class.R index 715d0499..065137ce 100644 --- a/R/PLNfit-class.R +++ b/R/PLNfit-class.R @@ -62,8 +62,9 @@ PLNfit <- R6Class( torch_elbo = function(data, params, index=torch_tensor(1:self$n)) { S2 <- torch_square(params$S[index]) Z <- data$O[index] + params$M[index] + torch_mm(data$X[index], params$B) + A <- torch_exp(Z + .5 * S2) res <- .5 * sum(data$w[index]) * torch_logdet(private$torch_Sigma(data, params, index)) + - sum(data$w[index,NULL] * (torch_exp(Z + .5 * S2) - data$Y[index] * Z - .5 * torch_log(S2))) + sum(data$w[index,NULL] * (A - data$Y[index] * Z - .5 * torch_log(S2))) res }, @@ -80,21 +81,27 @@ PLNfit <- R6Class( torch_vloglik = function(data, params) { S2 <- torch_square(params$S) - Ji <- .5 * self$p - rowSums(.logfactorial(as.matrix(data$Y))) + as.numeric( - .5 * torch_logdet(params$Omega) + - torch_sum(data$Y * params$Z - params$A + .5 * torch_log(S2), dim = 2) - - .5 * torch_sum(torch_mm(params$M, params$Omega) * params$M + S2 * torch_diag(params$Omega), dim = 2) - ) - attr(Ji, "weights") <- as.numeric(data$w) + + Ji_tmp = .5 * torch_logdet(params$Omega) + + torch_sum(data$Y * params$Z - params$A + .5 * torch_log(S2), dim = 2) - + .5 * torch_sum(torch_mm(params$M, params$Omega) * params$M + S2 * torch_diag(params$Omega), dim = 2) + Ji_tmp = Ji_tmp$cpu() + Ji_tmp = as.numeric(Ji_tmp) + Ji <- .5 * self$p - rowSums(.logfactorial(as.matrix(data$Y$cpu()))) + Ji_tmp + + attr(Ji, "weights") <- as.numeric(data$w$cpu()) Ji }, #' @import torch torch_optimize = function(data, params, config) { + #config$device = "mps" + if (config$trace > 1) + message (paste("optimizing with device: ", config$device)) ## Conversion of data and parameters to torch tensors (pointers) - data <- lapply(data, torch_tensor) # list with Y, X, O, w - params <- lapply(params, torch_tensor, requires_grad = TRUE) # list with B, M, S + data <- lapply(data, torch_tensor, dtype = torch_float32(), device = config$device) # list with Y, X, O, w + params <- lapply(params, torch_tensor, dtype = torch_float32(), requires_grad = TRUE, device = config$device) # list with B, M, S ## Initialize optimizer optimizer <- switch(config$algorithm, @@ -111,11 +118,14 @@ PLNfit <- R6Class( batch_size <- floor(self$n/num_batch) objective <- double(length = config$num_epoch + 1) + #B_old = optimizer$param_groups[[1]]$params$B$clone() for (iterate in 1:num_epoch) { - B_old <- as.numeric(optimizer$param_groups[[1]]$params$B) - + #B_old <- as.numeric(optimizer$param_groups[[1]]$params$B) # rearrange the data each epoch - permute <- torch::torch_randperm(self$n) + 1L + #permute <- torch::torch_randperm(self$n, device = "cpu") + 1L + permute = torch::torch_tensor(sample.int(self$n), dtype = torch_long(), device=config$device) + + #print (paste("num batches", num_batch)) for (batch_idx in 1:num_batch) { # here index is a vector of the indices in the batch index <- permute[(batch_size*(batch_idx - 1) + 1):(batch_idx*batch_size)] @@ -129,9 +139,7 @@ PLNfit <- R6Class( ## assess convergence objective[iterate + 1] <- loss$item() - B_new <- as.numeric(optimizer$param_groups[[1]]$params$B) delta_f <- abs(objective[iterate] - objective[iterate + 1]) / abs(objective[iterate + 1]) - delta_x <- sum(abs(B_old - B_new))/sum(abs(B_new)) ## Error message if objective diverges if (!is.finite(loss$item())) { @@ -140,13 +148,14 @@ PLNfit <- R6Class( } ## display progress - if (config$trace > 1 && (iterate %% 50 == 0)) + if (config$trace > 1 && (iterate %% 50 == 1)) cat('\niteration: ', iterate, 'objective', objective[iterate + 1], - 'delta_f' , round(delta_f, 6), 'delta_x', round(delta_x, 6)) + 'delta_f' , round(delta_f, 6)) ## Check for convergence + #print (delta_f) if (delta_f < config$ftol_rel) status <- 3 - if (delta_x < config$xtol_rel) status <- 4 + #if (delta_x < config$xtol_rel) status <- 4 if (status %in% c(3,4)) { objective <- objective[1:iterate + 1] break @@ -158,7 +167,10 @@ PLNfit <- R6Class( params$Z <- data$O + params$M + torch_matmul(data$X, params$B) params$A <- torch_exp(params$Z + torch_pow(params$S, 2)/2) - out <- lapply(params, as.matrix) + out <- lapply(params, function(x) { + x = x$cpu() + as.matrix(x)} + ) out$Ji <- private$torch_vloglik(data, params) out$monitoring <- list( objective = objective, @@ -174,7 +186,7 @@ PLNfit <- R6Class( ## PRIVATE METHODS FOR VARIANCE OF THE ESTIMATORS ## %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% - variance_variational = function(X) { + variance_variational = function(X, config = config_default_nlopt) { ## Variance of B for n data points fisher <- Matrix::bdiag(lapply(1:self$p, function(j) { crossprod(X, private$A[, j] * X) # t(X) %*% diag(A[, i]) %*% X @@ -204,8 +216,56 @@ PLNfit <- R6Class( invisible(list(var_B = var_B, var_Omega = var_Omega)) }, + compute_vcov_from_resamples = function(resamples){ + # compute the covariance of the parameters + get_cov_mat = function(data, cell_group) { + + cov_matrix = cov(data) + rownames(cov_matrix) = paste0(cell_group, "_", rownames(cov_matrix)) + colnames(cov_matrix) = paste0(cell_group, "_", colnames(cov_matrix)) + return(cov_matrix) + } + + + B_list = resamples %>% map("B") + #print (B_list) + vcov_B = lapply(seq(1, ncol(private$B)), function(B_col){ + param_ests_for_col = B_list %>% map(~.x[, B_col]) + param_ests_for_col = do.call(rbind, param_ests_for_col) + #print (param_ests_for_col) + row_vcov = cov(param_ests_for_col) + }) + #print ("vcov blocks") + #print (vcov_B) + + #B_vcov <- resamples %>% map("B") %>% map(~( . )) %>% reduce(cov) + + #var_jack <- jacks %>% map("B") %>% map(~( (. - B_jack)^2)) %>% reduce(`+`) %>% + # `dimnames<-`(dimnames(private$B)) + #B_hat <- private$B[,] ## strips attributes while preserving names + + vcov_B = Matrix::bdiag(vcov_B) %>% as.matrix() + + rownames(vcov_B) <- colnames(vcov_B) <- + expand.grid(covariates = rownames(private$B), + responses = colnames(private$B)) %>% rev() %>% + ## Hack to make sure that species is first and varies slowest + apply(1, paste0, collapse = "_") + + #print (pheatmap::pheatmap(vcov_B, cluster_rows=FALSE, cluster_cols=FALSE)) + + + #names = lapply(bootstrapped_df$cov_mat, function(m){ colnames(m)}) %>% unlist() + #rownames(bootstrapped_vhat) = names + #colnames(bootstrapped_vhat) = names + + vcov_B = methods::as(vcov_B, "dgCMatrix") + + return(vcov_B) + }, + variance_jackknife = function(Y, X, O, w, config = config_default_nlopt) { - jacks <- future.apply::future_lapply(seq_len(self$n), function(i) { + jacks <- lapply(seq_len(self$n), function(i) { data <- list(Y = Y[-i, , drop = FALSE], X = X[-i, , drop = FALSE], O = O[-i, , drop = FALSE], @@ -215,7 +275,7 @@ PLNfit <- R6Class( config = config) optim_out <- do.call(private$optimizer$main, args) optim_out[c("B", "Omega")] - }, future.seed = TRUE) + }) B_jack <- jacks %>% map("B") %>% reduce(`+`) / self$n var_jack <- jacks %>% map("B") %>% map(~( (. - B_jack)^2)) %>% reduce(`+`) %>% @@ -224,6 +284,9 @@ PLNfit <- R6Class( attr(private$B, "bias") <- (self$n - 1) * (B_jack - B_hat) attr(private$B, "variance_jackknife") <- (self$n - 1) / self$n * var_jack + vcov_jacks = private$compute_vcov_from_resamples(jacks) + attr(private$B, "vcov_jackknife") <- vcov_jacks + Omega_jack <- jacks %>% map("Omega") %>% reduce(`+`) / self$n var_jack <- jacks %>% map("Omega") %>% map(~( (. - Omega_jack)^2)) %>% reduce(`+`) %>% `dimnames<-`(dimnames(private$Omega)) @@ -234,23 +297,37 @@ PLNfit <- R6Class( variance_bootstrap = function(Y, X, O, w, n_resamples = 100, config = config_default_nlopt) { resamples <- replicate(n_resamples, sample.int(self$n, replace = TRUE), simplify = FALSE) - boots <- future.apply::future_lapply(resamples, function(resample) { + boots <- lapply(resamples, function(resample) { data <- list(Y = Y[resample, , drop = FALSE], X = X[resample, , drop = FALSE], O = O[resample, , drop = FALSE], w = w[resample]) + #print (config$torch_device) + #print (config) + if (config$algorithm %in% c("RPROP", "RMSPROP", "ADAM", "ADAGRAD")) # hack, to know if we're doing torch or not + data <- lapply(data, torch_tensor, device = config$device) # list with Y, X, O, w + + #print (data$Y$device) + args <- list(data = data, params = list(B = private$B, M = matrix(0,self$n,self$p), S = private$S[resample, ]), config = config) + if (config$algorithm %in% c("RPROP", "RMSPROP", "ADAM", "ADAGRAD")) # hack, to know if we're doing torch or not + args$params <- lapply(args$params, torch_tensor, requires_grad = TRUE, device = config$device) # list with B, M, S + optim_out <- do.call(private$optimizer$main, args) + #print (optim_out) optim_out[c("B", "Omega", "monitoring")] - }, future.seed = TRUE) + }) B_boots <- boots %>% map("B") %>% reduce(`+`) / n_resamples attr(private$B, "variance_bootstrap") <- boots %>% map("B") %>% map(~( (. - B_boots)^2)) %>% reduce(`+`) %>% `dimnames<-`(dimnames(private$B)) / n_resamples + vcov_boots = private$compute_vcov_from_resamples(boots) + attr(private$B, "vcov_bootstrap") <- vcov_boots + Omega_boots <- boots %>% map("Omega") %>% reduce(`+`) / n_resamples attr(private$Omega, "variance_bootstrap") <- boots %>% map("Omega") %>% map(~( (. - Omega_boots)^2)) %>% reduce(`+`) %>% @@ -386,7 +463,7 @@ PLNfit <- R6Class( #' * variational_var boolean indicating whether variational Fisher information matrix should be computed to estimate the variance of the model parameters (highly underestimated). Default is FALSE. #' * rsquared boolean indicating whether approximation of R2 based on deviance should be computed. Default is TRUE #' * trace integer for verbosity. should be > 1 to see output in post-treatments - postTreatment = function(responses, covariates, offsets, weights = rep(1, nrow(responses)), config, nullModel = NULL) { + postTreatment = function(responses, covariates, offsets, weights = rep(1, nrow(responses)), config_post, config_optim, nullModel = NULL) { ## PARAMATERS DIMNAMES ## Set names according to those of the data matrices. If missing, use sensible defaults if (is.null(colnames(responses))) @@ -403,24 +480,27 @@ PLNfit <- R6Class( ## OPTIONAL POST-TREATMENT (potentially costly) ## 1. compute and store approximated R2 with Poisson-based deviance - if (config$rsquared) { - if(config$trace > 1) cat("\n\tComputing bootstrap estimator of the variance...") + if (config_post$rsquared) { + if(config_post$trace > 1) cat("\n\tComputing approximate R^2...") private$approx_r2(responses, covariates, offsets, weights, nullModel) } ## 2. compute and store matrix of standard variances for B and Omega with rough variational approximation - if (config$variational_var) { - if(config$trace > 1) cat("\n\tComputing variational estimator of the variance...") - private$variance_variational(covariates) + if (config_post$variational_var) { + if(config_post$trace > 1) cat("\n\tComputing variational estimator of the variance...") + private$variance_variational(covariates, config = config_optim) } ## 3. Jackknife estimation of bias and variance - if (config$jackknife) { - if(config$trace > 1) cat("\n\tComputing jackknife estimator of the variance...") - private$variance_jackknife(responses, covariates, offsets, weights) + if (config_post$jackknife) { + if(config_post$trace > 1) cat("\n\tComputing jackknife estimator of the variance...") + private$variance_jackknife(responses, covariates, offsets, weights, config = config_optim) } ## 4. Bootstrap estimation of variance - if (config$bootstrap > 0) { - if(config$trace > 1) cat("\n\tComputing bootstrap estimator of the variance...") - private$variance_bootstrap(responses, covariates, offsets, weights, config$bootstrap) + if (config_post$bootstrap > 0) { + if(config_post$trace > 1) { + cat("\n\tComputing bootstrap estimator of the variance...") + #print (str(config_optim)) + } + private$variance_bootstrap(responses, covariates, offsets, weights, n_resamples=config_post$bootstrap, config = config_optim) } }, @@ -815,11 +895,11 @@ PLNfit_fixedcov <- R6Class( #' * bootstrap integer indicating the number of bootstrap resamples generated to evaluate the variance of the model parameters. Default is 0 (inactivated). #' * variational_var boolean indicating whether variational Fisher information matrix should be computed to estimate the variance of the model parameters (highly underestimated). Default is FALSE. #' * rsquared boolean indicating whether approximation of R2 based on deviance should be computed. Default is TRUE - postTreatment = function(responses, covariates, offsets, weights = rep(1, nrow(responses)), config, nullModel = NULL) { - super$postTreatment(responses, covariates, offsets, weights, config, nullModel) + postTreatment = function(responses, covariates, offsets, weights = rep(1, nrow(responses)), config_post, config_optim, nullModel = NULL) { + super$postTreatment(responses, covariates, offsets, weights, config_post, config_optim, nullModel) ## 6. compute and store matrix of standard variances for B with sandwich correction approximation - if (config$sandwich_var) { - if(config$trace > 1) cat("\n\tComputing sandwich estimator of the variance...") + if (config_post$sandwich_var) { + if(config_post$trace > 1) cat("\n\tComputing sandwich estimator of the variance...") private$vcov_sandwich_B(responses, covariates) } } diff --git a/R/PLNmixture.R b/R/PLNmixture.R index 7231fe1d..b723c6b2 100644 --- a/R/PLNmixture.R +++ b/R/PLNmixture.R @@ -60,7 +60,7 @@ PLNmixture <- function(formula, data, subset, clusters = 1:5, control = PLNmixt ## Post-treatments: Compute pseudo-R2, rearrange criteria and the visualization for PCA if (control$trace > 0) cat("\n Post-treatments") config_post <- config_post_default_PLNmixture; config_post$trace <- control$trace - myPLN$postTreatment(config_post) + myPLN$postTreatment(config_post, control$config_optim) if (control$trace > 0) cat("\n DONE!\n") myPLN diff --git a/R/PLNmixturefit-class.R b/R/PLNmixturefit-class.R index 23363380..12eca6a9 100644 --- a/R/PLNmixturefit-class.R +++ b/R/PLNmixturefit-class.R @@ -281,7 +281,7 @@ PLNmixturefit <- ## Post treatment -------------------- #' @description Update fields after optimization #' @param config a list for controlling the post-treatment - postTreatment = function(responses, covariates, offsets, weights, config, nullModel) { + postTreatment = function(responses, covariates, offsets, weights, config_post, config_optim, nullModel) { ## restoring the full design matrix (group means + covariates) mu_k <- matrix(1, self$n, ncol = 1); colnames(mu_k) <- 'Intercept' @@ -292,7 +292,8 @@ PLNmixturefit <- mu_k, offsets, private$tau[,k_], - config, + config_post, + config_optim, nullModel = nullModel ) }, diff --git a/R/PLNnetwork.R b/R/PLNnetwork.R index 4d644a69..197bdf91 100644 --- a/R/PLNnetwork.R +++ b/R/PLNnetwork.R @@ -41,8 +41,9 @@ PLNnetwork <- function(formula, data, subset, weights, penalties = NULL, control ## Post-treatments if (control$trace > 0) cat("\n Post-treatments") - config_post <- config_post_default_PLNnetwork; config_post$trace <- control$trace - myPLN$postTreatment(config_post) + #config_post <- config_post_default_PLNnetwork; + #config_post$trace <- control$trace + myPLN$postTreatment(control$config_post, control$config_optim) if (control$trace > 0) cat("\n DONE!\n") myPLN @@ -72,18 +73,25 @@ PLNnetwork <- function(formula, data, subset, weights, penalties = NULL, control #' @seealso [PLN_param()] #' @export PLNnetwork_param <- function( - backend = "nlopt", + backend = c("nlopt", "torch"), + covariance = c("fixed", "spherical", "diagonal"), trace = 1 , n_penalties = 30 , min_ratio = 0.1 , penalize_diagonal = TRUE , penalty_weights = NULL , + config_post = list(), config_optim = list(), inception = NULL ) { if (!is.null(inception)) stopifnot(isPLNfit(inception)) + ## post-treatment config + config_pst <- config_post_default_PLN + config_pst[names(config_post)] <- config_post + config_pst$trace <- trace + ## optimization config backend <- match.arg(backend) stopifnot(backend %in% c("nlopt", "torch")) @@ -95,6 +103,7 @@ PLNnetwork_param <- function( stopifnot(config_optim$algorithm %in% available_algorithms_torch) config_opt <- config_default_torch } + covariance <- match.arg(covariance) config_opt$trace <- trace config_opt$ftol_out <- 1e-5 config_opt$maxit_out <- 20 @@ -103,6 +112,7 @@ PLNnetwork_param <- function( structure(list( backend = backend , trace = trace , + covariance = covariance , n_penalties = n_penalties , min_ratio = min_ratio , penalize_diagonal = penalize_diagonal, @@ -110,6 +120,7 @@ PLNnetwork_param <- function( jackknife = FALSE , bootstrap = 0 , variance = TRUE , + config_post = config_pst , config_optim = config_opt , inception = inception ), class = "PLNmodels_param") } diff --git a/R/PLNnetworkfamily-class.R b/R/PLNnetworkfamily-class.R index d62d2053..efe58309 100644 --- a/R/PLNnetworkfamily-class.R +++ b/R/PLNnetworkfamily-class.R @@ -45,7 +45,15 @@ PLNnetworkfamily <- R6Class( ## A basic model for inception, useless one is defined by the user ### TODO check if it is useful if (is.null(control$inception)) { - myPLN <- PLNfit$new(responses, covariates, offsets, weights, formula, control) + + # CHECK_ME_TORCH_GPU + # This appears to be in torch_gpu only. The commented out line below is + # in both PLNmodels/master and PLNmodels/dev. + myPLN <- switch(control$covariance, + "spherical" = PLNfit_spherical$new(responses, covariates, offsets, weights, formula, control), + "diagonal" = PLNfit_diagonal$new(responses, covariates, offsets, weights, formula, control), + PLNfit$new(responses, covariates, offsets, weights, formula, control)) # defaults to fixed + # myPLN <- PLNfit$new(responses, covariates, offsets, weights, formula, control) myPLN$optimize(responses, covariates, offsets, weights, control$config_optim) control$inception <- myPLN } @@ -68,8 +76,13 @@ PLNnetworkfamily <- R6Class( ## Get an appropriate grid of penalties if (is.null(penalties)) { if (control$trace > 1) cat("\n Recovering an appropriate grid of penalties.") + # CHECK_ME_TORCH_GPU + # This appears to be in torch_gpu only. The commented out line below is + # in both PLNmodels/master and PLNmodels/dev. + # changed it to other one max_pen <- list_penalty_weights %>% - map(~ control$inception$model_par$Sigma / .x) %>% + map(~ as.matrix(myPLN$model_par$Sigma) / .x) %>% + # map(~ control$inception$model_par$Sigma / .x) %>% map_dbl(~ max(abs(.x[upper.tri(.x, diag = control$penalize_diagonal)]))) %>% max() penalties <- 10^seq(log10(max_pen), log10(max_pen*control$min_ratio), len = control$n_penalties) diff --git a/R/utils.R b/R/utils.R index 9bc10dfe..60351772 100644 --- a/R/utils.R +++ b/R/utils.R @@ -26,7 +26,8 @@ config_default_torch <- step_sizes = c(1e-3, 50), etas = c(0.5, 1.2), centered = FALSE, - trace = 1 + trace = 1, + device = "cpu" ) config_post_default_PLN <- @@ -107,6 +108,11 @@ trace <- function(x) sum(diag(x)) x } +.logfactorial_torch <- function(n){ + n[n == 0] <- 1 ## 0! = 1! + n*torch_log(n) - n + torch_log(8*torch_pow(n,3) + 4*torch_pow(n,2) + n + 1/30)/6 + log(pi)/2 +} + .logfactorial <- function(n) { # Ramanujan's formula n[n == 0] <- 1 ## 0! = 1! n*log(n) - n + log(8*n^3 + 4*n^2 + n + 1/30)/6 + log(pi)/2 diff --git a/tests/testthat/test-standard-error.R b/tests/testthat/test-standard-error.R index 9bd20803..43768350 100644 --- a/tests/testthat/test-standard-error.R +++ b/tests/testthat/test-standard-error.R @@ -95,7 +95,7 @@ test_that("Check that variance estimation are coherent in PLNfit", { trace = 2 ) - myPLN$postTreatment(Y, X, exp(log_O), config = config_post) + myPLN$postTreatment(Y, X, exp(log_O), config_post = config_post) tr_variational <- sum(standard_error(myPLN, "variational")^2) tr_bootstrap <- sum(standard_error(myPLN, "bootstrap")^2)