Skip to content

Commit

Permalink
Refactor nb_param for ZIPLNfit
Browse files Browse the repository at this point in the history
  • Loading branch information
mahendra-mariadassou committed Feb 20, 2024
1 parent bbc5737 commit c8eb7c3
Showing 1 changed file with 26 additions and 45 deletions.
71 changes: 26 additions & 45 deletions R/ZIPLNfit-class.R
Original file line number Diff line number Diff line change
Expand Up @@ -501,16 +501,21 @@ ZIPLNfit <- R6Class(
d = function() {nrow(private$B)},
#' @field d0 number of covariates in the ZI part
d0 = function() {nrow(private$B0)},
#' @field nb_param number of parameters in the current PLN model
#' @field nb_param_zi number of parameters in the ZI part of the model
nb_param_zi = function() {
as.integer(switch(private$ziparam,
"single" = 1L,
"row" = self$n,
"col" = self$p,
"covar" = self$p * self$d))
},
#' @field nb_param_pln number of parameters in the PLN part of the model
nb_param_pln = function() {
as.integer(self$p * self$d + self$p * (self$p + 1L) / 2L)
},
#' @field nb_param number of parameters in the ZIPLN model
nb_param = function() {
as.integer(
self$p * self$d + self$p * (self$p + 1L)/2L +
switch(private$ziparam,
"single" = 1L,
"row" = self$n,
"col" = self$p,
"covar" = self$p * self$d)
)
self$nb_param_zi + self$nb_param_pln
},
#' @field model_par a list with the matrices of parameters found in the model (B, Sigma, plus some others depending on the variant)
model_par = function() {list(B = private$B, B0 = private$B0, Pi = private$Pi, Omega = private$Omega, Sigma = private$Sigma)},
Expand Down Expand Up @@ -583,15 +588,9 @@ ZIPLNfit_diagonal <- R6Class(
}
),
active = list(
#' @field nb_param number of parameters in the current PLN model
nb_param = function() {
res <- self$p * self$d + self$p +
switch(private$ziparam,
"single" = 1L,
"row" = self$n,
"col" = self$p,
"covar" = self$p * self$d)
as.integer(res)
#' @field nb_param_pln number of parameters in the PLN part of the current model
nb_param_pln = function() {
as.integer(self$p * self$d + self$p)
},
#' @field vcov_model character: the model used for the residual covariance
vcov_model = function() {"diagonal"}
Expand Down Expand Up @@ -632,15 +631,9 @@ ZIPLNfit_spherical <- R6Class(
}
),
active = list(
#' @field nb_param number of parameters in the current PLN model
nb_param = function() {
res <- self$p * self$d + 1L +
switch(private$ziparam,
"single" = 1L,
"row" = self$n,
"col" = self$p,
"covar" = self$p * self$d)
as.integer(res)
#' @field nb_param_pln number of parameters in the PLN part of the current model
nb_param_pln = function() {
as.integer(self$p * self$d + 1L)
},
#' @field vcov_model character: the model used for the residual covariance
vcov_model = function() {"spherical"}
Expand Down Expand Up @@ -686,15 +679,9 @@ ZIPLNfit_fixed <- R6Class(
}
),
active = list(
#' @field nb_param number of parameters in the current PLN model
nb_param = function() {
res <- self$p * self$d +
switch(private$ziparam,
"single" = 1L,
"row" = self$n,
"col" = self$p,
"covar" = self$p * self$d)
as.integer(res)
#' @field nb_param_pln number of parameters in the PLN part of the current model
nb_param_pln = function() {
as.integer(self$p * self$d + 0L)
},
#' @field vcov_model character: the model used for the residual covariance
vcov_model = function() {"fixed"}
Expand Down Expand Up @@ -813,15 +800,9 @@ ZIPLNfit_sparse <- R6Class(
penalty_weights = function() {private$rho},
#' @field n_edges number of edges if the network (non null coefficient of the sparse precision matrix)
n_edges = function() {sum(private$Omega[upper.tri(private$Omega, diag = FALSE)] != 0)},
#' @field nb_param number of parameters in the current PLN model
nb_param = function() {
res <- self$p * self$d + self$n_edges +
switch(private$ziparam,
"single" = 1L,
"row" = self$n,
"col" = self$p,
"covar" = self$p * self$d)
as.integer(res)
#' @field nb_param_pln number of parameters in the PLN part of the current model
nb_param_pln = function() {
as.integer(self$p * self$d + self$n_edges)
},
#' @field vcov_model character: the model used for the residual covariance
vcov_model = function() {"sparse"},
Expand Down

0 comments on commit c8eb7c3

Please sign in to comment.