Skip to content

Commit

Permalink
Methionine demo works up to creating idata
Browse files Browse the repository at this point in the history
  • Loading branch information
teddygroves committed Feb 21, 2025
1 parent 57cc6d7 commit 440c623
Show file tree
Hide file tree
Showing 2 changed files with 160 additions and 267 deletions.
61 changes: 33 additions & 28 deletions scripts/mcmc_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from enzax.kinetic_model import get_conc
from enzax.mcmc import get_idata, run_nuts
from enzax.steady_state import get_kinetic_model_steady_state
from enzax.statistical_modelling import enzax_log_density
from enzax.statistical_modelling import enzax_log_density, prior_from_truth

import equinox as eqx

Expand All @@ -24,6 +24,17 @@
jax.config.update("jax_enable_x64", True)


def simulate(key, truth, error):
key_conc, key_enz, key_flux = jax.random.split(key, num=3)
true_conc, true_log_enz, true_flux = truth
conc_err, enz_err, flux_err = error
return (
jnp.exp(jnp.log(true_conc) + jax.random.normal(key_conc) * conc_err),
jnp.exp(true_log_enz + jax.random.normal(key_enz) * enz_err),
true_flux + jax.random.normal(key_flux) * flux_err,
)


def main():
"""Demonstrate How to make a Bayesian kinetic model with enzax."""
true_parameters = methionine.parameters
Expand All @@ -38,51 +49,45 @@ def get_free_params(params):
params["dgf"],
)

false_tree = jax.tree.map(lambda _: False, true_parameters)
freespec = eqx.tree_at(
is_free = eqx.tree_at(
get_free_params,
false_tree,
jax.tree.map(lambda _: False, true_parameters),
replace_fn=lambda _: True,
)
free_params, fixed_params = eqx.partition(true_parameters, freespec)
prior_mean = free_params
prior_sd = jax.tree.map(lambda arr: jnp.full_like(arr, 0.1), prior_mean)
prior = jax.tree.transpose(
outer_treedef=jax.tree.structure(("*", "*")),
inner_treedef=jax.tree.structure(prior_mean),
pytree_to_transpose=[prior_mean, prior_sd],
free_params, fixed_params = eqx.partition(true_parameters, is_free)
is_mv = eqx.tree_at(
lambda params: params["dgf"],
jax.tree.map(lambda _: False, free_params),
replace=True,
)
# get true concentration
prior = prior_from_truth(free_params, sd=0.1, is_multivariate=is_mv)
# get true concentration, flux and log enzyme
true_conc = get_conc(
true_steady,
true_parameters["log_conc_unbalanced"],
methionine.structure,
)
# get true flux
true_flux = true_model.flux(true_steady)
true_log_enz_flat, _ = ravel_pytree(true_parameters["log_enzyme"])
# simulate observations
error_conc = 0.03
error_flux = 0.05
error_enzyme = 0.03
conc_err = jnp.full_like(true_conc, 0.03)
flux_err = jnp.full_like(true_flux, 0.05)
enz_err = jnp.full_like(true_log_enz_flat, 0.03)
key = jax.random.key(SEED)
true_log_enz_flat, _ = ravel_pytree(true_parameters["log_enzyme"])
key_conc, key_enz, key_flux, key_nuts = jax.random.split(key, num=4)
obs_conc = jnp.exp(
jnp.log(true_conc) + jax.random.normal(key_conc) * error_conc
)
obs_enzyme = jnp.exp(
true_log_enz_flat + jax.random.normal(key_enz) * error_enzyme
key_sim, key_nuts = jax.random.split(key, num=2)
measurement_errors = (conc_err, enz_err, flux_err)
measurement_values = simulate(
key=key_sim,
truth=(true_conc, true_log_enz_flat, true_flux),
error=measurement_errors,
)
obs_flux = true_flux + jax.random.normal(key_flux) * error_flux
print(obs_conc)
print(obs_enzyme)
print(obs_flux)
measurements = tuple(zip(measurement_values, measurement_errors))
posterior_log_density = jax.jit(
functools.partial(
enzax_log_density,
structure=true_model.structure,
fixed_parameters=fixed_params,
observations=[],
measurements=measurements,
prior=prior,
guess=default_guess,
)
Expand Down
Loading

0 comments on commit 440c623

Please sign in to comment.