Skip to content

Commit

Permalink
added fixed and sparse covariance for ZIPLN
Browse files Browse the repository at this point in the history
  • Loading branch information
jchiquet committed Jan 16, 2024
1 parent 718e747 commit 16bd6b2
Show file tree
Hide file tree
Showing 12 changed files with 291 additions and 112 deletions.
29 changes: 17 additions & 12 deletions R/ZIPLN.R
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,13 @@
#' @rdname ZIPLN
#' @include ZIPLNfit-class.R
#' @examples
#' data(scRNA)
#' # data subsample: only 100 random cell and the 50 most varying transcript
#' scRNA <- scRNA[sample.int(nrow(scRNA), 100), ]
#' scRNA$counts <- scRNA$counts[, 1:50]
#' myPLN_full <- ZIPLN(counts ~ 1 + cell_line + offset(log(total_counts)), data = scRNA)
#' myPLN_sparse <- ZIPLN(counts ~ 1 + offset(log(total_counts)), rho = .5, data = scRNA)
#' data(trichoptera)
#' trichoptera <- prepare_data(trichoptera$Abundance, trichoptera$Covariate)
#' myPLN <- PLN(Abundance ~ 1, data = trichoptera)
#' myZIPLN_1 <- ZIPLN(Abundance ~ 1, data = trichoptera, zi = "single")
#' myZIPLN_2 <- ZIPLN(Abundance ~ 1, data = trichoptera, zi = "row")
#' myZIPLN_3 <- ZIPLN(Abundance ~ 1, data = trichoptera, zi = "col")
#' myZIPLN_3 <- ZIPLN(Abundance ~ 1, data = trichoptera, zi = "col")
#' myPLN_full$criteria # better BIC with sparse version
#' myPLN_sparse$criteria
#' @seealso The class [`ZIPLNfit`]
Expand All @@ -31,7 +32,6 @@ ZIPLN <- function(formula, data, subset, zi = c("single", "row", "col"), control
args <- extract_model_zi(match.call(expand.dots = FALSE), parent.frame())

## define default control parameters for optim and eventually overwrite them by user-defined parameters
control$lambda <- 0
control$rho <- 0
control$ziparam <- ifelse((args$zicovar), "covar", match.arg(zi))
control$penalize_intercept <- FALSE
Expand All @@ -46,7 +46,8 @@ ZIPLN <- function(formula, data, subset, zi = c("single", "row", "col"), control
myPLN <- switch(control$covariance,
"diagonal" = ZIPLNfit_diagonal$new(args$Y , list(PLN = args$X, ZI = args$X0), args$O, args$w, args$formula, control),
"spherical" = ZIPLNfit_spherical$new(args$Y, list(PLN = args$X, ZI = args$X0), args$O, args$w, args$formula, control),
"fixed" = ZIPLNfit_fixedcov$new(args$Y , list(PLN = args$X, ZI = args$X0), args$O, args$w, args$formula, control),
"fixed" = ZIPLNfit_fixed$new(args$Y , list(PLN = args$X, ZI = args$X0), args$O, args$w, args$formula, control),
"sparse" = ZIPLNfit_sparse$new(args$Y , list(PLN = args$X, ZI = args$X0), args$O, args$w, args$formula, control),
ZIPLNfit$new(args$Y, list(PLN = args$X, ZI = args$X0), args$O, args$w, args$formula, control)) # default: full covariance

## optimization
Expand All @@ -62,13 +63,12 @@ ZIPLN <- function(formula, data, subset, zi = c("single", "row", "col"), control
## -----------------------------------------------------------------
## Series of setter to default parameters for user's main functions

available_algorithms <- c("MMA", "CCSAQ", "LBFGS", "VAR1", "VAR2", "TNEWTON", "TNEWTON_PRECOND", "TNEWTON_PRECOND_RESTART")

#' Control of a PLN fit
#'
#' Helper to define list of parameters to control the PLN fit. All arguments have defaults.
#'
#' @inheritParams PLN_param
#' @param penalty a user defined penalty for sparsifying the residual covariance. Default is 0 (no sparsity).
#' @return list of parameters configuring the fit.
#'
#' @inherit PLN_param details
Expand All @@ -81,15 +81,19 @@ available_algorithms <- c("MMA", "CCSAQ", "LBFGS", "VAR1", "VAR2", "TNEWTON", "T
ZIPLN_param <- function(
backend = c("nlopt"),
trace = 1,
covariance = c("full", "diagonal", "spherical", "fixed"),
covariance = c("full", "diagonal", "spherical", "fixed", "sparse"),
Omega = NULL,
penalty = 0,
config_post = list(),
config_optim = list(),
inception = NULL # pretrained ZIPLNfit used as initialization
) {

covariance <- match.arg(covariance)
if (covariance == "fixed") stopifnot(inherits(Omega, "matrix") | inherits(Omega, "Matrix"))
if (covariance == "fixed") stopifnot("Omega must be provied for fixed covariance" = inherits(Omega, "matrix") | inherits(Omega, "Matrix")) |> try()
if (inherits(Omega, "matrix") | inherits(Omega, "Matrix")) covariance <- "fixed"
if (covariance == "sparse") stopifnot("You should provide a positive penalty when chosing 'sparse' covariance" = penalty > 0) |> try()
if (penalty > 0) covariance <- "sparse"
if (!is.null(inception)) stopifnot(isZIPLNfit(inception))

## post-treatment config
Expand All @@ -111,6 +115,7 @@ ZIPLN_param <- function(
trace = trace ,
covariance = covariance,
Omega = Omega ,
penalty = penalty ,
config_post = config_pst,
config_optim = config_opt,
inception = inception), class = "PLNmodels_param")
Expand Down
112 changes: 104 additions & 8 deletions R/ZIPLNfit-class.R
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#' @param responses the matrix of responses (called Y in the model). Will usually be extracted from the corresponding field in PLNfamily-class
#' @param covariates design matrix (called X in the model). Will usually be extracted from the corresponding field in PLNfamily-class
#' @param offsets offset matrix (called O in the model). Will usually be extracted from the corresponding field in PLNfamily-class
#' @param weights an optional vector of observation weights to be used in the fitting process.
#' @param formula model formula used for fitting, extracted from the formula in the upper-level call
#' @param control a list for controlling the optimization. See details.
#'
Expand Down Expand Up @@ -437,7 +438,7 @@ ZIPLNfit_diagonal <- R6Class(
#' \dontrun{
#' data(trichoptera)
#' trichoptera <- prepare_data(trichoptera$Abundance, trichoptera$Covariate)
#' myPLN <- ZIPLN(Abundance ~ 1, data = trichoptera, control = ZIPLN_param(covariance = "spherical))
#' myPLN <- ZIPLN(Abundance ~ 1, data = trichoptera, control = ZIPLN_param(covariance = "spherical"))
#' class(myPLN)
#' print(myPLN)
#' }
Expand Down Expand Up @@ -471,7 +472,7 @@ ZIPLNfit_spherical <- R6Class(
)

## %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
## CLASS ZIPLNfit_fixedcov #############################
## CLASS ZIPLNfit_fixed #############################
## %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

#' An R6 Class to represent a ZIPLNfit in a standard, general framework, with fixed (inverse) residual covariance
Expand All @@ -486,19 +487,19 @@ ZIPLNfit_spherical <- R6Class(
#' @param control a list for controlling the optimization. See details.
#' @param config part of the \code{control} argument which configures the optimizer
#'
#' @rdname ZIPLNfit_fixedcov
#' @rdname ZIPLNfit_fixed
#' @importFrom R6 R6Class
#'
#' @examples
#' \dontrun{
#' data(trichoptera)
#' trichoptera <- prepare_data(trichoptera$Abundance, trichoptera$Covariate)
#' myPLN <- ZIPLN(Abundance ~ 1, data = trichoptera)
#' myPLN <- ZIPLN(Abundance ~ 1, data = trichoptera, contro = ZIPLN_param(Omega = diag(ncol(trichoptera$Abundance))))
#' class(myPLN)
#' print(myPLN)
#' }
ZIPLNfit_fixedcov <- R6Class(
classname = "ZIPLNfit_fixedcov",
ZIPLNfit_fixed <- R6Class(
classname = "ZIPLNfit_fixed",
inherit = ZIPLNfit,
## %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
## PUBLIC MEMBERS ----
Expand All @@ -508,7 +509,7 @@ ZIPLNfit_fixedcov <- R6Class(
initialize = function(responses, covariates, offsets, weights, formula, control) {
super$initialize(responses, covariates, offsets, weights, formula, control)
private$Omega <- control$Omega
### TODO handled fixed cov
private$optimizer$Omega <- function(M, X, B, S) {private$Omega}
}
),
active = list(
Expand All @@ -526,6 +527,101 @@ ZIPLNfit_fixedcov <- R6Class(
vcov_model = function() {"fixed"}
)
## %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
## END OF THE CLASS ZIPLNfit_fixedcov
## END OF THE CLASS ZIPLNfit_fixed
## %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
)

## %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
## CLASS ZIPLNfit_sparse #############################
## %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

#' An R6 Class to represent a ZIPLNfit in a standard, general framework, with sparse inverse residual covariance
#'
#' @param responses the matrix of responses (called Y in the model). Will usually be extracted from the corresponding field in PLNfamily-class
#' @param covariates design matrix (called X in the model). Will usually be extracted from the corresponding field in PLNfamily-class
#' @param offsets offset matrix (called O in the model). Will usually be extracted from the corresponding field in PLNfamily-class
#' @param data an optional data frame, list or environment (or object coercible by as.data.frame to a data frame) containing the variables in the model. If not found in data, the variables are taken from environment(formula), typically the environment from which PLN is called.
#' @param weights an optional vector of observation weights to be used in the fitting process.
#' @param nullModel null model used for approximate R2 computations. Defaults to a GLM model with same design matrix but not latent variable.
#' @param formula model formula used for fitting, extracted from the formula in the upper-level call
#' @param control a list for controlling the optimization. See details.
#' @param config part of the \code{control} argument which configures the optimizer
#'
#' @rdname ZIPLNfit_fixedcov
#' @importFrom R6 R6Class
#'
#' @examples
#' \dontrun{
#' data(trichoptera)
#' trichoptera <- prepare_data(trichoptera$Abundance, trichoptera$Covariate)
#' myPLN <- ZIPLN(Abundance ~ 1, data = trichoptera, control= ZIPLN_param(penalty = 0.2))
#' class(myPLN)
#' print(myPLN)
#' }
ZIPLNfit_sparse <- R6Class(
classname = "ZIPLNfit_sparse",
inherit = ZIPLNfit,
## %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
## PUBLIC MEMBERS ----
## %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
public = list(
#' @description Initialize a [`ZIPLNfit_fixedcov`] model
#' @importFrom glassoFast glassoFast
initialize = function(responses, covariates, offsets, weights, formula, control) {
super$initialize(responses, covariates, offsets, weights, formula, control)
private$optimizer$Omega <-
function(M, X, B, S) {
glassoFast( crossprod(M - X %*% B)/self$n + diag(colMeans(S * S), self$p, self$p), rho = control$penalty )$wi
}
}
),
active = list(
#' @field nb_param number of parameters in the current PLN model
nb_param = function() {
res <- self$p * self$d + (sum(private$Omega != 0) - self$p)/2L +
switch(private$ziparam,
"single" = 1,
"row" = self$n,
"col" = self$p,
"covar" = self$p * self$d)
as.integer(res)
},
#' @field vcov_model character: the model used for the residual covariance
vcov_model = function() {"sparse"}
)
## %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
## END OF THE CLASS ZIPLNfit_sparse
## %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
)


# Test convergence for a named list of parameters
# oldp, newp: named list of parameters
# xtol_rel: double ; negative or NULL = disabled
# xtol_abs: double ; negative or NULL = disabled
# Returns boolean
parameter_list_converged <- function(oldp, newp, xtol_abs = NULL, xtol_rel = NULL) {
# Strategy is to compare each pair of list elements with matching names.
stopifnot(is.list(oldp), is.list(newp))
oldp <- oldp[order(names(oldp))]
newp <- newp[order(names(newp))]
stopifnot(all(names(oldp) == names(newp)))

# Check convergence with xtol_rel if enabled
if(is.double(xtol_rel) && xtol_rel > 0) {
if(all(mapply(function(o, n) { all(abs(n - o) <= xtol_rel * abs(o)) }, oldp, newp))) {
return(TRUE)
}
}

# Check convergence with xtol_abs (homogeneous) if enabled
if(is.double(xtol_abs) && xtol_abs > 0) {
if(all(mapply(function(o, n) { all(abs(n - o) <= xtol_abs) }, oldp, newp))) {
return(TRUE)
}
}

# If no criteria has triggered, indicate no convergence
FALSE
}

68 changes: 8 additions & 60 deletions R/optim-zipln.R
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
optimize_zi <- function(init_parameters, Y, X, O, configuration) {

n <- nrow(Y); p <- ncol(Y); d <- ncol(X)

# Link to the approximate function to optimize Omega ,depending on the target structure
optim_zipln_Omega <- switch(
configuration$covariance,
Expand All @@ -36,7 +36,7 @@ optimize_zi <- function(init_parameters, Y, X, O, configuration) {
"col" = function(init_B0, X, R, config) list(Pi = matrix(colMeans(R), n, p, byrow = TRUE), B0 = matrix(NA, d, p)),
"covar" = optim_zipln_zipar_covar
)

maxit_out <- if("maxit_out" %in% names(configuration)) { configuration$maxit_out } else { 50 }

# Main loop
Expand All @@ -63,7 +63,7 @@ optimize_zi <- function(init_parameters, Y, X, O, configuration) {
new_B <- optim_zipln_B(
M = parameters$M, X = X, Omega = new_Omega, configuration
)

optim_new_zipar <- optim_zipln_zipar(
init_B0 = parameters$B0, X = X, R = parameters$R, config = configuration
)
Expand Down Expand Up @@ -125,11 +125,11 @@ optimize_zi <- function(init_parameters, Y, X, O, configuration) {
}
}

#' @importFrom glassoFast glassoFast
optim_zipln_Omega_sparse <- function(M, X, B, S, rho) {
n <- nrow(M); p <- ncol(M)
glassoFast::glassoFast( crossprod(M - X %*% B)/n + diag(colMeans(S * S), p, p), rho = rho )$wi
}
#' #' @importFrom glassoFast glassoFast
#' optim_zipln_Omega_sparse <- function(M, X, B, S, rho) {
#' n <- nrow(M); p <- ncol(M)
#' glassoFast::glassoFast( crossprod(M - X %*% B)/n + diag(colMeans(S * S), p, p), rho = rho )$wi
#' }

#' @importFrom glmnet glmnet
optim_zipln_B <- function(M, X, Omega, config) {
Expand Down Expand Up @@ -163,55 +163,3 @@ optim_zipln_B <- function(M, X, Omega, config) {
B
}

# Test convergence for a named list of parameters
# oldp, newp: named list of parameters
# xtol_rel: double ; negative or NULL = disabled
# xtol_abs: double ; negative or NULL = disabled
# Returns boolean
parameter_list_converged <- function(oldp, newp, xtol_abs = NULL, xtol_rel = NULL) {
# Strategy is to compare each pair of list elements with matching names.
# Named lists are just vectors (T,str) using order of insertion.
# mapply() is handy to do the pair tests, but it works on the underlying vector order (ignoring names).
# So reorder lists by their names to use mapply.
stopifnot(is.list(oldp), is.list(newp))
oldp <- oldp[order(names(oldp))]
newp <- newp[order(names(newp))]
stopifnot(all(names(oldp) == names(newp)))

# Check convergence with xtol_rel if enabled
if(is.double(xtol_rel) && xtol_rel > 0) {
if(all(mapply(function(o, n) { all(abs(n - o) <= xtol_rel * abs(o)) }, oldp, newp))) {
return(TRUE)
}
}

# Check convergence with xtol_abs (homogeneous) if enabled
if(is.double(xtol_abs) && xtol_abs > 0) {
if(all(mapply(function(o, n) { all(abs(n - o) <= xtol_abs) }, oldp, newp))) {
return(TRUE)
}
}

# Check convergence with xtol_abs as list(xtol_abs for each param_name)
if(is.list(xtol_abs)) {
xtol_abs <- xtol_abs[order(names(xtol_abs))]
stopifnot(all(names(oldp) == names(xtol_abs)))
# Due to the possible presence of NULLs, mapply may return a list. unlist allows all() to operate anyway.
if(all(unlist(mapply(
function(o, n, tol) {
if((is.double(tol) && tol > 0) || is.matrix(tol)) {
all(abs(n - o) <= tol)
} else {
NULL # Ignore comparison in outer all()
}
},
oldp, newp, xtol_abs
)))) {
return(TRUE)
}
}

# If no criteria has triggered, indicate no convergence
FALSE
}

2 changes: 1 addition & 1 deletion R/utils.R
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
available_algorithms_nlopt <- c("MMA", "CCSAQ", "LBFGS", "LBFGS_NOCEDAL", "VAR1", "VAR2")
available_algorithms_nlopt <- c("MMA", "CCSAQ", "LBFGS", "LBFGS_NOCEDAL", "VAR1", "VAR2") #"TNEWTON", "TNEWTON_PRECOND", "TNEWTON_PRECOND_RESTART"#
available_algorithms_torch <- c("RPROP", "RMSPROP", "ADAM", "ADAGRAD")

config_default_nlopt <-
Expand Down
13 changes: 7 additions & 6 deletions man/ZIPLN.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 4 additions & 1 deletion man/ZIPLN_param.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit 16bd6b2

Please sign in to comment.