Skip to content

Commit

Permalink
added exact form and optimization for W|Y (with tests)
Browse files Browse the repository at this point in the history
  • Loading branch information
jchiquet committed Jan 24, 2024
1 parent a6656e6 commit 5aea835
Show file tree
Hide file tree
Showing 13 changed files with 98 additions and 95 deletions.
12 changes: 6 additions & 6 deletions R/RcppExports.R
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,12 @@ optim_zipln_zipar_covar <- function(R, init_B0, X0, configuration) {
.Call('_PLNmodels_optim_zipln_zipar_covar', PACKAGE = 'PLNmodels', R, init_B0, X0, configuration)
}

optim_zipln_R <- function(Y, X, O, M, S, Pi) {
.Call('_PLNmodels_optim_zipln_R', PACKAGE = 'PLNmodels', Y, X, O, M, S, Pi)
optim_zipln_R_var <- function(Y, X, O, M, S, Pi, B) {
.Call('_PLNmodels_optim_zipln_R_var', PACKAGE = 'PLNmodels', Y, X, O, M, S, Pi, B)
}

optim_zipln_R_exact <- function(Y, X, O, M, S, Pi, B) {
.Call('_PLNmodels_optim_zipln_R_exact', PACKAGE = 'PLNmodels', Y, X, O, M, S, Pi, B)
}

optim_zipln_M <- function(init_M, Y, X, O, R, S, B, Omega, configuration) {
Expand All @@ -85,7 +89,3 @@ cpp_test_packing <- function() {
.Call('_PLNmodels_cpp_test_packing', PACKAGE = 'PLNmodels')
}

# Register entry points for exported C++ functions
methods::setLoadAction(function(ns) {
.Call('_PLNmodels_RcppExport_registerCCallable', PACKAGE = 'PLNmodels')
})
7 changes: 5 additions & 2 deletions R/ZIPLN.R
Original file line number Diff line number Diff line change
Expand Up @@ -77,10 +77,12 @@ ZIPLN <- function(formula, data, subset, zi = c("single", "row", "col"), control
#' @return list of parameters used during the fit and post-processing steps
#'
#' @inherit PLN_param details
#' @details See [PLN_param()] for a full description of the generic optimization parameters. ZIPLN_param() also has two additional parameters controlling the optimization due
#' the inner-outer loop structure of the optimizer:
#' @details See [PLN_param()] for a full description of the generic optimization parameters. ZIPLN_param() also
#' has two additional parameters controlling the optimization due the inner-outer loop structure of the optimizer,
#' and additional parameter controlling the form of the variational approximation of the zero inflation:
#' * "ftol_out" outer solver stops when an optimization step changes the objective function by less than `ftol_out` multiplied by the absolute value of the parameter. Default is 1e-8
#' * "maxit_out" outer solver stops when the number of iteration exceeds `maxit_out`. Default is 100
#' * "approx_ZI" either use an exact or approximated conditional distribution for the zero inflantion. Default is FALSE
#'
#' @export
ZIPLN_param <- function(
Expand Down Expand Up @@ -113,6 +115,7 @@ ZIPLN_param <- function(
config_opt$trace <- trace
config_opt$ftol_out <- 1e-6
config_opt$maxit_out <- 100
config_opt$approx_ZI <- FALSE
config_opt[names(config_optim)] <- config_optim

structure(list(
Expand Down
10 changes: 5 additions & 5 deletions R/ZIPLNfit-class.R
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ ZIPLNfit <- R6Class(
"col" = function(R, ...) list(Pi = matrix(colMeans(R), nrow(R), p, byrow = TRUE), B0 = matrix(NA, d0, p)),
"covar" = optim_zipln_zipar_covar
)
private$optimizer$R <- ifelse(control$config_optim$approx_ZI, optim_zipln_R_var, optim_zipln_R_exact)
private$optimizer$Omega <- optim_zipln_Omega_full

},
Expand Down Expand Up @@ -179,9 +180,8 @@ ZIPLNfit <- R6Class(

### VE Step
# ZI part
new_R <- optim_zipln_R(
Y = data$Y, X = data$X, O = data$O, M = parameters$M, S = parameters$S, Pi = new_Pi
)
new_R <- private$optimizer$R(Y = data$Y, X = data$X, O = data$O, M = parameters$M, S = parameters$S, Pi = new_Pi, B = new_B)

# PLN part
new_M <- optim_zipln_M(
init_M = parameters$M,
Expand Down Expand Up @@ -300,8 +300,8 @@ ZIPLNfit <- R6Class(
)$Pi

# VE Step
new_R <- optim_zipln_R(
Y = data$Y, X = data$X, O = data$O, M = parameters$M, S = parameters$S, Pi = Pi
new_R <- private$optimizer$R(
Y = data$Y, X = data$X, O = data$O, M = parameters$M, S = parameters$S, Pi = Pi, B = B
)
new_M <- optim_zipln_M(
init_M = parameters$M,
Expand Down
2 changes: 1 addition & 1 deletion inst/case_studies/scRNA.R
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ data(scRNA)
# data subsample: only 500 random cell and the 200 most varying transcript
scRNA <- scRNA[sample.int(nrow(scRNA), 500), ]
scRNA$counts <- scRNA$counts[, 1:200]
myZIPLN <- ZIPLN(counts ~ 1 + offset(log(total_counts)), data = scRNA)
myZIPLN <- ZIPLN(counts ~ 1 + offset(log(total_counts)), zi = "col", data = scRNA)
myPLN <- PLN(counts ~ 1 + offset(log(total_counts)), data = scRNA)

data.frame(
Expand Down
9 changes: 0 additions & 9 deletions inst/include/PLNmodels.h

This file was deleted.

30 changes: 0 additions & 30 deletions inst/include/PLNmodels_RcppExports.h

This file was deleted.

4 changes: 2 additions & 2 deletions inst/simus_ZIPLN/essai_ZIPLN.R
Original file line number Diff line number Diff line change
Expand Up @@ -133,8 +133,8 @@ p <- ggplot(res) + aes(x = factor(n), y = pred_Y, fill = factor(method)) + geom_
scale_y_log10() + ylim(c(0,2))
p

p <- ggplot(res) + aes(x = factor(n), y = rmse_B, fill = factor(method)) + geom_violin() + theme_bw() + scale_y_log10() + ylim(c(2.75,3))
p <- ggplot(res) + aes(x = factor(n), y = rmse_B, fill = factor(method)) + geom_violin() + theme_bw() + scale_y_log10() + ylim(c(2,5))
p

p <- ggplot(res) + aes(x = factor(n), y = rmse_Omega, fill = factor(method)) + geom_violin() + theme_bw() + scale_y_log10() + ylim(c(0,0.5))
p <- ggplot(res) + aes(x = factor(n), y = rmse_Omega, fill = factor(method)) + geom_violin() + theme_bw() + scale_y_log10() + ylim(c(0.1,.3))
p
6 changes: 4 additions & 2 deletions man/ZIPLN_param.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions man/ZIPLNfit.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

47 changes: 24 additions & 23 deletions src/RcppExports.cpp
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
// Generated by using Rcpp::compileAttributes() -> do not edit by hand
// Generator token: 10BE3573-1514-4C36-9D1C-5A225CD40393

#include "../inst/include/PLNmodels.h"
#include <RcppArmadillo.h>
#include <Rcpp.h>
#include <string>
#include <set>

using namespace Rcpp;

Expand Down Expand Up @@ -253,9 +250,9 @@ BEGIN_RCPP
return rcpp_result_gen;
END_RCPP
}
// optim_zipln_R
arma::mat optim_zipln_R(const arma::mat& Y, const arma::mat& X, const arma::mat& O, const arma::mat& M, const arma::mat& S, const arma::mat& Pi);
RcppExport SEXP _PLNmodels_optim_zipln_R(SEXP YSEXP, SEXP XSEXP, SEXP OSEXP, SEXP MSEXP, SEXP SSEXP, SEXP PiSEXP) {
// optim_zipln_R_var
arma::mat optim_zipln_R_var(const arma::mat& Y, const arma::mat& X, const arma::mat& O, const arma::mat& M, const arma::mat& S, const arma::mat& Pi, const arma::mat& B);
RcppExport SEXP _PLNmodels_optim_zipln_R_var(SEXP YSEXP, SEXP XSEXP, SEXP OSEXP, SEXP MSEXP, SEXP SSEXP, SEXP PiSEXP, SEXP BSEXP) {
BEGIN_RCPP
Rcpp::RObject rcpp_result_gen;
Rcpp::RNGScope rcpp_rngScope_gen;
Expand All @@ -265,7 +262,25 @@ BEGIN_RCPP
Rcpp::traits::input_parameter< const arma::mat& >::type M(MSEXP);
Rcpp::traits::input_parameter< const arma::mat& >::type S(SSEXP);
Rcpp::traits::input_parameter< const arma::mat& >::type Pi(PiSEXP);
rcpp_result_gen = Rcpp::wrap(optim_zipln_R(Y, X, O, M, S, Pi));
Rcpp::traits::input_parameter< const arma::mat& >::type B(BSEXP);
rcpp_result_gen = Rcpp::wrap(optim_zipln_R_var(Y, X, O, M, S, Pi, B));
return rcpp_result_gen;
END_RCPP
}
// optim_zipln_R_exact
arma::mat optim_zipln_R_exact(const arma::mat& Y, const arma::mat& X, const arma::mat& O, const arma::mat& M, const arma::mat& S, const arma::mat& Pi, const arma::mat& B);
RcppExport SEXP _PLNmodels_optim_zipln_R_exact(SEXP YSEXP, SEXP XSEXP, SEXP OSEXP, SEXP MSEXP, SEXP SSEXP, SEXP PiSEXP, SEXP BSEXP) {
BEGIN_RCPP
Rcpp::RObject rcpp_result_gen;
Rcpp::RNGScope rcpp_rngScope_gen;
Rcpp::traits::input_parameter< const arma::mat& >::type Y(YSEXP);
Rcpp::traits::input_parameter< const arma::mat& >::type X(XSEXP);
Rcpp::traits::input_parameter< const arma::mat& >::type O(OSEXP);
Rcpp::traits::input_parameter< const arma::mat& >::type M(MSEXP);
Rcpp::traits::input_parameter< const arma::mat& >::type S(SSEXP);
Rcpp::traits::input_parameter< const arma::mat& >::type Pi(PiSEXP);
Rcpp::traits::input_parameter< const arma::mat& >::type B(BSEXP);
rcpp_result_gen = Rcpp::wrap(optim_zipln_R_exact(Y, X, O, M, S, Pi, B));
return rcpp_result_gen;
END_RCPP
}
Expand Down Expand Up @@ -316,20 +331,6 @@ BEGIN_RCPP
END_RCPP
}

// validate (ensure exported C++ functions exist before calling them)
static int _PLNmodels_RcppExport_validate(const char* sig) {
static std::set<std::string> signatures;
if (signatures.empty()) {
}
return signatures.find(sig) != signatures.end();
}

// registerCCallable (register entry points for exported C++ functions)
RcppExport SEXP _PLNmodels_RcppExport_registerCCallable() {
R_RegisterCCallable("PLNmodels", "_PLNmodels_RcppExport_validate", (DL_FUNC)_PLNmodels_RcppExport_validate);
return R_NilValue;
}

static const R_CallMethodDef CallEntries[] = {
{"_PLNmodels_cpp_test_nlopt", (DL_FUNC) &_PLNmodels_cpp_test_nlopt, 0},
{"_PLNmodels_nlopt_optimize_diagonal", (DL_FUNC) &_PLNmodels_nlopt_optimize_diagonal, 3},
Expand All @@ -348,11 +349,11 @@ static const R_CallMethodDef CallEntries[] = {
{"_PLNmodels_optim_zipln_Omega_diagonal", (DL_FUNC) &_PLNmodels_optim_zipln_Omega_diagonal, 4},
{"_PLNmodels_optim_zipln_B_dense", (DL_FUNC) &_PLNmodels_optim_zipln_B_dense, 2},
{"_PLNmodels_optim_zipln_zipar_covar", (DL_FUNC) &_PLNmodels_optim_zipln_zipar_covar, 4},
{"_PLNmodels_optim_zipln_R", (DL_FUNC) &_PLNmodels_optim_zipln_R, 6},
{"_PLNmodels_optim_zipln_R_var", (DL_FUNC) &_PLNmodels_optim_zipln_R_var, 7},
{"_PLNmodels_optim_zipln_R_exact", (DL_FUNC) &_PLNmodels_optim_zipln_R_exact, 7},
{"_PLNmodels_optim_zipln_M", (DL_FUNC) &_PLNmodels_optim_zipln_M, 9},
{"_PLNmodels_optim_zipln_S", (DL_FUNC) &_PLNmodels_optim_zipln_S, 7},
{"_PLNmodels_cpp_test_packing", (DL_FUNC) &_PLNmodels_cpp_test_packing, 0},
{"_PLNmodels_RcppExport_registerCCallable", (DL_FUNC) &_PLNmodels_RcppExport_registerCCallable, 0},
{NULL, NULL, 0}
};

Expand Down
12 changes: 1 addition & 11 deletions src/lambertW.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,17 +36,7 @@ Fritsch, F. N.; Shafer, R. E. & Crowley, W. P.
1973, 16, 123-124
*/

// [[Rcpp::depends(RcppParallel)]]
// [[Rcpp::interfaces(r, cpp)]]
#include <Rcpp.h>

#define _USE_MATH_DEFINES
#include <cmath>

using namespace Rcpp;

const double EPS = 2.2204460492503131e-16;
const double M_1_E = 1.0 / M_E;
#include "lambertW.h"

/* Fritsch Iteration
* W_{n+1} = W_n * (1 + e_n)
Expand Down
39 changes: 37 additions & 2 deletions src/optim_zi-pln.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include "nlopt_wrapper.h"
#include "packing.h"
#include "utils.h"
#include "lambertW.h"

// [[Rcpp::export]]
arma::vec zipln_vloglik(
Expand Down Expand Up @@ -118,13 +119,14 @@ Rcpp::List optim_zipln_zipar_covar(
}

// [[Rcpp::export]]
arma::mat optim_zipln_R(
arma::mat optim_zipln_R_var(
const arma::mat & Y, // responses (n,p)
const arma::mat & X, // covariates (n,d)
const arma::mat & O, // offsets (n,p)
const arma::mat & M, // (n,p)
const arma::mat & S, // (n,p)
const arma::mat & Pi // (d,p)
const arma::mat & Pi, // (d,p)
const arma::mat & B // covariates (n,d)
) {
arma::mat A = exp(O + M + 0.5 * S % S);
arma::mat R = pow(1. + exp(- (A + logit(Pi))), -1);
Expand All @@ -144,6 +146,39 @@ arma::mat optim_zipln_R(
return R;
}

double phi (double mu, double sigma2) {
double W = lambertW0_CS(sigma2 * exp(mu)) ;
return(exp(-(pow(W, 2) + 2 * W) / (2 * sigma2)) / sqrt(1 + W)) ;
}

// [[Rcpp::export]]
arma::mat optim_zipln_R_exact (
const arma::mat & Y, // covariates (n,d)
const arma::mat & X, // covariates (n,d)
const arma::mat & O, // offsets (n,p)
const arma::mat & M, // (n,p)
const arma::mat & S, // (n,p)
const arma::mat & Pi, // (n,p)
const arma::mat & B // covariates (n,d)
) {

arma::mat XB = X * B;
arma::mat M_mu = M - XB;
arma::uword n = M.n_rows;
arma::uword p = M.n_cols;
arma::vec diag_Sigma = arma::diagvec((1./n) * (M_mu.t() * M_mu + diagmat(sum(S % S, 0)))) ;
arma::mat R = arma::zeros(n,p);
for(arma::uword i = 0; i < n; i += 1) {
for(arma::uword j = 0; j < p; j += 1) {
if(Y(i, j) < 0.5) {
double Phi = phi(O(i,j) + XB(i,j), diag_Sigma(j)) ;
R(i,j) = Pi(i,j) / (Phi * (1 - Pi(i,j)) + Pi(i,j)) ;
}
}
}
return R;
}

// [[Rcpp::export]]
Rcpp::List optim_zipln_M(
const arma::mat & init_M, // (n,p)
Expand Down
11 changes: 11 additions & 0 deletions tests/testthat/test-zipln.R
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,17 @@ test_that("PLN is working with unnamed data matrix", {
expect_error(ZIPLN(Abundance ~ 1, data = trichoptera, control = ZIPLN_param(config_optim = list(algorithm = "nawak"))))
})

test_that("ZIPLN is working with exact and variational inference for the conditional distribution of the ZI component", {

approx <- ZIPLN(Abundance ~ 1, data = trichoptera, control = ZIPLN_param(config_optim = list(approx_ZI = TRUE)))
exact <- ZIPLN(Abundance ~ 1, data = trichoptera, control = ZIPLN_param(config_optim = list(approx_ZI = FALSE)))

expect_equal(approx$loglik, exact$loglik, tolerance = 1e-1) ## Almost equivalent
expect_equal(approx$model_par$B, exact$model_par$B, tolerance = 1e-1) ## Almost equivalent
expect_equal(approx$model_par$Sigma, exact$model_par$Sigma, tolerance = 1e-1) ## Almost equivalent

})

test_that("ZIPLN: Check that univariate ZIPLN models works, with matrix of numeric format", {
expect_no_error(uniZIPLN <- ZIPLN(Abundance[,1,drop=FALSE] ~ 1, data = trichoptera))
expect_no_error(uniZIPLN <- ZIPLN(Abundance[,1] ~ 1, data = trichoptera))
Expand Down

0 comments on commit 5aea835

Please sign in to comment.