Skip to content

Commit

Permalink
missing conditions in setup, segfault test
Browse files Browse the repository at this point in the history
  • Loading branch information
nickjbrowning committed Jan 14, 2024
1 parent 2b70a15 commit 343c333
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 1 deletion.
4 changes: 4 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,10 @@ def run(self):
f"-DSPHERICART_ARCH_NATIVE={SPHERICART_ARCH_NATIVE}",
]

CUDA_HOME = os.environ.get("CUDA_HOME")
if CUDA_HOME is not None:
cmake_options.append(f"-DCUDA_TOOLKIT_ROOT_DIR={CUDA_HOME}")

if sys.platform.startswith("darwin"):
cmake_options.append("-DCMAKE_OSX_DEPLOYMENT_TARGET:STRING=11.0")

Expand Down
22 changes: 22 additions & 0 deletions sphericart-jax/python/tests/segfault_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import pytest
import jax

# jax.config.update("jax_platform_name", "cpu")
import jax.numpy as jnp
import jax._src.test_util as jtu
import sphericart.jax

def xyz():
key = jax.random.PRNGKey(0)
return 6 * jax.random.normal(key, (100, 3))


def compute(xyz):
sph = sphericart.jax.spherical_harmonics(l_max=4, normalized=False, xyz=xyz)
assert jnp.linalg.norm(sph) != 0.0
return sph.sum()


sph = compute(xyz())

print (sph)
5 changes: 4 additions & 1 deletion sphericart-jax/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,12 @@ def run(self):
f"-DPYTHON_EXECUTABLE={sys.executable}",
f"-DSPHERICART_ARCH_NATIVE={SPHERICART_ARCH_NATIVE}",
f"-DCMAKE_PREFIX_PATH={';'.join(cmake_prefix_path)}",
"-DSPHERICART_ENABLE_CUDA=ON"
]

CUDA_HOME = os.environ.get("CUDA_HOME")
if CUDA_HOME is not None:
cmake_options.append(f"-DCUDA_TOOLKIT_ROOT_DIR={CUDA_HOME}")

if sys.platform.startswith("darwin"):
cmake_options.append("-DCMAKE_OSX_DEPLOYMENT_TARGET:STRING=11.0")

Expand Down

0 comments on commit 343c333

Please sign in to comment.