Skip to content

Commit

Permalink
test: test on float32 (#51)
Browse files Browse the repository at this point in the history
* test: test on float32

Signed-off-by: nstarman <[email protected]>
  • Loading branch information
nstarman authored Mar 3, 2024
1 parent 0f1bd4f commit ff717da
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 13 deletions.
9 changes: 5 additions & 4 deletions src/jax_quantity/_register_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,9 +367,10 @@ def _concatenate_p_jqnd(
... [xp.sin(theta), xp.cos(theta), 0],
... [0, 0, 1]])
>>> Rz
Quantity['dimensionless'](Array([[ 0.70710678, -0.70710678, 0. ],
[ 0.70710678, 0.70710678, 0. ],
[ 0. , 0. , 1. ]], dtype=float64), unit='')
Quantity[...](Array([[ 0.70710677, -0.70710677, 0. ],
[ 0.70710677, 0.70710677, 0. ],
[ 0. , 0. , 1. ]],
dtype=float32), unit='')
"""
return Quantity(
Expand Down Expand Up @@ -565,7 +566,7 @@ def _dot_general_jq(
... [0, 0, 1]])
>>> q = Quantity([1, 0, 0], "m")
>>> Rz @ q
Quantity['length'](Array([0.70710678, 0.70710678, 0. ], dtype=float64), unit='m')
Quantity['length'](Array([0.70710677, 0.70710677, 0. ], dtype=float32), unit='m')
"""
return Quantity(
lax.dot_general_p.bind(
Expand Down
2 changes: 2 additions & 0 deletions tests/test_array_api_jax_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -1448,6 +1448,7 @@ def test_min():
assert jnp.array_equal(got.value, expected.value)


@pytest.mark.filterwarnings("ignore:Explicitly requested dtype") # TODO: Why?
def test_prod():
"""Test `prod`."""
x = Quantity(xp.asarray([1, 2, 3], dtype=float), u.m)
Expand All @@ -1470,6 +1471,7 @@ def test_std():
assert jnp.array_equal(got.value, expected.value)


@pytest.mark.filterwarnings("ignore:Explicitly requested dtype") # TODO: Why?
def test_sum():
"""Test `sum`."""
x = Quantity(xp.asarray([1, 2, 3], dtype=float), u.m)
Expand Down
18 changes: 9 additions & 9 deletions tests/test_quantity.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,8 @@
import pytest
from hypothesis import example, given, strategies as st
from hypothesis.extra.array_api import make_strategies_namespace
from hypothesis.extra.numpy import (
array_shapes as np_array_shapes,
arrays as np_arrays,
)
from hypothesis.extra.numpy import array_shapes as np_array_shapes, arrays as np_arrays
from jax.dtypes import canonicalize_dtype
from quax import quaxify

import array_api_jax_compat
Expand All @@ -22,13 +20,15 @@

xps = make_strategies_namespace(jax_xp)

jax.config.update("jax_enable_x64", val=True)

jaxint = canonicalize_dtype(int)
jaxfloat = canonicalize_dtype(float)

integers_strategy = st.integers(
min_value=np.iinfo(np.int64).min, max_value=np.iinfo(np.int64).max
min_value=np.iinfo(jaxint).min, max_value=np.iinfo(jaxint).max
)
floats_strategy = st.floats(
min_value=np.finfo(np.float64).min, max_value=np.finfo(np.float64).max
min_value=np.finfo(jaxfloat).min, max_value=np.finfo(jaxfloat).max
)


Expand All @@ -42,12 +42,12 @@
# | st.lists(st.lists(integers_strategy)) # TODO: enable nested lists
# | st.lists(st.lists(floats_strategy))
| np_arrays(
dtype=np.float64,
dtype=np.float32,
shape=np_array_shapes(),
elements={"allow_nan": False, "allow_infinity": False},
)
| xps.arrays(
dtype=xps.floating_dtypes(),
dtype=np.float32,
shape=xps.array_shapes(),
elements={"allow_nan": False, "allow_infinity": False},
)
Expand Down

0 comments on commit ff717da

Please sign in to comment.