diff --git a/arraycontext/impl/jax/fake_numpy.py b/arraycontext/impl/jax/fake_numpy.py index 662b86c0..6c9192ba 100644 --- a/arraycontext/impl/jax/fake_numpy.py +++ b/arraycontext/impl/jax/fake_numpy.py @@ -23,9 +23,10 @@ """ from functools import partial, reduce -import jax.numpy as jnp import numpy as np +import jax.numpy as jnp + from arraycontext.container import NotAnArrayContainerError, serialize_container from arraycontext.container.traversal import ( rec_map_array_container, diff --git a/arraycontext/impl/pytato/__init__.py b/arraycontext/impl/pytato/__init__.py index 6e04bdcd..c030de71 100644 --- a/arraycontext/impl/pytato/__init__.py +++ b/arraycontext/impl/pytato/__init__.py @@ -720,7 +720,6 @@ def __init__(self, unstable. """ import jax.numpy as jnp - import pytato as pt super().__init__(compile_trace_callback=compile_trace_callback) self.array_types = (pt.Array, jnp.ndarray) @@ -766,7 +765,6 @@ def zeros_like(self, ary): def from_numpy(self, array): import jax - import pytato as pt def _from_numpy(ary): @@ -791,7 +789,6 @@ def freeze(self, array): return array import jax.numpy as jnp - import pytato as pt from arraycontext.container.traversal import rec_keyed_map_array_container diff --git a/arraycontext/pytest.py b/arraycontext/pytest.py index 3ea7d065..d3d719e5 100644 --- a/arraycontext/pytest.py +++ b/arraycontext/pytest.py @@ -206,7 +206,6 @@ def __init__(self, *args, **kwargs): def is_available(cls) -> bool: try: import jax # noqa: F401 - import pytato # noqa: F401 return True except ImportError: