Skip to content

Commit

Permalink
test
Browse files Browse the repository at this point in the history
  • Loading branch information
frans committed Dec 6, 2023
1 parent 70bc8f3 commit a975ab3
Showing 1 changed file with 31 additions and 0 deletions.
31 changes: 31 additions & 0 deletions test/contrib/test_control_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@
import numpyro.distributions as dist
from numpyro.handlers import mask, seed, substitute, trace
from numpyro.infer import MCMC, NUTS, SVI, Predictive, Trace_ELBO
from numpyro.infer.autoguide import AutoNormal
from numpyro.infer.util import log_density, potential_energy
from numpyro.optim import Adam


def test_scan():
Expand Down Expand Up @@ -241,3 +243,32 @@ def transition(carry, y_curr):
assert model_density
assert model_trace["x"]["fn"].batch_shape == (12, 10)
assert model_trace["x"]["fn"].event_shape == (3,)


def test_scan_svi():
T = 3
N = 5

def gaussian_hmm(y=None, T=T, N=N):
def transition(x_prev, y_curr):
with numpyro.plate("data", N):
x_curr = numpyro.sample("x", dist.Normal(x_prev, 1.5))
y_curr = numpyro.sample("y", dist.Normal(x_curr, 0.1), obs=y_curr)
return x_curr, (x_curr, y_curr)

with numpyro.plate("data", N):
x0 = numpyro.sample("x_0", dist.Normal(jnp.zeros(N), 5.0))
_, (x, y) = scan(transition, x0, y, length=T)
return (x, y)

with numpyro.handlers.seed(rng_seed=0):
x, y = gaussian_hmm()
with numpyro.handlers.seed(rng_seed=0):
tr = numpyro.handlers.trace(gaussian_hmm).get_trace(y=y, T=T, N=N)

guide = AutoNormal(gaussian_hmm)
svi = SVI(gaussian_hmm, guide, Adam(0.1), Trace_ELBO(), y=y, T=T, N=N)
results = svi.run(random.PRNGKey(0), 10**3)

xhat = results.params["x_auto_loc"]
assert_allclose(xhat, tr["x"]["value"], rtol=0.1)

0 comments on commit a975ab3

Please sign in to comment.