Skip to content

Commit

Permalink
test: increase iteration count and adjust precision tolerances in `in…
Browse files Browse the repository at this point in the history
…fer` tests
  • Loading branch information
Qazalbash committed Jan 31, 2025
1 parent f269393 commit eb86294
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 6 deletions.
2 changes: 1 addition & 1 deletion test/infer/test_autoguide.py
Original file line number Diff line number Diff line change
Expand Up @@ -1236,7 +1236,7 @@ def model():
model, model, subsample_plate="N", use_global_dais_params=use_global_dais_params
)
svi = SVI(model, guide, optax.adam(0.02), Trace_ELBO())
svi_results = svi.run(random.PRNGKey(0), 3000)
svi_results = svi.run(random.PRNGKey(0), 5000)
samples = guide.sample_posterior(
random.PRNGKey(1), svi_results.params, sample_shape=(1000,)
)
Expand Down
8 changes: 4 additions & 4 deletions test/infer/test_gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,8 +460,8 @@ def actual_loss_fn(params_raw):

actual_loss, actual_grads = jax.value_and_grad(actual_loss_fn)(params_raw)

assert_equal(actual_loss, expected_loss, prec=3e-3)
assert_equal(actual_grads, expected_grads, prec=4e-3)
assert_equal(actual_loss, expected_loss, prec=0.05)
assert_equal(actual_grads, expected_grads, prec=0.005)


def test_analytic_kl_3():
Expand Down Expand Up @@ -555,8 +555,8 @@ def actual_loss_fn(params_raw):

actual_loss, actual_grads = jax.value_and_grad(actual_loss_fn)(params_raw)

assert_equal(actual_loss, expected_loss, prec=3e-3)
assert_equal(actual_grads, expected_grads, prec=4e-3)
assert_equal(actual_loss, expected_loss, prec=0.01)
assert_equal(actual_grads, expected_grads, prec=0.005)


@pytest.mark.parametrize("scale1", [1, 10])
Expand Down
2 changes: 1 addition & 1 deletion test/infer/test_hmc_gibbs.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ def model():
mcmc.run(random.PRNGKey(0))
mcmc.print_summary()
samples = mcmc.get_samples()
assert_allclose(jnp.mean(samples["x"], 0), 0.7 * jnp.ones(3), atol=0.01)
assert_allclose(jnp.mean(samples["x"], 0), 0.7 * jnp.ones(3), atol=0.05)
assert_allclose(jnp.mean(samples["y"], 0), 0.3 * 10, atol=0.1)


Expand Down

0 comments on commit eb86294

Please sign in to comment.