diff --git a/src/jax_quantity/_register_primitives.py b/src/jax_quantity/_register_primitives.py index dbf0185f..be065489 100644 --- a/src/jax_quantity/_register_primitives.py +++ b/src/jax_quantity/_register_primitives.py @@ -367,9 +367,10 @@ def _concatenate_p_jqnd( ... [xp.sin(theta), xp.cos(theta), 0], ... [0, 0, 1]]) >>> Rz - Quantity['dimensionless'](Array([[ 0.70710678, -0.70710678, 0. ], - [ 0.70710678, 0.70710678, 0. ], - [ 0. , 0. , 1. ]], dtype=float64), unit='') + Quantity[...](Array([[ 0.70710677, -0.70710677, 0. ], + [ 0.70710677, 0.70710677, 0. ], + [ 0. , 0. , 1. ]], + dtype=float32), unit='') """ return Quantity( @@ -565,7 +566,7 @@ def _dot_general_jq( ... [0, 0, 1]]) >>> q = Quantity([1, 0, 0], "m") >>> Rz @ q - Quantity['length'](Array([0.70710678, 0.70710678, 0. ], dtype=float64), unit='m') + Quantity['length'](Array([0.70710677, 0.70710677, 0. ], dtype=float32), unit='m') """ return Quantity( lax.dot_general_p.bind( diff --git a/tests/test_array_api_jax_compat.py b/tests/test_array_api_jax_compat.py index 35d3cfbb..98e31018 100644 --- a/tests/test_array_api_jax_compat.py +++ b/tests/test_array_api_jax_compat.py @@ -1448,6 +1448,7 @@ def test_min(): assert jnp.array_equal(got.value, expected.value) +@pytest.mark.filterwarnings("ignore:Explicitly requested dtype") # TODO: Why? def test_prod(): """Test `prod`.""" x = Quantity(xp.asarray([1, 2, 3], dtype=float), u.m) @@ -1470,6 +1471,7 @@ def test_std(): assert jnp.array_equal(got.value, expected.value) +@pytest.mark.filterwarnings("ignore:Explicitly requested dtype") # TODO: Why? def test_sum(): """Test `sum`.""" x = Quantity(xp.asarray([1, 2, 3], dtype=float), u.m) diff --git a/tests/test_quantity.py b/tests/test_quantity.py index e9541c81..89c0b406 100644 --- a/tests/test_quantity.py +++ b/tests/test_quantity.py @@ -10,10 +10,8 @@ import pytest from hypothesis import example, given, strategies as st from hypothesis.extra.array_api import make_strategies_namespace -from hypothesis.extra.numpy import ( - array_shapes as np_array_shapes, - arrays as np_arrays, -) +from hypothesis.extra.numpy import array_shapes as np_array_shapes, arrays as np_arrays +from jax.dtypes import canonicalize_dtype from quax import quaxify import array_api_jax_compat @@ -22,13 +20,15 @@ xps = make_strategies_namespace(jax_xp) -jax.config.update("jax_enable_x64", val=True) + +jaxint = canonicalize_dtype(int) +jaxfloat = canonicalize_dtype(float) integers_strategy = st.integers( - min_value=np.iinfo(np.int64).min, max_value=np.iinfo(np.int64).max + min_value=np.iinfo(jaxint).min, max_value=np.iinfo(jaxint).max ) floats_strategy = st.floats( - min_value=np.finfo(np.float64).min, max_value=np.finfo(np.float64).max + min_value=np.finfo(jaxfloat).min, max_value=np.finfo(jaxfloat).max ) @@ -42,12 +42,12 @@ # | st.lists(st.lists(integers_strategy)) # TODO: enable nested lists # | st.lists(st.lists(floats_strategy)) | np_arrays( - dtype=np.float64, + dtype=np.float32, shape=np_array_shapes(), elements={"allow_nan": False, "allow_infinity": False}, ) | xps.arrays( - dtype=xps.floating_dtypes(), + dtype=np.float32, shape=xps.array_shapes(), elements={"allow_nan": False, "allow_infinity": False}, )