Skip to content

Commit

Permalink
Adjust tolerances
Browse files Browse the repository at this point in the history
  • Loading branch information
frostedoyster committed Jan 15, 2024
1 parent 8ba63af commit b5a00c7
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 21 deletions.
27 changes: 19 additions & 8 deletions sphericart-jax/python/tests/test_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,28 +7,39 @@
import sphericart.jax


@pytest.fixture
def xyz():
key = jax.random.PRNGKey(0)
return 6 * jax.random.normal(key, (100, 3))
@pytest.mark.parametrize("normalized", [True, False])
def test_autograd(normalized):
# here, 32-bit numerical gradients are very noisy, so we use 64-bit
jax.config.update("jax_enable_x64", True)

key = jax.random.PRNGKey(0)
xyz = 6 * jax.random.normal(key, (100, 3))

@pytest.mark.parametrize("normalized", [True, False])
def test_autograd(xyz, normalized):
# print(xyz.device_buffer.device())
def compute(xyz):
sph = sphericart.jax.spherical_harmonics(xyz=xyz, l_max=4, normalized=normalized)
assert jnp.linalg.norm(sph) != 0.0
return sph.sum()

jtu.check_grads(compute, (xyz,), modes=["fwd", "bwd"], order=1)

# reset to 32-bit so this doesn't carry over to other tests
jax.config.update("jax_enable_x64", False)


@pytest.mark.parametrize("normalized", [True, False])
def test_autograd_second_derivatives(xyz, normalized):
def test_autograd_second_derivatives(normalized):
# here, 32-bit numerical gradients are very noisy, so we use 64-bit
jax.config.update("jax_enable_x64", True)

key = jax.random.PRNGKey(0)
xyz = 6 * jax.random.normal(key, (100, 3))

def compute(xyz):
sph = sphericart.jax.spherical_harmonics(xyz=xyz, l_max=4, normalized=normalized)
assert jnp.linalg.norm(sph) != 0.0
return sph.sum()

jtu.check_grads(compute, (xyz,), modes=["fwd", "bwd"], order=2)

# reset to 32-bit so this doesn't carry over to other tests
jax.config.update("jax_enable_x64", False)
10 changes: 5 additions & 5 deletions sphericart-jax/python/tests/test_pure_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@ def xyz():
@pytest.mark.parametrize("l_max", [2, 7])
def test_jit(xyz, l_max):
jitted_sph = jax.jit(sphericart.jax.spherical_harmonics, static_argnums=(1, 2))
pure_jax_jitted_sph = jax.jit(pure_jax_spherical_harmonics, static_argnums=(1, 2))
pure_jax_jitted_sph = jax.jit(pure_jax_spherical_harmonics, static_argnums=1)
sph = jitted_sph(xyz=xyz, l_max=l_max, normalized=True)
sph_pure_jax = pure_jax_jitted_sph(xyz, l_max)
assert jnp.allclose(sph, sph_pure_jax)
assert jnp.allclose(sph, sph_pure_jax, atol=1e-4, rtol=1e-4)


@pytest.mark.parametrize("l_max", [2, 7])
Expand All @@ -29,7 +29,7 @@ def test_jacfwd(xyz, l_max):
pure_jax_jacfwd_sph = jax.jacfwd(pure_jax_spherical_harmonics)
sph = jacfwd_sph(xyz, l_max=l_max, normalized=True)
sph_pure_jax = pure_jax_jacfwd_sph(xyz, l_max)
assert jnp.allclose(sph, sph_pure_jax)
assert jnp.allclose(sph, sph_pure_jax, atol=1e-4, rtol=1e-4)


@pytest.mark.parametrize("l_max", [2, 7])
Expand All @@ -38,7 +38,7 @@ def test_jacrev(xyz, l_max):
pure_jax_jacrev_sph = jax.jacrev(pure_jax_spherical_harmonics)
sph = jacrev_sph(xyz, l_max=l_max, normalized=True)
sph_pure_jax = pure_jax_jacrev_sph(xyz, l_max)
assert jnp.allclose(sph, sph_pure_jax)
assert jnp.allclose(sph, sph_pure_jax, atol=1e-4, rtol=1e-4)


@pytest.mark.parametrize("l_max", [2, 7])
Expand All @@ -57,4 +57,4 @@ def test_gradgrad(xyz, l_max):
pure_jax_gradgrad_sph = jax.grad(pure_jax_sum_grad_sph)
sph = gradgrad_sph(xyz, l_max, normalized=True)
sph_pure_jax = pure_jax_gradgrad_sph(xyz, l_max)
assert jnp.allclose(sph, sph_pure_jax)
assert jnp.allclose(sph, sph_pure_jax, atol=1e-4, rtol=1e-4)
8 changes: 0 additions & 8 deletions sphericart-jax/src/jax.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ _get_or_create_sph_cpu(CacheMapCPU<T> &sph_cache, std::mutex &cache_mutex,
}

template <typename T> void cpu_sph(void *out, const void **in) {
std::cout << "STILL OK 0" << std::endl;
// Parse the inputs
const T *xyz = reinterpret_cast<const T *>(in[0]);
const size_t l_max = *reinterpret_cast<const int *>(in[1]);
Expand All @@ -61,15 +60,11 @@ template <typename T> void cpu_sph(void *out, const void **in) {

auto &calculator =
_get_or_create_sph_cpu(sph_cache, cache_mutex, l_max, normalized);

std::cout << "STILL OK" << std::endl;
calculator->compute_array(xyz, xyz_length, sph, sph_len);
std::cout << "STILL OK 2" << std::endl;
}

template <typename T>
void cpu_sph_with_gradients(void *out_tuple, const void **in) {
std::cout << "STILL OK 0" << std::endl;
// Parse the inputs
const T *xyz = reinterpret_cast<const T *>(in[0]);
const size_t l_max = *reinterpret_cast<const int *>(in[1]);
Expand All @@ -89,14 +84,12 @@ void cpu_sph_with_gradients(void *out_tuple, const void **in) {

auto &calculator =
_get_or_create_sph_cpu(sph_cache, cache_mutex, l_max, normalized);
std::cout << "STILL OK" << std::endl;
calculator->compute_array_with_gradients(xyz, xyz_length, sph, sph_len,
dsph, dsph_len);
}

template <typename T>
void cpu_sph_with_hessians(void *out_tuple, const void **in) {
std::cout << "STILL OK 0" << std::endl;
// Parse the inputs
const T *xyz = reinterpret_cast<const T *>(in[0]);
const size_t l_max = *reinterpret_cast<const int *>(in[1]);
Expand All @@ -118,7 +111,6 @@ void cpu_sph_with_hessians(void *out_tuple, const void **in) {

auto &calculator =
_get_or_create_sph_cpu(sph_cache, cache_mutex, l_max, normalized);
std::cout << "STILL OK" << std::endl;
calculator->compute_array_with_hessians(xyz, xyz_length, sph, sph_len, dsph,
dsph_len, ddsph, ddsph_len);
}
Expand Down

0 comments on commit b5a00c7

Please sign in to comment.