Skip to content

Commit

Permalink
minor edits.
Browse files Browse the repository at this point in the history
  • Loading branch information
nickjbrowning committed Jan 15, 2024
1 parent cfa089c commit c87d07a
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions sphericart-jax/python/tests/simple_jax_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,12 @@ def xyz():
def compute(xyz):
sph = spherical_harmonics(l_max=4, normalized=False, xyz=xyz)
assert jnp.linalg.norm(sph) != 0.0
return sph.sum()
return sph

xyzs = jax.device_put(xyz(), device=jax.devices('gpu')[0])

sph = compute(xyz())

print ("sum sph:", sph)
print ("cuda jax succesful")
print (sph)
print ("sum sph:", sph.sum())
print ("cuda jax successful")

0 comments on commit c87d07a

Please sign in to comment.