diff --git a/test/infer/test_autoguide.py b/test/infer/test_autoguide.py index 61be7f317..d0c945faa 100644 --- a/test/infer/test_autoguide.py +++ b/test/infer/test_autoguide.py @@ -1236,7 +1236,7 @@ def model(): model, model, subsample_plate="N", use_global_dais_params=use_global_dais_params ) svi = SVI(model, guide, optax.adam(0.02), Trace_ELBO()) - svi_results = svi.run(random.PRNGKey(0), 3000) + svi_results = svi.run(random.PRNGKey(0), 5000) samples = guide.sample_posterior( random.PRNGKey(1), svi_results.params, sample_shape=(1000,) ) diff --git a/test/infer/test_gradient.py b/test/infer/test_gradient.py index dec977909..b97fe67e9 100644 --- a/test/infer/test_gradient.py +++ b/test/infer/test_gradient.py @@ -460,8 +460,8 @@ def actual_loss_fn(params_raw): actual_loss, actual_grads = jax.value_and_grad(actual_loss_fn)(params_raw) - assert_equal(actual_loss, expected_loss, prec=3e-3) - assert_equal(actual_grads, expected_grads, prec=4e-3) + assert_equal(actual_loss, expected_loss, prec=0.05) + assert_equal(actual_grads, expected_grads, prec=0.005) def test_analytic_kl_3(): @@ -555,8 +555,8 @@ def actual_loss_fn(params_raw): actual_loss, actual_grads = jax.value_and_grad(actual_loss_fn)(params_raw) - assert_equal(actual_loss, expected_loss, prec=3e-3) - assert_equal(actual_grads, expected_grads, prec=4e-3) + assert_equal(actual_loss, expected_loss, prec=0.01) + assert_equal(actual_grads, expected_grads, prec=0.005) @pytest.mark.parametrize("scale1", [1, 10]) diff --git a/test/infer/test_hmc_gibbs.py b/test/infer/test_hmc_gibbs.py index c4195da68..427692abd 100644 --- a/test/infer/test_hmc_gibbs.py +++ b/test/infer/test_hmc_gibbs.py @@ -194,7 +194,7 @@ def model(): mcmc.run(random.PRNGKey(0)) mcmc.print_summary() samples = mcmc.get_samples() - assert_allclose(jnp.mean(samples["x"], 0), 0.7 * jnp.ones(3), atol=0.01) + assert_allclose(jnp.mean(samples["x"], 0), 0.7 * jnp.ones(3), atol=0.05) assert_allclose(jnp.mean(samples["y"], 0), 0.3 * 10, atol=0.1)