Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add batching of evidence estimation inputs and update examples #307

Merged
merged 15 commits into from
Nov 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
142 changes: 63 additions & 79 deletions examples/gaussian_nondiagcov.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import jax
import jax.numpy as jnp
import emcee
import logging


def ln_analytic_evidence(ndim, cov):
Expand Down Expand Up @@ -99,29 +100,33 @@ def run_example(
inv_cov = jnp.linalg.inv(cov)
training_proportion = 0.5
if flow_type == "RealNVP":
epochs_num = 5
epochs_num = 10 #5
elif flow_type == "RQSpline":
epochs_num = 3
#epochs_num = 5
epochs_num = 110

# Beginning of path where plots will be saved
save_name_start = "examples/plots/" + flow_type

temperature = 0.8
temperature = 0.9
standardize = True
verbose = True

# Spline params
n_layers = 5
n_bins = 5
n_layers = 3
n_bins = 128
hidden_size = [32, 32]
spline_range = (-10.0, 10.0)

if flow_type == "RQSpline":
save_name_start += "_" + str(n_layers) + "l_" + str(n_bins) + "b_" + str(epochs_num) + "e_" + str(int(training_proportion * 100)) + "perc_" + str(temperature) + "T" + "_emcee"

# Start timer.
clock = time.process_time()

# Run multiple realisations.
n_realisations = 1
evidence_inv_summary = np.zeros((n_realisations, 3))
n_realisations = 100
ln_evidence_inv_summary = np.zeros((n_realisations, 5))
for i_realisation in range(n_realisations):
if n_realisations > 0:
hm.logs.info_log(
Expand All @@ -130,7 +135,7 @@ def run_example(
# Define the number of dimensions and the mean of the Gaussian
num_samples = nchains * samples_per_chain
# Initialize a PRNG key (you can use any valid key)
key = jax.random.PRNGKey(0)
key = jax.random.PRNGKey(i_realisation)
mean = jnp.zeros(ndim)

# Generate random samples from the 2D Gaussian distribution
Expand All @@ -139,7 +144,7 @@ def run_example(
samples = jnp.reshape(samples, (nchains, -1, ndim))
lnprob = jnp.reshape(lnprob, (nchains, -1))

MCMC = False
MCMC = True
if MCMC:
nburn = 500
# Set up and run sampler.
Expand All @@ -151,7 +156,7 @@ def run_example(
rstate = np.random.get_state() # Set random state to repeatable
# across calls.
(pos, prob, state) = sampler.run_mcmc(
pos, samples_per_chain, rstate0=rstate
pos, samples_per_chain, rstate0=rstate, progress=True
)
samples = np.ascontiguousarray(sampler.chain[:, nburn:, :])
lnprob = np.ascontiguousarray(sampler.lnprobability[:, nburn:])
Expand Down Expand Up @@ -191,92 +196,68 @@ def run_example(
ev = hm.Evidence(chains_test.nchains, model)
# ev.set_mean_shift(0.0)
ev.add_chains(chains_test)
ln_evidence, ln_evidence_std = ev.compute_ln_evidence()
err_ln_inv_evidence = ev.compute_ln_inv_evidence_errors()

# Compute analytic evidence.
if i_realisation == 0:
ln_evidence_analytic = ln_analytic_evidence(ndim, cov)

# ======================================================================
# Display evidence computation results.
# ======================================================================
hm.logs.info_log("---------------------------------")
hm.logs.info_log("The inverse evidence in log space is:")
hm.logs.info_log(
"Evidence: analytic = {}, estimated = {}".format(
np.exp(ln_evidence_analytic), np.exp(ln_evidence)
"ln_inv_evidence = {} +/- {}".format(
ev.ln_evidence_inv, err_ln_inv_evidence
)
)
hm.logs.info_log(
"Evidence: std = {}, std / estimate = {}".format(
np.exp(ln_evidence_std), np.exp(ln_evidence_std - ln_evidence)
"ln evidence = {} +/- {} {}".format(
-ev.ln_evidence_inv, -err_ln_inv_evidence[1], -err_ln_inv_evidence[0]
)
)
diff = np.log(np.abs(np.exp(ln_evidence_analytic) - np.exp(ln_evidence)))
hm.logs.info_log("Analytic ln evidence is {}".format(ln_evidence_analytic))
delta = -ln_evidence_analytic - ev.ln_evidence_inv
hm.logs.info_log(
"Evidence: |analytic - estimate| / estimate = {}".format(
np.exp(diff - ln_evidence)
)
)
# ======================================================================
# Display inverse evidence computation results.
# ======================================================================
hm.logs.debug_log("---------------------------------")
hm.logs.debug_log(
"Inv Evidence: analytic = {}, estimate = {}".format(
np.exp(-ln_evidence_analytic), ev.evidence_inv
)
)
hm.logs.debug_log(
"Inv Evidence: std = {}, std / estimate = {}".format(
np.sqrt(ev.evidence_inv_var),
np.sqrt(ev.evidence_inv_var) / ev.evidence_inv,
)
)
hm.logs.debug_log(
"Inv Evidence: kurtosis = {}, sqrt( 2 / ( n_eff - 1 ) ) = {}".format(
ev.kurtosis, np.sqrt(2.0 / (ev.n_eff - 1))
)
)
hm.logs.debug_log(
"Inv Evidence: sqrt( var(var) ) / var = {}".format(
np.sqrt(ev.evidence_inv_var_var) / ev.evidence_inv_var
"Difference between analytic and harmonic is {} +- {} {}".format(
delta, err_ln_inv_evidence[0], err_ln_inv_evidence[1]
)
)

hm.logs.info_log("kurtosis = {}".format(ev.kurtosis))
hm.logs.info_log(" Aim for ~3.")
check = np.exp(0.5 * ev.ln_evidence_inv_var_var - ev.ln_evidence_inv_var)
hm.logs.info_log("sqrt( var(var) ) / var = {}".format(check))
hm.logs.info_log(
"Inv Evidence: |analytic - estimate| / estimate = {}".format(
np.abs(np.exp(-ln_evidence_analytic) - ev.evidence_inv)
/ ev.evidence_inv
)
" Aim for sqrt( 2/(n_eff-1) ) = {}".format(np.sqrt(2.0 / (ev.n_eff - 1)))
)

# ===========================================================================
# Display more technical details
# ===========================================================================
hm.logs.debug_log("---------------------------------")
hm.logs.debug_log("Technical Details")
hm.logs.debug_log("---------------------------------")
hm.logs.debug_log(
hm.logs.info_log("---------------------------------")
hm.logs.info_log("Technical Details")
hm.logs.info_log("---------------------------------")
hm.logs.info_log(
"lnargmax = {}, lnargmin = {}".format(ev.lnargmax, ev.lnargmin)
)
hm.logs.debug_log(
hm.logs.info_log(
"lnprobmax = {}, lnprobmin = {}".format(ev.lnprobmax, ev.lnprobmin)
)
hm.logs.debug_log(
hm.logs.info_log(
"lnpredictmax = {}, lnpredictmin = {}".format(
ev.lnpredictmax, ev.lnpredictmin
)
)
hm.logs.debug_log("---------------------------------")
hm.logs.debug_log(
hm.logs.info_log("---------------------------------")
hm.logs.info_log(
"shift = {}, shift setting = {}".format(ev.shift_value, ev.shift)
)
hm.logs.debug_log("running sum total = {}".format(sum(ev.running_sum)))
hm.logs.debug_log("running sum = \n{}".format(ev.running_sum))
hm.logs.debug_log("nsamples per chain = \n{}".format(ev.nsamples_per_chain))
hm.logs.debug_log(
hm.logs.info_log("running sum total = {}".format(sum(ev.running_sum)))
hm.logs.info_log("running sum = \n{}".format(ev.running_sum))
hm.logs.info_log("nsamples per chain = \n{}".format(ev.nsamples_per_chain))
hm.logs.info_log(
"nsamples eff per chain = \n{}".format(ev.nsamples_eff_per_chain)
)
hm.logs.debug_log("===============================")
hm.logs.info_log("===============================")

# ======================================================================
# Create corner/triangle plot.
Expand Down Expand Up @@ -314,28 +295,31 @@ def run_example(

plt.show()

evidence_inv_summary[i_realisation, 0] = ev.evidence_inv
evidence_inv_summary[i_realisation, 1] = ev.evidence_inv_var
evidence_inv_summary[i_realisation, 2] = ev.evidence_inv_var_var
# Save out realisations for violin plot.
ln_evidence_inv_summary[i_realisation, 0] = ev.ln_evidence_inv
ln_evidence_inv_summary[i_realisation, 1] = err_ln_inv_evidence[0]
ln_evidence_inv_summary[i_realisation, 2] = err_ln_inv_evidence[1]
ln_evidence_inv_summary[i_realisation, 3] = ev.ln_evidence_inv_var
ln_evidence_inv_summary[i_realisation, 4] = ev.ln_evidence_inv_var_var

clock = time.process_time() - clock
hm.logs.info_log("Execution_time = {}s".format(clock))

if n_realisations > 1:
save_name = (
save_name_start
+ "_gaussian_nondiagcov_evidence_inv"
+ "_gaussian_nondiagcov_ln_evidence_inv"
+ "_realisations_{}D.dat".format(ndim)
)
np.savetxt(save_name, evidence_inv_summary)
evidence_inv_analytic_summary = np.zeros(1)
evidence_inv_analytic_summary[0] = np.exp(-ln_evidence_analytic)
np.savetxt(save_name, ln_evidence_inv_summary)
ln_evidence_inv_analytic_summary = np.zeros(1)
ln_evidence_inv_analytic_summary[0] = -ln_evidence_analytic
save_name = (
save_name_start
+ "_gaussian_nondiagcov_evidence_inv"
+ "_gaussian_nondiagcov_ln_evidence_inv"
+ "_analytic_{}D.dat".format(ndim)
)
np.savetxt(save_name, evidence_inv_analytic_summary)
np.savetxt(save_name, ln_evidence_inv_analytic_summary)

created_plots = True
if created_plots:
Expand All @@ -344,14 +328,14 @@ def run_example(

if __name__ == "__main__":
# Setup logging config.
hm.logs.setup_logging()
hm.logs.setup_logging(default_level=logging.DEBUG)

# Define parameters.
ndim = 5
nchains = 100
ndim = 21
nchains = 80
samples_per_chain = 5000
flow_str = "RealNVP"
# flow_str = "RQSpline"
#flow_str = "RealNVP"
flow_str = "RQSpline"
np.random.seed(10) # used for initializing covariance matrix

hm.logs.info_log("Non-diagonal Covariance Gaussian example")
Expand All @@ -365,4 +349,4 @@ def run_example(
hm.logs.debug_log("-------------------------")

# Run example.
run_example(flow_str, ndim, nchains, samples_per_chain, plot_corner=False)
run_example(flow_str, ndim, nchains, samples_per_chain, plot_corner=True)
Loading
Loading