Skip to content

Commit

Permalink
Merge branch 'main' into new-api
Browse files Browse the repository at this point in the history
  • Loading branch information
frostedoyster committed Aug 26, 2024
2 parents 5ff6b54 + 6b51e53 commit c6a73ef
Show file tree
Hide file tree
Showing 18 changed files with 129 additions and 86 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,6 @@ will build the documentation in a CPU-only environment.

Although sphericart natively calculates real solid and spherical harmonics from
Cartesian positions, it is easy to manipulate its output it to calculate complex
spherical harmonics and/or to accept spherical coordinates as inputs. You can see an
example [here](https://sphericart.readthedocs.io/en/latest/spherical-complex.html)/.
spherical harmonics and/or to accept spherical coordinates as inputs. You can see
examples [here](https://sphericart.readthedocs.io/en/latest/spherical-complex.html).

4 changes: 2 additions & 2 deletions docs/src/spherical-complex.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ Spherical coordinates and/or complex harmonics

The algorithms implemented in ``sphericart`` are designed to work with Cartesian
input positions and real spherical (or solid) harmonics. However, depending on the use
case, it might be more convenient to harmonics from spherical coordinates and/or work
with complex harmonics.
case, it might be more convenient to compute harmonics from spherical coordinates
and/or work with complex harmonics.

Below, we provide a series of Python examples that illustrate how to use the
``sphericart`` library in such cases. The examples can be easily adapted to the
Expand Down
5 changes: 3 additions & 2 deletions sphericart-jax/python/sphericart/jax/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import jax
from .lib import sphericart_jax_cpu
from .spherical_harmonics import spherical_harmonics, solid_harmonics
from .spherical_harmonics import spherical_harmonics, solid_harmonics # noqa: F401


# register the operations to xla
Expand All @@ -9,9 +9,10 @@

try:
from .lib import sphericart_jax_cuda

# register the operations to xla
for _name, _value in sphericart_jax_cuda.registrations().items():
jax.lib.xla_client.register_custom_call_target(_name, _value, platform="gpu")

except ImportError:
pass
14 changes: 9 additions & 5 deletions sphericart-jax/python/sphericart/jax/ddsph.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import jax
import math
from functools import partial

import jax
from jax import core
from jax.core import ShapedArray
from jax.interpreters import mlir, xla
from jax.interpreters.mlir import ir, custom_call
from .utils import default_layouts, build_sph_descriptor
from jax.interpreters.mlir import custom_call, ir

from .utils import build_sph_descriptor, default_layouts


# This file registers the _ddsph_p primitive and defines its implementation,
Expand All @@ -18,7 +20,9 @@


def ddsph(xyz, l_max, normalized):
sph, dsph, ddsph = _ddsph_p.bind(xyz, l_max, normalized, l_max_c=l_max, normalized_c=normalized)
sph, dsph, ddsph = _ddsph_p.bind(
xyz, l_max, normalized, l_max_c=l_max, normalized_c=normalized
)
return sph, dsph, ddsph


Expand Down Expand Up @@ -114,7 +118,7 @@ def ddsph_lowering_cuda(ctx, xyz, l_max, normalized, *, l_max_c, normalized_c):
operands=[xyz],
operand_layouts=default_layouts(xyz_shape),
result_layouts=default_layouts(sph_shape, dsph_shape, ddsph_shape),
backend_config=descriptor
backend_config=descriptor,
).results


Expand Down
18 changes: 10 additions & 8 deletions sphericart-jax/python/sphericart/jax/dsph.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
import jax
import jax.numpy as jnp
import math
from functools import partial

import jax
import jax.numpy as jnp
from jax import core
from jax.core import ShapedArray
from jax.interpreters import mlir, xla
from jax.interpreters.mlir import ir, custom_call
from jax.interpreters import ad
from jax.interpreters import ad, mlir, xla
from jax.interpreters.mlir import custom_call, ir

from .utils import default_layouts, build_sph_descriptor
from .ddsph import ddsph
from .utils import build_sph_descriptor, default_layouts


# This file registers the _dsph_p primitive and defines its implementation,
Expand All @@ -22,7 +22,9 @@


def dsph(xyz, l_max, normalized):
sph, dsph = _dsph_p.bind(xyz, l_max, normalized, l_max_c=l_max, normalized_c=normalized)
sph, dsph = _dsph_p.bind(
xyz, l_max, normalized, l_max_c=l_max, normalized_c=normalized
)
return sph, dsph


Expand Down Expand Up @@ -109,7 +111,7 @@ def dsph_lowering_cuda(ctx, xyz, l_max, normalized, *, l_max_c, normalized_c):
operands=[xyz],
operand_layouts=default_layouts(xyz_shape),
result_layouts=default_layouts(sph_shape, dsph_shape),
backend_config=descriptor
backend_config=descriptor,
).results


Expand Down
16 changes: 8 additions & 8 deletions sphericart-jax/python/sphericart/jax/sph.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
import jax
import jax.numpy as jnp
import math
from functools import partial

import jax
import jax.numpy as jnp
from jax import core
from jax.core import ShapedArray
from jax.interpreters import mlir, xla
from jax.interpreters.mlir import ir, custom_call
from jax.interpreters import ad
from jax.interpreters import ad, mlir, xla
from jax.interpreters.mlir import custom_call, ir

from .dsph import dsph
from .utils import default_layouts, build_sph_descriptor
from .utils import build_sph_descriptor, default_layouts


# register the sph primitive
Expand Down Expand Up @@ -110,7 +110,7 @@ def sph_lowering_cuda(ctx, xyz, l_max, normalized, *, l_max_c, normalized_c):
raise NotImplementedError(f"Unsupported dtype {dtype}")

descriptor = build_sph_descriptor(n_samples, l_max_c)

return custom_call(
op_name,
# Output types
Expand All @@ -122,7 +122,7 @@ def sph_lowering_cuda(ctx, xyz, l_max, normalized, *, l_max_c, normalized_c):
# Layout specification:
operand_layouts=default_layouts(xyz_shape),
result_layouts=default_layouts(out_shape),
backend_config=descriptor
backend_config=descriptor,
).results


Expand Down
13 changes: 5 additions & 8 deletions sphericart-jax/python/sphericart/jax/spherical_harmonics.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,19 @@
def spherical_harmonics(xyz, l_max):
"""Computes the Spherical harmonics and their derivatives within
the JAX framework.
This function supports ``jit``, ``vmap``, and up to two rounds of forward and/or
backward automatic differentiation (``grad``, ``jacfwd``, ``jacrev``, ``hessian``, ...).
For the moment, it does not support ``pmap``.
Note that the ``l_max`` argument (position 1 in the signature) should be tagged as static
when jit-ing the function:
Note that the ``l_max`` argument (position 1 in the signature) should be tagged as
static when jit-ing the function:
>>> import jax
>>> import sphericart.jax
>>> jitted_sph_function = jax.jit(sphericart.jax.spherical_harmonics, static_argnums=1)
>>> jitted_sph_fn = jax.jit(sphericart.jax.spherical_harmonics, static_argnums=1)
Parameters
----------
xyz : jax array [..., 3]
single vector or set of vectors in 3D. All dimensions are optional except for the last
single vector or set of vectors in 3D. All dimensions are optional except for
the last
l_max : int
maximum order of the spherical harmonics (included)
Expand Down
2 changes: 2 additions & 0 deletions sphericart-jax/python/sphericart/jax/utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
def default_layouts(*shapes):
return [range(len(shape) - 1, -1, -1) for shape in shapes]


try:
from .lib.sphericart_jax_cuda import build_sph_descriptor
except ImportError:

def build_sph_descriptor(a, b):
raise ValueError(
"Trying to use sphericart-jax on CUDA, "
Expand Down
5 changes: 3 additions & 2 deletions sphericart-jax/python/tests/pure_jax_sph.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from functools import partial

import jax
import jax.numpy as jnp
from functools import partial


@partial(jax.vmap, in_axes=(0, None))
Expand All @@ -14,7 +15,7 @@ def pure_jax_spherical_harmonics(xyz, l_max):
prefactors = jnp.empty(((l_max + 1) ** 2,))
ylm = jnp.empty(((l_max + 1) ** 2,))

for l in range(l_max + 1):
for l in range(l_max + 1): # noqa: E741
prefactors = prefactors.at[l**2 + l].set(jnp.sqrt((2 * l + 1) / (2 * jnp.pi)))
for m in range(1, l + 1):
prefactors = prefactors.at[l**2 + l + m].set(
Expand Down
18 changes: 14 additions & 4 deletions sphericart-jax/python/tests/test_autograd.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import pytest
import jax

import jax.numpy as jnp
import jax.test_util as jtu
import pytest

import sphericart.jax


Expand All @@ -14,7 +14,12 @@ def test_autograd(normalized):
key = jax.random.PRNGKey(0)
xyz = 6 * jax.random.normal(key, (100, 3))

function = sphericart.jax.spherical_harmonics if normalized else sphericart.jax.solid_harmonics
function = (
sphericart.jax.spherical_harmonics
if normalized
else sphericart.jax.solid_harmonics
)

def compute(xyz):
sph = function(xyz, 4)
return jnp.sum(sph)
Expand All @@ -33,7 +38,12 @@ def test_autograd_second_derivatives(normalized):
key = jax.random.PRNGKey(0)
xyz = 6 * jax.random.normal(key, (100, 3))

function = sphericart.jax.spherical_harmonics if normalized else sphericart.jax.solid_harmonics
function = (
sphericart.jax.spherical_harmonics
if normalized
else sphericart.jax.solid_harmonics
)

def compute(xyz):
sph = function(xyz, 4)
return jnp.sum(sph)
Expand Down
17 changes: 9 additions & 8 deletions sphericart-jax/python/tests/test_consistency.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,17 @@
import pytest
import jax
import numpy as np
import pytest

import sphericart
import sphericart.jax



import jax.numpy as jnp
import numpy as np

@pytest.fixture
def xyz():
key = jax.random.PRNGKey(0)
return 6 * jax.random.normal(key, (100, 3))


@pytest.mark.parametrize("normalized", [False, True])
@pytest.mark.parametrize("l_max", [4, 7, 10])
def test_consistency(xyz, l_max, normalized):
Expand All @@ -21,10 +20,12 @@ def test_consistency(xyz, l_max, normalized):
else:
calculator = sphericart.SolidHarmonics(l_max=l_max)

function = sphericart.jax.spherical_harmonics if normalized else sphericart.jax.solid_harmonics
sph = function(
l_max=l_max, xyz=xyz
function = (
sphericart.jax.spherical_harmonics
if normalized
else sphericart.jax.solid_harmonics
)
sph = function(l_max=l_max, xyz=xyz)

sph_ref = calculator.compute(np.asarray(xyz))

Expand Down
4 changes: 2 additions & 2 deletions sphericart-jax/python/tests/test_nn.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import equinox as eqx
import jax

import jax.numpy as jnp
import equinox as eqx

import sphericart.jax


Expand Down
8 changes: 5 additions & 3 deletions sphericart-jax/python/tests/test_no_points.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@
def test_no_points(l_max, normalized):
xyz = jnp.empty((0, 3))

function = sphericart.jax.spherical_harmonics if normalized else sphericart.jax.solid_harmonics
sph = function(
l_max=l_max, xyz=xyz
function = (
sphericart.jax.spherical_harmonics
if normalized
else sphericart.jax.solid_harmonics
)
sph = function(l_max=l_max, xyz=xyz)
assert sph.shape == (0, l_max * l_max + 2 * l_max + 1)
4 changes: 2 additions & 2 deletions sphericart-jax/python/tests/test_precision.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import pytest
import jax

import jax.numpy as jnp
import pytest

import sphericart.jax


Expand Down
22 changes: 12 additions & 10 deletions sphericart-jax/python/tests/test_pure_jax.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import numpy as np
import jax
import jax.numpy as jnp
import pytest
from pure_jax_sph import pure_jax_spherical_harmonics

import sphericart.jax
from pure_jax_sph import pure_jax_spherical_harmonics


@pytest.fixture
Expand Down Expand Up @@ -42,14 +41,17 @@ def test_jacrev(xyz, l_max):

@pytest.mark.parametrize("l_max", [2, 7])
def test_gradgrad(xyz, l_max):
sum_sph = lambda x, l_max: jnp.sum(
sphericart.jax.spherical_harmonics(x, l_max)
)
pure_jax_sum_sph = lambda x, l_max: jnp.sum(pure_jax_spherical_harmonics(x, l_max))
sum_grad_sph = lambda x, l_max: jnp.sum(
jax.grad(sum_sph)(x, l_max)
)
pure_jax_sum_grad_sph = lambda x, l_max: jnp.sum(

def sum_sph(x, l_max):
return jnp.sum(sphericart.jax.spherical_harmonics(x, l_max))

def pure_jax_sum_sph(x, l_max):
return jnp.sum(pure_jax_spherical_harmonics(x, l_max))

def sum_grad_sph(x, l_max):
return jnp.sum(jax.grad(sum_sph)(x, l_max))

pure_jax_sum_grad_sph = lambda x, l_max: jnp.sum( # noqa: E731
jax.grad(pure_jax_sum_sph)(x, l_max)
)
gradgrad_sph = jax.grad(sum_grad_sph)
Expand Down
Loading

0 comments on commit c6a73ef

Please sign in to comment.