From a975ab3803f6dfaf9be71590636eb54bdd75b0ab Mon Sep 17 00:00:00 2001 From: frans Date: Wed, 6 Dec 2023 14:41:09 +0100 Subject: [PATCH] test --- test/contrib/test_control_flow.py | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/test/contrib/test_control_flow.py b/test/contrib/test_control_flow.py index 67273d02c..f75686daf 100644 --- a/test/contrib/test_control_flow.py +++ b/test/contrib/test_control_flow.py @@ -12,7 +12,9 @@ import numpyro.distributions as dist from numpyro.handlers import mask, seed, substitute, trace from numpyro.infer import MCMC, NUTS, SVI, Predictive, Trace_ELBO +from numpyro.infer.autoguide import AutoNormal from numpyro.infer.util import log_density, potential_energy +from numpyro.optim import Adam def test_scan(): @@ -241,3 +243,32 @@ def transition(carry, y_curr): assert model_density assert model_trace["x"]["fn"].batch_shape == (12, 10) assert model_trace["x"]["fn"].event_shape == (3,) + + +def test_scan_svi(): + T = 3 + N = 5 + + def gaussian_hmm(y=None, T=T, N=N): + def transition(x_prev, y_curr): + with numpyro.plate("data", N): + x_curr = numpyro.sample("x", dist.Normal(x_prev, 1.5)) + y_curr = numpyro.sample("y", dist.Normal(x_curr, 0.1), obs=y_curr) + return x_curr, (x_curr, y_curr) + + with numpyro.plate("data", N): + x0 = numpyro.sample("x_0", dist.Normal(jnp.zeros(N), 5.0)) + _, (x, y) = scan(transition, x0, y, length=T) + return (x, y) + + with numpyro.handlers.seed(rng_seed=0): + x, y = gaussian_hmm() + with numpyro.handlers.seed(rng_seed=0): + tr = numpyro.handlers.trace(gaussian_hmm).get_trace(y=y, T=T, N=N) + + guide = AutoNormal(gaussian_hmm) + svi = SVI(gaussian_hmm, guide, Adam(0.1), Trace_ELBO(), y=y, T=T, N=N) + results = svi.run(random.PRNGKey(0), 10**3) + + xhat = results.params["x_auto_loc"] + assert_allclose(xhat, tr["x"]["value"], rtol=0.1)