Skip to content

Commit

Permalink
Change JAX API
Browse files Browse the repository at this point in the history
  • Loading branch information
frostedoyster committed Aug 20, 2024
1 parent e81d35f commit 08a86d1
Show file tree
Hide file tree
Showing 17 changed files with 279 additions and 188 deletions.
36 changes: 21 additions & 15 deletions sphericart-jax/include/sphericart/sphericart_jax_cuda.hpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
// This file is needed as a workaround for pybind11 not accepting cuda files.
// Note that all the templated functions are split into separate functions so
// that they can be compiled in the `.cu` file.

#ifndef _SPHERICART_JAX_CUDA_HPP_
#define _SPHERICART_JAX_CUDA_HPP_
Expand All @@ -8,32 +11,35 @@
struct SphDescriptor {
std::int64_t n_samples;
std::int64_t lmax;
bool normalize;
};

namespace sphericart_jax {

namespace cuda {

void apply_cuda_sph_f32(cudaStream_t stream, void** in, const char* opaque, std::size_t opaque_len);
void cuda_spherical_f32(cudaStream_t stream, void** in, const char* opaque, std::size_t opaque_len);

void apply_cuda_sph_f64(cudaStream_t stream, void** in, const char* opaque, std::size_t opaque_len);
void cuda_spherical_f64(cudaStream_t stream, void** in, const char* opaque, std::size_t opaque_len);

void apply_cuda_sph_with_gradients_f32(
cudaStream_t stream, void** in, const char* opaque, std::size_t opaque_len
);
void cuda_dspherical_f32(cudaStream_t stream, void** in, const char* opaque, std::size_t opaque_len);

void apply_cuda_sph_with_gradients_f64(
cudaStream_t stream, void** in, const char* opaque, std::size_t opaque_len
);
void cuda_dspherical_f64(cudaStream_t stream, void** in, const char* opaque, std::size_t opaque_len);

void apply_cuda_sph_with_hessians_f32(
cudaStream_t stream, void** in, const char* opaque, std::size_t opaque_len
);
void cuda_ddspherical_f32(cudaStream_t stream, void** in, const char* opaque, std::size_t opaque_len);

void apply_cuda_sph_with_hessians_f64(
cudaStream_t stream, void** in, const char* opaque, std::size_t opaque_len
);
void cuda_ddspherical_f64(cudaStream_t stream, void** in, const char* opaque, std::size_t opaque_len);

void cuda_solid_f32(cudaStream_t stream, void** in, const char* opaque, std::size_t opaque_len);

void cuda_solid_f64(cudaStream_t stream, void** in, const char* opaque, std::size_t opaque_len);

void cuda_dsolid_f32(cudaStream_t stream, void** in, const char* opaque, std::size_t opaque_len);

void cuda_dsolid_f64(cudaStream_t stream, void** in, const char* opaque, std::size_t opaque_len);

void cuda_ddsolid_f32(cudaStream_t stream, void** in, const char* opaque, std::size_t opaque_len);

void cuda_ddsolid_f64(cudaStream_t stream, void** in, const char* opaque, std::size_t opaque_len);

} // namespace cuda
} // namespace sphericart_jax
Expand Down
2 changes: 1 addition & 1 deletion sphericart-jax/python/sphericart/jax/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import jax
from .lib import sphericart_jax_cpu
from .spherical_harmonics import spherical_harmonics
from .spherical_harmonics import spherical_harmonics, solid_harmonics


# register the operations to xla
Expand Down
23 changes: 16 additions & 7 deletions sphericart-jax/python/sphericart/jax/ddsph.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,15 @@ def ddsph_lowering_cpu(ctx, xyz, l_max, normalized, *, l_max_c, normalized_c):
ddsph_shape = xyz_shape[:-1] + [3, 3, sph_size]
n_samples = math.prod(xyz_shape[:-1])

op_name = "cpu_dd"
if normalized_c:
op_name += "spherical_"
else:
op_name += "solid_"
if dtype == ir.F32Type.get():
op_name = "cpu_ddsph_f32"
op_name += "f32"
elif dtype == ir.F64Type.get():
op_name = "cpu_ddsph_f64"
op_name += "f64"
else:
raise NotImplementedError(f"Unsupported dtype {dtype}")

Expand All @@ -65,10 +70,9 @@ def ddsph_lowering_cpu(ctx, xyz, l_max, normalized, *, l_max_c, normalized_c):
operands=[
xyz,
mlir.ir_constant(l_max_c),
mlir.ir_constant(normalized_c),
mlir.ir_constant(n_samples),
],
operand_layouts=default_layouts(xyz_shape, (), (), ()),
operand_layouts=default_layouts(xyz_shape, (), ()),
result_layouts=default_layouts(sph_shape, dsph_shape, ddsph_shape),
).results

Expand All @@ -86,14 +90,19 @@ def ddsph_lowering_cuda(ctx, xyz, l_max, normalized, *, l_max_c, normalized_c):
ddsph_shape = xyz_shape[:-1] + [3, 3, sph_size]
n_samples = math.prod(xyz_shape[:-1])

op_name = "cuda_dd"
if normalized_c:
op_name += "spherical_"
else:
op_name += "solid_"
if dtype == ir.F32Type.get():
op_name = "cuda_ddsph_f32"
op_name += "f32"
elif dtype == ir.F64Type.get():
op_name = "cuda_ddsph_f64"
op_name += "f64"
else:
raise NotImplementedError(f"Unsupported dtype {dtype}")

descriptor = build_sph_descriptor(n_samples, l_max_c, normalized_c)
descriptor = build_sph_descriptor(n_samples, l_max_c)

return custom_call(
op_name,
Expand Down
23 changes: 16 additions & 7 deletions sphericart-jax/python/sphericart/jax/dsph.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,15 @@ def dsph_lowering_cpu(ctx, xyz, l_max, normalized, *, l_max_c, normalized_c):
dsph_shape = xyz_shape[:-1] + [3, sph_size]
n_samples = math.prod(xyz_shape[:-1])

op_name = "cpu_d"
if normalized_c:
op_name += "spherical_"
else:
op_name += "solid_"
if dtype == ir.F32Type.get():
op_name = "cpu_dsph_f32"
op_name += "f32"
elif dtype == ir.F64Type.get():
op_name = "cpu_dsph_f64"
op_name += "f64"
else:
raise NotImplementedError(f"Unsupported dtype {dtype}")

Expand All @@ -62,10 +67,9 @@ def dsph_lowering_cpu(ctx, xyz, l_max, normalized, *, l_max_c, normalized_c):
operands=[
xyz,
mlir.ir_constant(l_max_c),
mlir.ir_constant(normalized_c),
mlir.ir_constant(n_samples),
],
operand_layouts=default_layouts(xyz_shape, (), (), ()),
operand_layouts=default_layouts(xyz_shape, (), ()),
result_layouts=default_layouts(sph_shape, dsph_shape),
).results

Expand All @@ -82,14 +86,19 @@ def dsph_lowering_cuda(ctx, xyz, l_max, normalized, *, l_max_c, normalized_c):
dsph_shape = xyz_shape[:-1] + [3, sph_size]
n_samples = math.prod(xyz_shape[:-1])

op_name = "cuda_d"
if normalized_c:
op_name += "spherical_"
else:
op_name += "solid_"
if dtype == ir.F32Type.get():
op_name = "cuda_dsph_f32"
op_name += "f32"
elif dtype == ir.F64Type.get():
op_name = "cuda_dsph_f64"
op_name += "f64"
else:
raise NotImplementedError(f"Unsupported dtype {dtype}")

descriptor = build_sph_descriptor(n_samples, l_max_c, normalized_c)
descriptor = build_sph_descriptor(n_samples, l_max_c)

return custom_call(
op_name,
Expand Down
23 changes: 16 additions & 7 deletions sphericart-jax/python/sphericart/jax/sph.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,15 @@ def sph_lowering_cpu(ctx, xyz, l_max, normalized, *, l_max_c, normalized_c):
n_samples = math.prod(xyz_shape[:-1])

# make sure we dispatch to the correct implementation
op_name = "cpu_"
if normalized_c:
op_name += "spherical_"
else:
op_name += "solid_"
if dtype == ir.F32Type.get():
op_name = "cpu_sph_f32"
op_name += "f32"
elif dtype == ir.F64Type.get():
op_name = "cpu_sph_f64"
op_name += "f64"
else:
raise NotImplementedError(f"Unsupported dtype {dtype}")

Expand All @@ -67,11 +72,10 @@ def sph_lowering_cpu(ctx, xyz, l_max, normalized, *, l_max_c, normalized_c):
operands=[
xyz,
mlir.ir_constant(l_max_c),
mlir.ir_constant(normalized_c),
mlir.ir_constant(n_samples),
],
# Layout specification:
operand_layouts=default_layouts(xyz_shape, (), (), ()),
operand_layouts=default_layouts(xyz_shape, (), ()),
result_layouts=default_layouts(out_shape),
).results

Expand All @@ -93,14 +97,19 @@ def sph_lowering_cuda(ctx, xyz, l_max, normalized, *, l_max_c, normalized_c):
n_samples = math.prod(xyz_shape[:-1])

# make sure we dispatch to the correct implementation
op_name = "cuda_"
if normalized_c:
op_name += "spherical_"
else:
op_name += "solid_"
if dtype == ir.F32Type.get():
op_name = "cuda_sph_f32"
op_name += "f32"
elif dtype == ir.F64Type.get():
op_name = "cuda_sph_f64"
op_name += "f64"
else:
raise NotImplementedError(f"Unsupported dtype {dtype}")

descriptor = build_sph_descriptor(n_samples, l_max_c, normalized_c)
descriptor = build_sph_descriptor(n_samples, l_max_c)

return custom_call(
op_name,
Expand Down
24 changes: 16 additions & 8 deletions sphericart-jax/python/sphericart/jax/spherical_harmonics.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,27 @@
from .sph import sph


def spherical_harmonics(xyz, l_max, normalized=False):
def spherical_harmonics(xyz, l_max):
"""Computes the Spherical harmonics and their derivatives within
the JAX framework.
This function supports ``jit``, ``vmap``, and up to two rounds of forward and/or
backward automatic differentiation (``grad``, ``jacfwd``, ``jacrev``, ``hessian``, ...).
For the moment, it does not support ``pmap``.
Note that the ``l_max`` and ``normalized`` arguments (positions 1 and 2 in the signature)
should be tagged as static when jit-ing the function:
Note that the ``l_max`` argument (position 1 in the signature) should be tagged as static
when jit-ing the function:
>>> import jax
>>> import sphericart.jax
>>> jitted_sph_function = jax.jit(sphericart.jax.spherical_harmonics, static_argnums=(1, 2))
>>> jitted_sph_function = jax.jit(sphericart.jax.spherical_harmonics, static_argnums=1)
Parameters
----------
xyz : jax array [..., 3]
single vector or set of vectors in 3D. All dimensions are optional except for the last
l_max : int
maximum order of the spherical harmonics (included)
normalized : bool
whether the function computes Cartesian solid harmonics (``normalized=False``, default)
or normalized spherical harmonicsi (``normalized=True``)
Returns
-------
Expand All @@ -34,5 +31,16 @@ def spherical_harmonics(xyz, l_max, normalized=False):
if xyz.shape[-1] != 3:
raise ValueError("the last axis of xyz must have size 3")
xyz = xyz.ravel().reshape(xyz.shape) # make contiguous (???)
output = sph(xyz, l_max, normalized)
output = sph(xyz, l_max, normalized=True)
return output


def solid_harmonics(xyz, l_max, normalized=False):
"""
Same as `spherical_harmonics`, but computes the solid harmonics instead.
"""
if xyz.shape[-1] != 3:
raise ValueError("the last axis of xyz must have size 3")
xyz = xyz.ravel().reshape(xyz.shape) # make contiguous (???)
output = sph(xyz, l_max, normalized=False)
return output
2 changes: 1 addition & 1 deletion sphericart-jax/python/sphericart/jax/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ def default_layouts(*shapes):
try:
from .lib.sphericart_jax_cuda import build_sph_descriptor
except ImportError:
def build_sph_descriptor(a, b, c):
def build_sph_descriptor(a, b):
raise ValueError(
"Trying to use sphericart-jax on CUDA, "
"but sphericart-jax was installed without CUDA support. "
Expand Down
8 changes: 5 additions & 3 deletions sphericart-jax/python/tests/test_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import jax

import jax.numpy as jnp
import jax._src.test_util as jtu
import jax.test_util as jtu
import sphericart.jax


Expand All @@ -14,8 +14,9 @@ def test_autograd(normalized):
key = jax.random.PRNGKey(0)
xyz = 6 * jax.random.normal(key, (100, 3))

function = sphericart.jax.spherical_harmonics if normalized else sphericart.jax.solid_harmonics
def compute(xyz):
sph = sphericart.jax.spherical_harmonics(xyz, 4, normalized)
sph = function(xyz, 4)
return jnp.sum(sph)

jtu.check_grads(compute, (xyz,), modes=["fwd", "bwd"], order=1)
Expand All @@ -32,8 +33,9 @@ def test_autograd_second_derivatives(normalized):
key = jax.random.PRNGKey(0)
xyz = 6 * jax.random.normal(key, (100, 3))

function = sphericart.jax.spherical_harmonics if normalized else sphericart.jax.solid_harmonics
def compute(xyz):
sph = sphericart.jax.spherical_harmonics(xyz, 4, normalized)
sph = function(xyz, 4)
return jnp.sum(sph)

jtu.check_grads(compute, (xyz,), modes=["fwd", "bwd"], order=2)
Expand Down
11 changes: 8 additions & 3 deletions sphericart-jax/python/tests/test_consistency.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,14 @@ def xyz():
@pytest.mark.parametrize("normalized", [False, True])
@pytest.mark.parametrize("l_max", [4, 7, 10])
def test_consistency(xyz, l_max, normalized):
calculator = sphericart.SphericalHarmonics(l_max=l_max, normalized=normalized)
sph = sphericart.jax.spherical_harmonics(
l_max=l_max, normalized=normalized, xyz=xyz
if normalized:
calculator = sphericart.SphericalHarmonics(l_max=l_max)
else:
calculator = sphericart.SolidHarmonics(l_max=l_max)

function = sphericart.jax.spherical_harmonics if normalized else sphericart.jax.solid_harmonics
sph = function(
l_max=l_max, xyz=xyz
)

sph_ref = calculator.compute(np.asarray(xyz))
Expand Down
2 changes: 1 addition & 1 deletion sphericart-jax/python/tests/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def __init__(self):
pass

def __call__(self, xyz):
sph = sphericart.jax.spherical_harmonics(xyz, 4, True)
sph = sphericart.jax.spherical_harmonics(xyz, 4)
sum = jnp.sum(sph)
return sum

Expand Down
5 changes: 3 additions & 2 deletions sphericart-jax/python/tests/test_no_points.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
def test_no_points(l_max, normalized):
xyz = jnp.empty((0, 3))

sph = sphericart.jax.spherical_harmonics(
l_max=l_max, normalized=normalized, xyz=xyz
function = sphericart.jax.spherical_harmonics if normalized else sphericart.jax.solid_harmonics
sph = function(
l_max=l_max, xyz=xyz
)
assert sph.shape == (0, l_max * l_max + 2 * l_max + 1)
2 changes: 1 addition & 1 deletion sphericart-jax/python/tests/test_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def test_precision(xyz):
jax.config.update("jax_enable_x64", True)

def compute(xyz):
sph = sphericart.jax.spherical_harmonics(xyz, l_max=4, normalized=False)
sph = sphericart.jax.solid_harmonics(xyz, l_max=4)
return sph

xyz_64 = jnp.array(xyz, dtype=jnp.float64)
Expand Down
Loading

0 comments on commit 08a86d1

Please sign in to comment.