Skip to content

Commit

Permalink
Lint sphericart-jax (#128)
Browse files Browse the repository at this point in the history
  • Loading branch information
frostedoyster authored Aug 20, 2024
1 parent ba49a10 commit 6b51e53
Show file tree
Hide file tree
Showing 15 changed files with 72 additions and 66 deletions.
5 changes: 3 additions & 2 deletions sphericart-jax/python/sphericart/jax/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import jax
from .lib import sphericart_jax_cpu
from .spherical_harmonics import spherical_harmonics
from .spherical_harmonics import spherical_harmonics # noqa: F401


# register the operations to xla
Expand All @@ -9,9 +9,10 @@

try:
from .lib import sphericart_jax_cuda

# register the operations to xla
for _name, _value in sphericart_jax_cuda.registrations().items():
jax.lib.xla_client.register_custom_call_target(_name, _value, platform="gpu")

except ImportError:
pass
14 changes: 9 additions & 5 deletions sphericart-jax/python/sphericart/jax/ddsph.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import jax
import math
from functools import partial

import jax
from jax import core
from jax.core import ShapedArray
from jax.interpreters import mlir, xla
from jax.interpreters.mlir import ir, custom_call
from .utils import default_layouts, build_sph_descriptor
from jax.interpreters.mlir import custom_call, ir

from .utils import build_sph_descriptor, default_layouts


# This file registers the _ddsph_p primitive and defines its implementation,
Expand All @@ -18,7 +20,9 @@


def ddsph(xyz, l_max, normalized):
sph, dsph, ddsph = _ddsph_p.bind(xyz, l_max, normalized, l_max_c=l_max, normalized_c=normalized)
sph, dsph, ddsph = _ddsph_p.bind(
xyz, l_max, normalized, l_max_c=l_max, normalized_c=normalized
)
return sph, dsph, ddsph


Expand Down Expand Up @@ -105,7 +109,7 @@ def ddsph_lowering_cuda(ctx, xyz, l_max, normalized, *, l_max_c, normalized_c):
operands=[xyz],
operand_layouts=default_layouts(xyz_shape),
result_layouts=default_layouts(sph_shape, dsph_shape, ddsph_shape),
backend_config=descriptor
backend_config=descriptor,
).results


Expand Down
18 changes: 10 additions & 8 deletions sphericart-jax/python/sphericart/jax/dsph.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
import jax
import jax.numpy as jnp
import math
from functools import partial

import jax
import jax.numpy as jnp
from jax import core
from jax.core import ShapedArray
from jax.interpreters import mlir, xla
from jax.interpreters.mlir import ir, custom_call
from jax.interpreters import ad
from jax.interpreters import ad, mlir, xla
from jax.interpreters.mlir import custom_call, ir

from .utils import default_layouts, build_sph_descriptor
from .ddsph import ddsph
from .utils import build_sph_descriptor, default_layouts


# This file registers the _dsph_p primitive and defines its implementation,
Expand All @@ -22,7 +22,9 @@


def dsph(xyz, l_max, normalized):
sph, dsph = _dsph_p.bind(xyz, l_max, normalized, l_max_c=l_max, normalized_c=normalized)
sph, dsph = _dsph_p.bind(
xyz, l_max, normalized, l_max_c=l_max, normalized_c=normalized
)
return sph, dsph


Expand Down Expand Up @@ -100,7 +102,7 @@ def dsph_lowering_cuda(ctx, xyz, l_max, normalized, *, l_max_c, normalized_c):
operands=[xyz],
operand_layouts=default_layouts(xyz_shape),
result_layouts=default_layouts(sph_shape, dsph_shape),
backend_config=descriptor
backend_config=descriptor,
).results


Expand Down
16 changes: 8 additions & 8 deletions sphericart-jax/python/sphericart/jax/sph.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
import jax
import jax.numpy as jnp
import math
from functools import partial

import jax
import jax.numpy as jnp
from jax import core
from jax.core import ShapedArray
from jax.interpreters import mlir, xla
from jax.interpreters.mlir import ir, custom_call
from jax.interpreters import ad
from jax.interpreters import ad, mlir, xla
from jax.interpreters.mlir import custom_call, ir

from .dsph import dsph
from .utils import default_layouts, build_sph_descriptor
from .utils import build_sph_descriptor, default_layouts


# register the sph primitive
Expand Down Expand Up @@ -101,7 +101,7 @@ def sph_lowering_cuda(ctx, xyz, l_max, normalized, *, l_max_c, normalized_c):
raise NotImplementedError(f"Unsupported dtype {dtype}")

descriptor = build_sph_descriptor(n_samples, l_max_c, normalized_c)

return custom_call(
op_name,
# Output types
Expand All @@ -113,7 +113,7 @@ def sph_lowering_cuda(ctx, xyz, l_max, normalized, *, l_max_c, normalized_c):
# Layout specification:
operand_layouts=default_layouts(xyz_shape),
result_layouts=default_layouts(out_shape),
backend_config=descriptor
backend_config=descriptor,
).results


Expand Down
19 changes: 10 additions & 9 deletions sphericart-jax/python/sphericart/jax/spherical_harmonics.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,27 +4,28 @@
def spherical_harmonics(xyz, l_max, normalized=False):
"""Computes the Spherical harmonics and their derivatives within
the JAX framework.
This function supports ``jit``, ``vmap``, and up to two rounds of forward and/or
backward automatic differentiation (``grad``, ``jacfwd``, ``jacrev``, ``hessian``, ...).
For the moment, it does not support ``pmap``.
backward automatic differentiation (``grad``, ``jacfwd``, ``jacrev``, ``hessian``,
...). For the moment, it does not support ``pmap``.
Note that the ``l_max`` and ``normalized`` arguments (positions 1 and 2 in the signature)
should be tagged as static when jit-ing the function:
Note that the ``l_max`` and ``normalized`` arguments (positions 1 and 2 in the
signature) should be tagged as static when jit-ing the function:
>>> import jax
>>> import sphericart.jax
>>> jitted_sph_function = jax.jit(sphericart.jax.spherical_harmonics, static_argnums=(1, 2))
>>> sph_fn_jit = jax.jit(sphericart.jax.spherical_harmonics, static_argnums=(1, 2))
Parameters
----------
xyz : jax array [..., 3]
single vector or set of vectors in 3D. All dimensions are optional except for the last
single vector or set of vectors in 3D. All dimensions are optional except for
the last
l_max : int
maximum order of the spherical harmonics (included)
normalized : bool
whether the function computes Cartesian solid harmonics (``normalized=False``, default)
or normalized spherical harmonicsi (``normalized=True``)
whether the function computes Cartesian solid harmonics (``normalized=False``,
default) or normalized spherical harmonicsi (``normalized=True``)
Returns
-------
Expand Down
2 changes: 2 additions & 0 deletions sphericart-jax/python/sphericart/jax/utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
def default_layouts(*shapes):
return [range(len(shape) - 1, -1, -1) for shape in shapes]


try:
from .lib.sphericart_jax_cuda import build_sph_descriptor
except ImportError:

def build_sph_descriptor(a, b, c):
raise ValueError(
"Trying to use sphericart-jax on CUDA, "
Expand Down
5 changes: 3 additions & 2 deletions sphericart-jax/python/tests/pure_jax_sph.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from functools import partial

import jax
import jax.numpy as jnp
from functools import partial


@partial(jax.vmap, in_axes=(0, None))
Expand All @@ -14,7 +15,7 @@ def pure_jax_spherical_harmonics(xyz, l_max):
prefactors = jnp.empty(((l_max + 1) ** 2,))
ylm = jnp.empty(((l_max + 1) ** 2,))

for l in range(l_max + 1):
for l in range(l_max + 1): # noqa: E741
prefactors = prefactors.at[l**2 + l].set(jnp.sqrt((2 * l + 1) / (2 * jnp.pi)))
for m in range(1, l + 1):
prefactors = prefactors.at[l**2 + l + m].set(
Expand Down
6 changes: 3 additions & 3 deletions sphericart-jax/python/tests/test_autograd.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import pytest
import jax

import jax.numpy as jnp
import jax._src.test_util as jtu
import jax.test_util as jtu
import pytest

import sphericart.jax


Expand Down
9 changes: 4 additions & 5 deletions sphericart-jax/python/tests/test_consistency.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,17 @@
import pytest
import jax
import numpy as np
import pytest

import sphericart
import sphericart.jax



import jax.numpy as jnp
import numpy as np

@pytest.fixture
def xyz():
key = jax.random.PRNGKey(0)
return 6 * jax.random.normal(key, (100, 3))


@pytest.mark.parametrize("normalized", [False, True])
@pytest.mark.parametrize("l_max", [4, 7, 10])
def test_consistency(xyz, l_max, normalized):
Expand Down
4 changes: 2 additions & 2 deletions sphericart-jax/python/tests/test_nn.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import equinox as eqx
import jax

import jax.numpy as jnp
import equinox as eqx

import sphericart.jax


Expand Down
4 changes: 2 additions & 2 deletions sphericart-jax/python/tests/test_precision.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import pytest
import jax

import jax.numpy as jnp
import pytest

import sphericart.jax


Expand Down
13 changes: 7 additions & 6 deletions sphericart-jax/python/tests/test_pure_jax.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import numpy as np
import jax
import jax.numpy as jnp
import pytest
from pure_jax_sph import pure_jax_spherical_harmonics

import sphericart.jax
from pure_jax_sph import pure_jax_spherical_harmonics


@pytest.fixture
Expand Down Expand Up @@ -42,14 +41,16 @@ def test_jacrev(xyz, l_max):

@pytest.mark.parametrize("l_max", [2, 7])
def test_gradgrad(xyz, l_max):
sum_sph = lambda x, l_max, normalized: jnp.sum(
sum_sph = lambda x, l_max, normalized: jnp.sum( # noqa: E731
sphericart.jax.spherical_harmonics(x, l_max, normalized)
)
pure_jax_sum_sph = lambda x, l_max: jnp.sum(pure_jax_spherical_harmonics(x, l_max))
sum_grad_sph = lambda x, l_max, normalized: jnp.sum(
pure_jax_sum_sph = lambda x, l_max: jnp.sum( # noqa: E731
pure_jax_spherical_harmonics(x, l_max)
)
sum_grad_sph = lambda x, l_max, normalized: jnp.sum( # noqa: E731
jax.grad(sum_sph)(x, l_max, normalized)
)
pure_jax_sum_grad_sph = lambda x, l_max: jnp.sum(
pure_jax_sum_grad_sph = lambda x, l_max: jnp.sum( # noqa: E731
jax.grad(pure_jax_sum_sph)(x, l_max)
)
gradgrad_sph = jax.grad(sum_grad_sph)
Expand Down
16 changes: 6 additions & 10 deletions sphericart-jax/python/tests/test_transforms.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import numpy as np
import pytest
import jax
import jax.numpy as jnp
import numpy as np
import pytest

import sphericart.jax

Expand All @@ -19,18 +19,16 @@ def compute(xyz):

# jit compile the function
jcompute = jax.jit(compute)
out = jcompute(xyz)
jcompute(xyz)
# get gradients for the compiled function
dout = jax.grad(jcompute)(xyz)
jax.grad(jcompute)(xyz)


@pytest.mark.parametrize("normalized", [True, False])
@pytest.mark.parametrize("l_max", [4, 7, 10])
def test_jit(xyz, l_max, normalized):
jitted_sph = jax.jit(sphericart.jax.spherical_harmonics, static_argnums=(1, 2))
calculator = sphericart.SphericalHarmonics(
l_max=l_max, normalized=normalized
)
calculator = sphericart.SphericalHarmonics(l_max=l_max, normalized=normalized)
sph = jitted_sph(xyz=xyz, l_max=l_max, normalized=normalized)
sph_ref = calculator.compute(np.asarray(xyz))
np.testing.assert_allclose(sph, sph_ref, rtol=2e-5, atol=1e-6)
Expand All @@ -40,9 +38,7 @@ def test_jit(xyz, l_max, normalized):
@pytest.mark.parametrize("l_max", [4, 7, 10])
def test_vmap(xyz, l_max, normalized):
vmapped_sph = jax.vmap(sphericart.jax.spherical_harmonics, in_axes=(0, None, None))
calculator = sphericart.SphericalHarmonics(
l_max=l_max, normalized=normalized
)
calculator = sphericart.SphericalHarmonics(l_max=l_max, normalized=normalized)
sph = vmapped_sph(xyz, l_max, normalized)
sph_ref = calculator.compute(np.asarray(xyz))
np.testing.assert_allclose(sph, sph_ref, rtol=2e-5, atol=1e-6)
Expand Down
5 changes: 2 additions & 3 deletions sphericart-jax/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import subprocess
import sys

import pybind11
from setuptools import Extension, setup
from setuptools.command.bdist_egg import bdist_egg
from setuptools.command.build_ext import build_ext
Expand All @@ -10,19 +11,17 @@
ROOT = os.path.realpath(os.path.dirname(__file__))
SPHERICART_ARCH_NATIVE = os.environ.get("SPHERICART_ARCH_NATIVE", "ON")


class cmake_ext(build_ext):
"""Build the native library using cmake"""


def run(self):
source_dir = ROOT
build_dir = os.path.join(ROOT, "build", "cmake-build")
install_dir = os.path.join(os.path.realpath(self.build_lib), "sphericart/jax")

os.makedirs(build_dir, exist_ok=True)

import pybind11

cmake_prefix_path = [pybind11.get_cmake_dir()]

cmake_options = [
Expand Down
2 changes: 1 addition & 1 deletion tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ allowlist_externals =
bash

pip_install_flags = --no-deps --no-cache --no-build-isolation --check-build-dependencies --force-reinstall
lint_folders = python setup.py sphericart-torch/python sphericart-torch/setup.py
lint_folders = python setup.py sphericart-torch/python sphericart-torch/setup.py sphericart-jax/python sphericart-jax/setup.py

[testenv:tests]
# this environement runs Python tests
Expand Down

0 comments on commit 6b51e53

Please sign in to comment.