From 6b51e5391df47d91a8c6da844da46ba794cf70a3 Mon Sep 17 00:00:00 2001 From: Filippo Bigi <98903385+frostedoyster@users.noreply.github.com> Date: Tue, 20 Aug 2024 15:47:25 +0200 Subject: [PATCH] Lint `sphericart-jax` (#128) --- .../python/sphericart/jax/__init__.py | 5 +++-- sphericart-jax/python/sphericart/jax/ddsph.py | 14 +++++++++----- sphericart-jax/python/sphericart/jax/dsph.py | 18 ++++++++++-------- sphericart-jax/python/sphericart/jax/sph.py | 16 ++++++++-------- .../sphericart/jax/spherical_harmonics.py | 19 ++++++++++--------- sphericart-jax/python/sphericart/jax/utils.py | 2 ++ sphericart-jax/python/tests/pure_jax_sph.py | 5 +++-- sphericart-jax/python/tests/test_autograd.py | 6 +++--- .../python/tests/test_consistency.py | 9 ++++----- sphericart-jax/python/tests/test_nn.py | 4 ++-- sphericart-jax/python/tests/test_precision.py | 4 ++-- sphericart-jax/python/tests/test_pure_jax.py | 13 +++++++------ .../python/tests/test_transforms.py | 16 ++++++---------- sphericart-jax/setup.py | 5 ++--- tox.ini | 2 +- 15 files changed, 72 insertions(+), 66 deletions(-) diff --git a/sphericart-jax/python/sphericart/jax/__init__.py b/sphericart-jax/python/sphericart/jax/__init__.py index d35415eff..95ba63047 100644 --- a/sphericart-jax/python/sphericart/jax/__init__.py +++ b/sphericart-jax/python/sphericart/jax/__init__.py @@ -1,6 +1,6 @@ import jax from .lib import sphericart_jax_cpu -from .spherical_harmonics import spherical_harmonics +from .spherical_harmonics import spherical_harmonics # noqa: F401 # register the operations to xla @@ -9,9 +9,10 @@ try: from .lib import sphericart_jax_cuda + # register the operations to xla for _name, _value in sphericart_jax_cuda.registrations().items(): jax.lib.xla_client.register_custom_call_target(_name, _value, platform="gpu") - + except ImportError: pass diff --git a/sphericart-jax/python/sphericart/jax/ddsph.py b/sphericart-jax/python/sphericart/jax/ddsph.py index 8de05b9e9..c80e2a982 100644 --- a/sphericart-jax/python/sphericart/jax/ddsph.py +++ b/sphericart-jax/python/sphericart/jax/ddsph.py @@ -1,11 +1,13 @@ -import jax import math from functools import partial + +import jax from jax import core from jax.core import ShapedArray from jax.interpreters import mlir, xla -from jax.interpreters.mlir import ir, custom_call -from .utils import default_layouts, build_sph_descriptor +from jax.interpreters.mlir import custom_call, ir + +from .utils import build_sph_descriptor, default_layouts # This file registers the _ddsph_p primitive and defines its implementation, @@ -18,7 +20,9 @@ def ddsph(xyz, l_max, normalized): - sph, dsph, ddsph = _ddsph_p.bind(xyz, l_max, normalized, l_max_c=l_max, normalized_c=normalized) + sph, dsph, ddsph = _ddsph_p.bind( + xyz, l_max, normalized, l_max_c=l_max, normalized_c=normalized + ) return sph, dsph, ddsph @@ -105,7 +109,7 @@ def ddsph_lowering_cuda(ctx, xyz, l_max, normalized, *, l_max_c, normalized_c): operands=[xyz], operand_layouts=default_layouts(xyz_shape), result_layouts=default_layouts(sph_shape, dsph_shape, ddsph_shape), - backend_config=descriptor + backend_config=descriptor, ).results diff --git a/sphericart-jax/python/sphericart/jax/dsph.py b/sphericart-jax/python/sphericart/jax/dsph.py index 82ff0a461..93a594c3a 100644 --- a/sphericart-jax/python/sphericart/jax/dsph.py +++ b/sphericart-jax/python/sphericart/jax/dsph.py @@ -1,15 +1,15 @@ -import jax -import jax.numpy as jnp import math from functools import partial + +import jax +import jax.numpy as jnp from jax import core from jax.core import ShapedArray -from jax.interpreters import mlir, xla -from jax.interpreters.mlir import ir, custom_call -from jax.interpreters import ad +from jax.interpreters import ad, mlir, xla +from jax.interpreters.mlir import custom_call, ir -from .utils import default_layouts, build_sph_descriptor from .ddsph import ddsph +from .utils import build_sph_descriptor, default_layouts # This file registers the _dsph_p primitive and defines its implementation, @@ -22,7 +22,9 @@ def dsph(xyz, l_max, normalized): - sph, dsph = _dsph_p.bind(xyz, l_max, normalized, l_max_c=l_max, normalized_c=normalized) + sph, dsph = _dsph_p.bind( + xyz, l_max, normalized, l_max_c=l_max, normalized_c=normalized + ) return sph, dsph @@ -100,7 +102,7 @@ def dsph_lowering_cuda(ctx, xyz, l_max, normalized, *, l_max_c, normalized_c): operands=[xyz], operand_layouts=default_layouts(xyz_shape), result_layouts=default_layouts(sph_shape, dsph_shape), - backend_config=descriptor + backend_config=descriptor, ).results diff --git a/sphericart-jax/python/sphericart/jax/sph.py b/sphericart-jax/python/sphericart/jax/sph.py index 586f2d039..d71587138 100644 --- a/sphericart-jax/python/sphericart/jax/sph.py +++ b/sphericart-jax/python/sphericart/jax/sph.py @@ -1,15 +1,15 @@ -import jax -import jax.numpy as jnp import math from functools import partial + +import jax +import jax.numpy as jnp from jax import core from jax.core import ShapedArray -from jax.interpreters import mlir, xla -from jax.interpreters.mlir import ir, custom_call -from jax.interpreters import ad +from jax.interpreters import ad, mlir, xla +from jax.interpreters.mlir import custom_call, ir from .dsph import dsph -from .utils import default_layouts, build_sph_descriptor +from .utils import build_sph_descriptor, default_layouts # register the sph primitive @@ -101,7 +101,7 @@ def sph_lowering_cuda(ctx, xyz, l_max, normalized, *, l_max_c, normalized_c): raise NotImplementedError(f"Unsupported dtype {dtype}") descriptor = build_sph_descriptor(n_samples, l_max_c, normalized_c) - + return custom_call( op_name, # Output types @@ -113,7 +113,7 @@ def sph_lowering_cuda(ctx, xyz, l_max, normalized, *, l_max_c, normalized_c): # Layout specification: operand_layouts=default_layouts(xyz_shape), result_layouts=default_layouts(out_shape), - backend_config=descriptor + backend_config=descriptor, ).results diff --git a/sphericart-jax/python/sphericart/jax/spherical_harmonics.py b/sphericart-jax/python/sphericart/jax/spherical_harmonics.py index 8751d0b1b..7421446cd 100644 --- a/sphericart-jax/python/sphericart/jax/spherical_harmonics.py +++ b/sphericart-jax/python/sphericart/jax/spherical_harmonics.py @@ -4,27 +4,28 @@ def spherical_harmonics(xyz, l_max, normalized=False): """Computes the Spherical harmonics and their derivatives within the JAX framework. - + This function supports ``jit``, ``vmap``, and up to two rounds of forward and/or - backward automatic differentiation (``grad``, ``jacfwd``, ``jacrev``, ``hessian``, ...). - For the moment, it does not support ``pmap``. + backward automatic differentiation (``grad``, ``jacfwd``, ``jacrev``, ``hessian``, + ...). For the moment, it does not support ``pmap``. - Note that the ``l_max`` and ``normalized`` arguments (positions 1 and 2 in the signature) - should be tagged as static when jit-ing the function: + Note that the ``l_max`` and ``normalized`` arguments (positions 1 and 2 in the + signature) should be tagged as static when jit-ing the function: >>> import jax >>> import sphericart.jax - >>> jitted_sph_function = jax.jit(sphericart.jax.spherical_harmonics, static_argnums=(1, 2)) + >>> sph_fn_jit = jax.jit(sphericart.jax.spherical_harmonics, static_argnums=(1, 2)) Parameters ---------- xyz : jax array [..., 3] - single vector or set of vectors in 3D. All dimensions are optional except for the last + single vector or set of vectors in 3D. All dimensions are optional except for + the last l_max : int maximum order of the spherical harmonics (included) normalized : bool - whether the function computes Cartesian solid harmonics (``normalized=False``, default) - or normalized spherical harmonicsi (``normalized=True``) + whether the function computes Cartesian solid harmonics (``normalized=False``, + default) or normalized spherical harmonicsi (``normalized=True``) Returns ------- diff --git a/sphericart-jax/python/sphericart/jax/utils.py b/sphericart-jax/python/sphericart/jax/utils.py index f3bdf6800..3579f75a8 100644 --- a/sphericart-jax/python/sphericart/jax/utils.py +++ b/sphericart-jax/python/sphericart/jax/utils.py @@ -1,9 +1,11 @@ def default_layouts(*shapes): return [range(len(shape) - 1, -1, -1) for shape in shapes] + try: from .lib.sphericart_jax_cuda import build_sph_descriptor except ImportError: + def build_sph_descriptor(a, b, c): raise ValueError( "Trying to use sphericart-jax on CUDA, " diff --git a/sphericart-jax/python/tests/pure_jax_sph.py b/sphericart-jax/python/tests/pure_jax_sph.py index b2ef64c31..37b83a93a 100644 --- a/sphericart-jax/python/tests/pure_jax_sph.py +++ b/sphericart-jax/python/tests/pure_jax_sph.py @@ -1,6 +1,7 @@ +from functools import partial + import jax import jax.numpy as jnp -from functools import partial @partial(jax.vmap, in_axes=(0, None)) @@ -14,7 +15,7 @@ def pure_jax_spherical_harmonics(xyz, l_max): prefactors = jnp.empty(((l_max + 1) ** 2,)) ylm = jnp.empty(((l_max + 1) ** 2,)) - for l in range(l_max + 1): + for l in range(l_max + 1): # noqa: E741 prefactors = prefactors.at[l**2 + l].set(jnp.sqrt((2 * l + 1) / (2 * jnp.pi))) for m in range(1, l + 1): prefactors = prefactors.at[l**2 + l + m].set( diff --git a/sphericart-jax/python/tests/test_autograd.py b/sphericart-jax/python/tests/test_autograd.py index ad8ec7101..ebe00e547 100644 --- a/sphericart-jax/python/tests/test_autograd.py +++ b/sphericart-jax/python/tests/test_autograd.py @@ -1,8 +1,8 @@ -import pytest import jax - import jax.numpy as jnp -import jax._src.test_util as jtu +import jax.test_util as jtu +import pytest + import sphericart.jax diff --git a/sphericart-jax/python/tests/test_consistency.py b/sphericart-jax/python/tests/test_consistency.py index e02415b2c..0563228cd 100644 --- a/sphericart-jax/python/tests/test_consistency.py +++ b/sphericart-jax/python/tests/test_consistency.py @@ -1,18 +1,17 @@ -import pytest import jax +import numpy as np +import pytest + import sphericart import sphericart.jax - -import jax.numpy as jnp -import numpy as np - @pytest.fixture def xyz(): key = jax.random.PRNGKey(0) return 6 * jax.random.normal(key, (100, 3)) + @pytest.mark.parametrize("normalized", [False, True]) @pytest.mark.parametrize("l_max", [4, 7, 10]) def test_consistency(xyz, l_max, normalized): diff --git a/sphericart-jax/python/tests/test_nn.py b/sphericart-jax/python/tests/test_nn.py index d11fd224a..b6128e44c 100644 --- a/sphericart-jax/python/tests/test_nn.py +++ b/sphericart-jax/python/tests/test_nn.py @@ -1,7 +1,7 @@ +import equinox as eqx import jax - import jax.numpy as jnp -import equinox as eqx + import sphericart.jax diff --git a/sphericart-jax/python/tests/test_precision.py b/sphericart-jax/python/tests/test_precision.py index 12ea29312..43c840f48 100644 --- a/sphericart-jax/python/tests/test_precision.py +++ b/sphericart-jax/python/tests/test_precision.py @@ -1,7 +1,7 @@ -import pytest import jax - import jax.numpy as jnp +import pytest + import sphericart.jax diff --git a/sphericart-jax/python/tests/test_pure_jax.py b/sphericart-jax/python/tests/test_pure_jax.py index 66f6c1512..c90a293bc 100644 --- a/sphericart-jax/python/tests/test_pure_jax.py +++ b/sphericart-jax/python/tests/test_pure_jax.py @@ -1,10 +1,9 @@ -import numpy as np import jax import jax.numpy as jnp import pytest +from pure_jax_sph import pure_jax_spherical_harmonics import sphericart.jax -from pure_jax_sph import pure_jax_spherical_harmonics @pytest.fixture @@ -42,14 +41,16 @@ def test_jacrev(xyz, l_max): @pytest.mark.parametrize("l_max", [2, 7]) def test_gradgrad(xyz, l_max): - sum_sph = lambda x, l_max, normalized: jnp.sum( + sum_sph = lambda x, l_max, normalized: jnp.sum( # noqa: E731 sphericart.jax.spherical_harmonics(x, l_max, normalized) ) - pure_jax_sum_sph = lambda x, l_max: jnp.sum(pure_jax_spherical_harmonics(x, l_max)) - sum_grad_sph = lambda x, l_max, normalized: jnp.sum( + pure_jax_sum_sph = lambda x, l_max: jnp.sum( # noqa: E731 + pure_jax_spherical_harmonics(x, l_max) + ) + sum_grad_sph = lambda x, l_max, normalized: jnp.sum( # noqa: E731 jax.grad(sum_sph)(x, l_max, normalized) ) - pure_jax_sum_grad_sph = lambda x, l_max: jnp.sum( + pure_jax_sum_grad_sph = lambda x, l_max: jnp.sum( # noqa: E731 jax.grad(pure_jax_sum_sph)(x, l_max) ) gradgrad_sph = jax.grad(sum_grad_sph) diff --git a/sphericart-jax/python/tests/test_transforms.py b/sphericart-jax/python/tests/test_transforms.py index 6561be3d8..ca2725b20 100644 --- a/sphericart-jax/python/tests/test_transforms.py +++ b/sphericart-jax/python/tests/test_transforms.py @@ -1,7 +1,7 @@ -import numpy as np -import pytest import jax import jax.numpy as jnp +import numpy as np +import pytest import sphericart.jax @@ -19,18 +19,16 @@ def compute(xyz): # jit compile the function jcompute = jax.jit(compute) - out = jcompute(xyz) + jcompute(xyz) # get gradients for the compiled function - dout = jax.grad(jcompute)(xyz) + jax.grad(jcompute)(xyz) @pytest.mark.parametrize("normalized", [True, False]) @pytest.mark.parametrize("l_max", [4, 7, 10]) def test_jit(xyz, l_max, normalized): jitted_sph = jax.jit(sphericart.jax.spherical_harmonics, static_argnums=(1, 2)) - calculator = sphericart.SphericalHarmonics( - l_max=l_max, normalized=normalized - ) + calculator = sphericart.SphericalHarmonics(l_max=l_max, normalized=normalized) sph = jitted_sph(xyz=xyz, l_max=l_max, normalized=normalized) sph_ref = calculator.compute(np.asarray(xyz)) np.testing.assert_allclose(sph, sph_ref, rtol=2e-5, atol=1e-6) @@ -40,9 +38,7 @@ def test_jit(xyz, l_max, normalized): @pytest.mark.parametrize("l_max", [4, 7, 10]) def test_vmap(xyz, l_max, normalized): vmapped_sph = jax.vmap(sphericart.jax.spherical_harmonics, in_axes=(0, None, None)) - calculator = sphericart.SphericalHarmonics( - l_max=l_max, normalized=normalized - ) + calculator = sphericart.SphericalHarmonics(l_max=l_max, normalized=normalized) sph = vmapped_sph(xyz, l_max, normalized) sph_ref = calculator.compute(np.asarray(xyz)) np.testing.assert_allclose(sph, sph_ref, rtol=2e-5, atol=1e-6) diff --git a/sphericart-jax/setup.py b/sphericart-jax/setup.py index f9946a765..9100b520a 100644 --- a/sphericart-jax/setup.py +++ b/sphericart-jax/setup.py @@ -2,6 +2,7 @@ import subprocess import sys +import pybind11 from setuptools import Extension, setup from setuptools.command.bdist_egg import bdist_egg from setuptools.command.build_ext import build_ext @@ -10,10 +11,10 @@ ROOT = os.path.realpath(os.path.dirname(__file__)) SPHERICART_ARCH_NATIVE = os.environ.get("SPHERICART_ARCH_NATIVE", "ON") + class cmake_ext(build_ext): """Build the native library using cmake""" - def run(self): source_dir = ROOT build_dir = os.path.join(ROOT, "build", "cmake-build") @@ -21,8 +22,6 @@ def run(self): os.makedirs(build_dir, exist_ok=True) - import pybind11 - cmake_prefix_path = [pybind11.get_cmake_dir()] cmake_options = [ diff --git a/tox.ini b/tox.ini index 5e09ae857..5916498b2 100644 --- a/tox.ini +++ b/tox.ini @@ -20,7 +20,7 @@ allowlist_externals = bash pip_install_flags = --no-deps --no-cache --no-build-isolation --check-build-dependencies --force-reinstall -lint_folders = python setup.py sphericart-torch/python sphericart-torch/setup.py +lint_folders = python setup.py sphericart-torch/python sphericart-torch/setup.py sphericart-jax/python sphericart-jax/setup.py [testenv:tests] # this environement runs Python tests