Skip to content

Commit

Permalink
small adjustments.
Browse files Browse the repository at this point in the history
  • Loading branch information
nickjbrowning committed Jan 16, 2025
1 parent fec2534 commit cad746a
Showing 1 changed file with 18 additions and 18 deletions.
36 changes: 18 additions & 18 deletions sphericart-jax/python/sphericart/jax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,42 +29,42 @@ def get_minimum_cuda_version_for_jax(jax_version):
), # JAX 0.4.11 - 0.4.25: CUDA 11.8+
]

# Parse the current JAX version
jax_ver = version.parse(jax_version)

# Find the appropriate CUDA version range
for start, end, cuda_version in version_ranges:
if start <= jax_ver <= end:
return cuda_version

# Default to a safe version if no range matches
raise ValueError(f"Unsupported JAX version: {jax_version}")


# register the operations to xla
for _name, _value in sphericart_jax_cpu.registrations().items():
jax.lib.xla_client.register_custom_call_target(_name, _value, platform="cpu")

has_sphericart_jax_cuda = False
try:
from .lib import sphericart_jax_cuda
from .lib.sphericart_jax_cuda import get_cuda_runtime_version

cuda_version = get_cuda_runtime_version()
cuda_version = (cuda_version["major"], cuda_version["minor"])
jax_version = jax.__version__
required_version = get_minimum_cuda_version_for_jax(jax_version)
if cuda_version < required_version:
raise RuntimeError(
f"Installed CUDA Toolkit: {cuda_version[0]}.{cuda_version[1]} \
is not compatible with installed JAX version {jax_version}. \
Minimum required CUDA Toolkit for your JAX version \
is {required_version[0]}.{required_version[1]}. \
Please upgrade your CUDA Toolkit to meet the requirements."
)

has_sphericart_jax_cuda = True
# register the operations to xla
for _name, _value in sphericart_jax_cuda.registrations().items():
jax.lib.xla_client.register_custom_call_target(_name, _value, platform="gpu")

except ImportError:
pass

if (has_sphericart_jax_cuda):
from .lib.sphericart_jax_cuda import get_cuda_runtime_version
#check the jaxlib version is suitable for the host cudatoolkit.
cuda_version = get_cuda_runtime_version()
cuda_version = (cuda_version["major"], cuda_version["minor"])
jax_version = jax.__version__
required_version = get_minimum_cuda_version_for_jax(jax_version)
if cuda_version < required_version:
raise RuntimeError(
f"Installed CUDA Toolkit: {cuda_version[0]}.{cuda_version[1]} \
is not compatible with installed JAX version {jax_version}. \
Minimum required CUDA Toolkit for your JAX version \
is {required_version[0]}.{required_version[1]}. \
Please upgrade your CUDA Toolkit to meet the requirements."
)

0 comments on commit cad746a

Please sign in to comment.