Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Numerical test test_chees_adaptation fails on aarch64-linux #668

Open
GaetanLepage opened this issue May 11, 2024 · 7 comments
Open

Numerical test test_chees_adaptation fails on aarch64-linux #668

GaetanLepage opened this issue May 11, 2024 · 7 comments
Assignees

Comments

@GaetanLepage
Copy link
Contributor

Describe the issue as clearly as possible:

Using jax/jaxlib 0.4.28 and jaxopt 0.8.3, the following test fails on aarch64-linux:

FAILED tests/adaptation/test_adaptation.py::test_chees_adaptation - AssertionError:

Steps/code to reproduce the bug:

pytest

Expected result:

Tests pass.

Error message:

=================================== FAILURES ===================================
____________________________ test_chees_adaptation _____________________________
[gw1] linux -- Python 3.11.9 /nix/store/33752yykc8r75jxvpcvpcynm22il4ch7-python3-3.11.9/bin/python3.11

    def test_chees_adaptation():
        logprob_fn = lambda x: jax.scipy.stats.norm.logpdf(
            x, loc=0.0, scale=jnp.array([1.0, 10.0])
        ).sum()
    
        num_burnin_steps = 1000
        num_results = 500
        num_chains = 16
        step_size = 0.1
    
        init_key, warmup_key, inference_key = jax.random.split(jax.random.key(346), 3)
    
        warmup = blackjax.chees_adaptation(
            logprob_fn, num_chains=num_chains, target_acceptance_rate=0.75
        )
    
        initial_positions = jax.random.normal(init_key, (num_chains, 2))
        (last_states, parameters), warmup_info = warmup.run(
            warmup_key,
            initial_positions,
            step_size=step_size,
            optim=optax.adamw(learning_rate=0.5),
            num_steps=num_burnin_steps,
        )
        algorithm = blackjax.dynamic_hmc(logprob_fn, **parameters)
    
        chain_keys = jax.random.split(inference_key, num_chains)
        _, _, infos = jax.vmap(
            lambda key, state: run_inference_algorithm(key, state, algorithm, num_results)
        )(chain_keys, last_states)
    
        harmonic_mean = 1.0 / jnp.mean(1.0 / infos.acceptance_rate)
>       np.testing.assert_allclose(harmonic_mean, 0.75, rtol=1e-1)

tests/adaptation/test_adaptation.py:69: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

args = (<function assert_allclose.<locals>.compare at 0xfffee0976ca0>, array(0.6619941, dtype=float32), array(0.75))
kwds = {'equal_nan': True, 'err_msg': '', 'header': 'Not equal to tolerance rtol=0.1, atol=0', 'verbose': True}

    @wraps(func)
    def inner(*args, **kwds):
        with self._recreate_cm():
>           return func(*args, **kwds)
E           AssertionError: 
E           Not equal to tolerance rtol=0.1, atol=0
E           
E           Mismatched elements: 1 / 1 (100%)
E           Max absolute difference: 0.0880059
E           Max relative difference: 0.1173412
E            x: array(0.661994, dtype=float32)
E            y: array(0.75)

/nix/store/33752yykc8r75jxvpcvpcynm22il4ch7-python3-3.11.9/lib/python3.11/contextlib.py:81: AssertionError
=========================== short test summary info ============================
FAILED tests/adaptation/test_adaptation.py::test_chees_adaptation - AssertionError: 
============= 1 failed, 442 passed, 1 skipped in 139.54s (0:02:19) =============

Blackjax/JAX/jaxlib/Python version information:

BlackJAX 1.2.1
Python 3.11.9 (main, Apr  2 2024, 08:25:04) [GCC 13.2.0]
Jax 0.4.28
Jaxlib 0.4.28

Context for the issue:

No response

@junpenglao
Copy link
Member

@albcab could you take a look to make this test more robust?

@albcab
Copy link
Member

albcab commented May 13, 2024

I can't reproduce the failing test on my machine, so it's hard to debug. Making atol=1e-2 and rtol=0 would obviously fix it. Since the harmonic mean is random, increasing chains/steps/learning rates might not necessarily avoid the failing test.

Just for the sake of passing the test in aarch64-linux I would change rtol to atol, @junpenglao thoughts?

@junpenglao
Copy link
Member

Yeah sure.

@GaetanLepage
Copy link
Contributor Author

I can already confirm that the test is fixed thanks to @albcab's patch.
Maybe we can wait until the next release to mark this issue as closed.

@GaetanLepage
Copy link
Contributor Author

Things might have changed since then, because several tests fail on 1.2.4.
Skipping test_chees_adaptation is enough to make the test suite succeed.

@junpenglao
Copy link
Member

Thanks for the feedback, it is a bit difficult for us to replicate, could you paste the full trace for all the test fail?
There is not much change on blackjax side re Chees, i suspect it is something upstream in JAX is causing this.

@GaetanLepage
Copy link
Contributor Author

Sure, here it is: https://paste.glepage.com/upload/spider-bee-bison

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants