From f4800b0714c5085064115e3adcb83ed78fff9562 Mon Sep 17 00:00:00 2001 From: Sabine Dritz Date: Thu, 21 Apr 2022 12:20:58 -0700 Subject: [PATCH 1/2] trying to implement hierarchical priors --- model_pystan3_hierarchical_priors.stan | 65 ++++++++++++++++++++++++++ 1 file changed, 65 insertions(+) create mode 100644 model_pystan3_hierarchical_priors.stan diff --git a/model_pystan3_hierarchical_priors.stan b/model_pystan3_hierarchical_priors.stan new file mode 100644 index 0000000..f616cdc --- /dev/null +++ b/model_pystan3_hierarchical_priors.stan @@ -0,0 +1,65 @@ +data { + // Dimensions of the data matrix, and matrix itself. + int n_p; + int n_a; + array[n_p, n_a] int M; +} +transformed data { + // Pre-compute the marginals of M to save computation in the model loop. + array[n_p] int M_rows = rep_array(0, n_p); + array[n_a] int M_cols = rep_array(0, n_a); + int M_tot = 0; + for (i in 1:n_p) { + for (j in 1:n_a) { + M_rows[i] += M[i, j]; + M_cols[j] += M[i, j]; + M_tot += M[i, j]; + } + } +} +parameters { + real C; + real r; + simplex[n_p] sigma; + simplex[n_a] tau; + real rho; +} +model { + // Prior + r ~ exponential(0.01); + + // Global sums and parameters + target += M_tot * log(C) - C; + // Weighted marginals of the data matrix + for (i in 1:n_p) { + target += M_rows[i] * log(sigma[i]); + } + for (j in 1:n_a) { + target += M_cols[j] * log(tau[j]); + } + // Pairwise loop + for (i in 1:n_p) { + for (j in 1:n_a) { + real nu_ij_0 = log(1 - rho); + real nu_ij_1 = log(rho) + M[i,j] * log(1 + r) - C * r * sigma[i] * tau[j]; + if (nu_ij_0 > nu_ij_1) + target += nu_ij_0 + log1p_exp(nu_ij_1 - nu_ij_0); + else + target += nu_ij_1 + log1p_exp(nu_ij_0 - nu_ij_1); + } + } +} +generated quantities { + // Posterior edge probability matrix + array[n_p, n_a] real Q; + for (i in 1:n_p) { + for (j in 1:n_a) { + real nu_ij_0 = log(1 - rho); + real nu_ij_1 = log(rho) + M[i,j] * log(1 + r) - C * r * sigma[i] * tau[j]; + if (nu_ij_1 > 0) + Q[i, j] = 1 / (1+ exp(nu_ij_0 - nu_ij_1)); + else + Q[i, j] = exp(nu_ij_1) / (exp(nu_ij_0) + exp(nu_ij_1)); + } + } +} From 5f22ec95703326e9230ea627546cb13465cb1812 Mon Sep 17 00:00:00 2001 From: Sabine Dritz <70649535+sjdritz@users.noreply.github.com> Date: Thu, 21 Apr 2022 12:29:08 -0700 Subject: [PATCH 2/2] Update model_pystan3_hierarchical_priors.stan --- model_pystan3_hierarchical_priors.stan | 22 ++++++++++++++++------ 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/model_pystan3_hierarchical_priors.stan b/model_pystan3_hierarchical_priors.stan index f616cdc..f81789e 100644 --- a/model_pystan3_hierarchical_priors.stan +++ b/model_pystan3_hierarchical_priors.stan @@ -20,28 +20,38 @@ transformed data { parameters { real C; real r; - simplex[n_p] sigma; - simplex[n_a] tau; + real mu_alpha_plants; + real mu_alpha_pols; + real sigma_alpha_plants; + real sigma_alpha_pols; + vector[n_p] alpha_plants; + vector[n_a] alpha_pols; + simplex[n_p] plant_abundances; + simplex[n_a] pol_abundances; real rho; } model { // Prior r ~ exponential(0.01); + alpha_plants ~ lognormal(mu_alpha_plants, sigma_alpha_plants); + alpha_pols ~ lognormal(mu_alpha_pols, sigma_alpha_pols); + plant_abundances ~ dirichlet(alpha_plants); + pol_abundances ~ dirichlet(alpha_pols); // Global sums and parameters target += M_tot * log(C) - C; // Weighted marginals of the data matrix for (i in 1:n_p) { - target += M_rows[i] * log(sigma[i]); + target += M_rows[i] * log(plant_abundances[i]); } for (j in 1:n_a) { - target += M_cols[j] * log(tau[j]); + target += M_cols[j] * log(pol_abundances[j]); } // Pairwise loop for (i in 1:n_p) { for (j in 1:n_a) { real nu_ij_0 = log(1 - rho); - real nu_ij_1 = log(rho) + M[i,j] * log(1 + r) - C * r * sigma[i] * tau[j]; + real nu_ij_1 = log(rho) + M[i,j] * log(1 + r) - C * r * plant_abundances[i] * pol_abundances[j]; if (nu_ij_0 > nu_ij_1) target += nu_ij_0 + log1p_exp(nu_ij_1 - nu_ij_0); else @@ -55,7 +65,7 @@ generated quantities { for (i in 1:n_p) { for (j in 1:n_a) { real nu_ij_0 = log(1 - rho); - real nu_ij_1 = log(rho) + M[i,j] * log(1 + r) - C * r * sigma[i] * tau[j]; + real nu_ij_1 = log(rho) + M[i,j] * log(1 + r) - C * r * plant_abundances[i] * pol_abundances[j]; if (nu_ij_1 > 0) Q[i, j] = 1 / (1+ exp(nu_ij_0 - nu_ij_1)); else