From b5a00c7bc4691310d0c42c9cd7340a92a7ed2d57 Mon Sep 17 00:00:00 2001 From: frostedoyster Date: Mon, 15 Jan 2024 19:59:23 +0100 Subject: [PATCH] Adjust tolerances --- sphericart-jax/python/tests/test_autograd.py | 27 ++++++++++++++------ sphericart-jax/python/tests/test_pure_jax.py | 10 ++++---- sphericart-jax/src/jax.cpp | 8 ------ 3 files changed, 24 insertions(+), 21 deletions(-) diff --git a/sphericart-jax/python/tests/test_autograd.py b/sphericart-jax/python/tests/test_autograd.py index 794bcd208..cf208e427 100644 --- a/sphericart-jax/python/tests/test_autograd.py +++ b/sphericart-jax/python/tests/test_autograd.py @@ -7,15 +7,14 @@ import sphericart.jax -@pytest.fixture -def xyz(): - key = jax.random.PRNGKey(0) - return 6 * jax.random.normal(key, (100, 3)) +@pytest.mark.parametrize("normalized", [True, False]) +def test_autograd(normalized): + # here, 32-bit numerical gradients are very noisy, so we use 64-bit + jax.config.update("jax_enable_x64", True) + key = jax.random.PRNGKey(0) + xyz = 6 * jax.random.normal(key, (100, 3)) -@pytest.mark.parametrize("normalized", [True, False]) -def test_autograd(xyz, normalized): - # print(xyz.device_buffer.device()) def compute(xyz): sph = sphericart.jax.spherical_harmonics(xyz=xyz, l_max=4, normalized=normalized) assert jnp.linalg.norm(sph) != 0.0 @@ -23,12 +22,24 @@ def compute(xyz): jtu.check_grads(compute, (xyz,), modes=["fwd", "bwd"], order=1) + # reset to 32-bit so this doesn't carry over to other tests + jax.config.update("jax_enable_x64", False) + @pytest.mark.parametrize("normalized", [True, False]) -def test_autograd_second_derivatives(xyz, normalized): +def test_autograd_second_derivatives(normalized): + # here, 32-bit numerical gradients are very noisy, so we use 64-bit + jax.config.update("jax_enable_x64", True) + + key = jax.random.PRNGKey(0) + xyz = 6 * jax.random.normal(key, (100, 3)) + def compute(xyz): sph = sphericart.jax.spherical_harmonics(xyz=xyz, l_max=4, normalized=normalized) assert jnp.linalg.norm(sph) != 0.0 return sph.sum() jtu.check_grads(compute, (xyz,), modes=["fwd", "bwd"], order=2) + + # reset to 32-bit so this doesn't carry over to other tests + jax.config.update("jax_enable_x64", False) diff --git a/sphericart-jax/python/tests/test_pure_jax.py b/sphericart-jax/python/tests/test_pure_jax.py index 94daaf18a..fd7439886 100644 --- a/sphericart-jax/python/tests/test_pure_jax.py +++ b/sphericart-jax/python/tests/test_pure_jax.py @@ -17,10 +17,10 @@ def xyz(): @pytest.mark.parametrize("l_max", [2, 7]) def test_jit(xyz, l_max): jitted_sph = jax.jit(sphericart.jax.spherical_harmonics, static_argnums=(1, 2)) - pure_jax_jitted_sph = jax.jit(pure_jax_spherical_harmonics, static_argnums=(1, 2)) + pure_jax_jitted_sph = jax.jit(pure_jax_spherical_harmonics, static_argnums=1) sph = jitted_sph(xyz=xyz, l_max=l_max, normalized=True) sph_pure_jax = pure_jax_jitted_sph(xyz, l_max) - assert jnp.allclose(sph, sph_pure_jax) + assert jnp.allclose(sph, sph_pure_jax, atol=1e-4, rtol=1e-4) @pytest.mark.parametrize("l_max", [2, 7]) @@ -29,7 +29,7 @@ def test_jacfwd(xyz, l_max): pure_jax_jacfwd_sph = jax.jacfwd(pure_jax_spherical_harmonics) sph = jacfwd_sph(xyz, l_max=l_max, normalized=True) sph_pure_jax = pure_jax_jacfwd_sph(xyz, l_max) - assert jnp.allclose(sph, sph_pure_jax) + assert jnp.allclose(sph, sph_pure_jax, atol=1e-4, rtol=1e-4) @pytest.mark.parametrize("l_max", [2, 7]) @@ -38,7 +38,7 @@ def test_jacrev(xyz, l_max): pure_jax_jacrev_sph = jax.jacrev(pure_jax_spherical_harmonics) sph = jacrev_sph(xyz, l_max=l_max, normalized=True) sph_pure_jax = pure_jax_jacrev_sph(xyz, l_max) - assert jnp.allclose(sph, sph_pure_jax) + assert jnp.allclose(sph, sph_pure_jax, atol=1e-4, rtol=1e-4) @pytest.mark.parametrize("l_max", [2, 7]) @@ -57,4 +57,4 @@ def test_gradgrad(xyz, l_max): pure_jax_gradgrad_sph = jax.grad(pure_jax_sum_grad_sph) sph = gradgrad_sph(xyz, l_max, normalized=True) sph_pure_jax = pure_jax_gradgrad_sph(xyz, l_max) - assert jnp.allclose(sph, sph_pure_jax) + assert jnp.allclose(sph, sph_pure_jax, atol=1e-4, rtol=1e-4) diff --git a/sphericart-jax/src/jax.cpp b/sphericart-jax/src/jax.cpp index 8282431bc..43273e57b 100644 --- a/sphericart-jax/src/jax.cpp +++ b/sphericart-jax/src/jax.cpp @@ -44,7 +44,6 @@ _get_or_create_sph_cpu(CacheMapCPU &sph_cache, std::mutex &cache_mutex, } template void cpu_sph(void *out, const void **in) { - std::cout << "STILL OK 0" << std::endl; // Parse the inputs const T *xyz = reinterpret_cast(in[0]); const size_t l_max = *reinterpret_cast(in[1]); @@ -61,15 +60,11 @@ template void cpu_sph(void *out, const void **in) { auto &calculator = _get_or_create_sph_cpu(sph_cache, cache_mutex, l_max, normalized); - - std::cout << "STILL OK" << std::endl; calculator->compute_array(xyz, xyz_length, sph, sph_len); - std::cout << "STILL OK 2" << std::endl; } template void cpu_sph_with_gradients(void *out_tuple, const void **in) { - std::cout << "STILL OK 0" << std::endl; // Parse the inputs const T *xyz = reinterpret_cast(in[0]); const size_t l_max = *reinterpret_cast(in[1]); @@ -89,14 +84,12 @@ void cpu_sph_with_gradients(void *out_tuple, const void **in) { auto &calculator = _get_or_create_sph_cpu(sph_cache, cache_mutex, l_max, normalized); - std::cout << "STILL OK" << std::endl; calculator->compute_array_with_gradients(xyz, xyz_length, sph, sph_len, dsph, dsph_len); } template void cpu_sph_with_hessians(void *out_tuple, const void **in) { - std::cout << "STILL OK 0" << std::endl; // Parse the inputs const T *xyz = reinterpret_cast(in[0]); const size_t l_max = *reinterpret_cast(in[1]); @@ -118,7 +111,6 @@ void cpu_sph_with_hessians(void *out_tuple, const void **in) { auto &calculator = _get_or_create_sph_cpu(sph_cache, cache_mutex, l_max, normalized); - std::cout << "STILL OK" << std::endl; calculator->compute_array_with_hessians(xyz, xyz_length, sph, sph_len, dsph, dsph_len, ddsph, ddsph_len); }