diff --git a/src/unxt/_src/quantity/register_primitives.py b/src/unxt/_src/quantity/register_primitives.py index d3fe5ed0..25ee4b08 100644 --- a/src/unxt/_src/quantity/register_primitives.py +++ b/src/unxt/_src/quantity/register_primitives.py @@ -30,6 +30,7 @@ from .base_parametric import AbstractParametricQuantity from .core import Quantity from unxt._src.units import unit, unit_of +from unxt._src.utils import promote_dtypes T = TypeVar("T") @@ -143,7 +144,7 @@ def _add_p_aqaq(x: AbstractQuantity, y: AbstractQuantity) -> AbstractQuantity: >>> import quaxed.numpy as jnp >>> from unxt.quantity import UncheckedQuantity - >>> q1 = UncheckedQuantity(1.0, "km") + >>> q1 = UncheckedQuantity(1, "km") >>> q2 = UncheckedQuantity(500.0, "m") >>> jnp.add(q1, q2) UncheckedQuantity(Array(1.5, dtype=float32, ...), unit='km') @@ -151,7 +152,7 @@ def _add_p_aqaq(x: AbstractQuantity, y: AbstractQuantity) -> AbstractQuantity: UncheckedQuantity(Array(1.5, dtype=float32, ...), unit='km') >>> from unxt.quantity import Quantity - >>> q1 = Quantity(1.0, "km") + >>> q1 = Quantity(1, "km") >>> q2 = Quantity(500.0, "m") >>> jnp.add(q1, q2) Quantity['length'](Array(1.5, dtype=float32, ...), unit='km') @@ -170,7 +171,7 @@ def _add_p_vaq(x: ArrayLike, y: AbstractQuantity) -> AbstractQuantity: -------- >>> import quaxed.numpy as jnp - >>> x = jnp.asarray(500.0) + >>> x = jnp.asarray(500) `unxt.UncheckedQuantity`: @@ -237,7 +238,7 @@ def _add_p_aqv(x: AbstractQuantity, y: ArrayLike) -> AbstractQuantity: -------- >>> import quaxed.numpy as jnp - >>> y = jnp.asarray(500.0) + >>> y = jnp.asarray(500) `unxt.UncheckedQuantity`: @@ -317,7 +318,7 @@ def _add_any_p( >>> import quaxed.numpy as jnp >>> import unxt as u - >>> q1 = u.Quantity(1.0, "km") + >>> q1 = u.Quantity(1, "km") >>> q2 = u.Quantity(500.0, "m") >>> @jax.jit @@ -495,14 +496,14 @@ def _atan2_p_aqaq(x: AbstractQuantity, y: AbstractQuantity) -> AbstractQuantity: >>> import quaxed.numpy as jnp >>> from unxt.quantity import UncheckedQuantity >>> q1 = UncheckedQuantity(1, "m") - >>> q2 = UncheckedQuantity(3, "m") + >>> q2 = UncheckedQuantity(3.0, "m") >>> jnp.atan2(q1, q2) UncheckedQuantity(Array(0.32175055, dtype=float32, ...), unit='rad') """ x, y = promote(x, y) # e.g. Distance -> Quantity - y_ = ustrip(x.unit, y) - return type_np(x)(lax.atan2(ustrip(x), y_), unit=radian) + yv = ustrip(x.unit, y) + return type_np(x)(lax.atan2(ustrip(x), yv), unit=radian) @register(lax.atan2_p) @@ -516,14 +517,14 @@ def _atan2_p_qq( >>> import quaxed.numpy as jnp >>> from unxt.quantity import Quantity >>> q1 = Quantity(1, "m") - >>> q2 = Quantity(3, "m") + >>> q2 = Quantity(3.0, "m") >>> jnp.atan2(q1, q2) Quantity['angle'](Array(0.32175055, dtype=float32, ...), unit='rad') """ x, y = promote(x, y) # e.g. Distance -> Quantity - y_ = ustrip(x.unit, y) - return type_np(x)(lax.atan2(ustrip(x), y_), unit=radian) + yv = ustrip(x.unit, y) + return type_np(x)(lax.atan2(ustrip(x), yv), unit=radian) # --------------------------- @@ -538,13 +539,13 @@ def _atan2_p_vaq(x: ArrayLike, y: AbstractQuantity) -> AbstractQuantity: >>> import quaxed.numpy as jnp >>> from unxt.quantity import UncheckedQuantity >>> x1 = jnp.asarray(1.0) - >>> q2 = UncheckedQuantity(3.0, "") + >>> q2 = UncheckedQuantity(3, "") >>> jnp.atan2(x1, q2) UncheckedQuantity(Array(0.32175055, dtype=float32, ...), unit='rad') """ - y_ = ustrip(one, y) - return type_np(y)(lax.atan2(x, y_), unit=radian) + yv = ustrip(one, y) + return type_np(y)(lax.atan2(x, yv), unit=radian) @register(lax.atan2_p) @@ -558,13 +559,13 @@ def _atan2_p_vq( >>> import quaxed.numpy as jnp >>> from unxt.quantity import Quantity >>> x1 = jnp.asarray(1.0) - >>> q2 = Quantity(3.0, "") + >>> q2 = Quantity(3, "") >>> jnp.atan2(x1, q2) Quantity['angle'](Array(0.32175055, dtype=float32, ...), unit='rad') """ - y_ = ustrip(one, y) - return Quantity(lax.atan2(x, y_), unit=radian) + yv = ustrip(one, y) + return Quantity(lax.atan2(x, yv), unit=radian) # --------------------------- @@ -579,13 +580,13 @@ def _atan2_p_aqv(x: AbstractQuantity, y: ArrayLike) -> AbstractQuantity: >>> import quaxed.numpy as jnp >>> from unxt.quantity import UncheckedQuantity >>> q1 = UncheckedQuantity(1.0, "") - >>> x2 = jnp.asarray(3.0) + >>> x2 = jnp.asarray(3) >>> jnp.atan2(q1, x2) UncheckedQuantity(Array(0.32175055, dtype=float32, ...), unit='rad') """ - x_ = ustrip(one, x) - return type_np(x)(lax.atan2(x_, y), unit=radian) + xv = ustrip(one, x) + return type_np(x)(lax.atan2(xv, y), unit=radian) @register(lax.atan2_p) @@ -599,13 +600,13 @@ def _atan2_p_qv( >>> import quaxed.numpy as jnp >>> from unxt.quantity import Quantity >>> q1 = Quantity(1.0, "") - >>> x2 = jnp.asarray(3.0) + >>> x2 = jnp.asarray(3) >>> jnp.atan2(q1, x2) Quantity['angle'](Array(0.32175055, dtype=float32, ...), unit='rad') """ - x_ = ustrip(one, x) - return type_np(x)(lax.atan2(x_, y), unit=radian) + xv = ustrip(one, x) + return type_np(x)(lax.atan2(xv, y), unit=radian) # ============================================================================== @@ -761,6 +762,9 @@ def _clamp_p( >>> lax.clamp(min, q, max) UncheckedQuantity(Array([0, 1, 2], dtype=int32), unit='m') + >>> jnp.clip(q.astype(float), min, max) + UncheckedQuantity(Array([0., 1., 2.], dtype=float32), unit='m') + >>> from unxt.quantity import Quantity >>> min = Quantity(0, "m") >>> max = Quantity(2, "m") @@ -768,6 +772,9 @@ def _clamp_p( >>> lax.clamp(min, q, max) Quantity['length'](Array([0, 1, 2], dtype=int32), unit='m') + >>> jnp.clip(q.astype(float), min, max) + Quantity['length'](Array([0., 1., 2.], dtype=float32), unit='m') + """ return replace( x, value=qlax.clamp(ustrip(x.unit, min), ustrip(x), ustrip(x.unit, max)) @@ -958,12 +965,10 @@ def _concatenate_p_aq(*operands: AbstractQuantity, dimension: Any) -> AbstractQu """ operand0 = operands[0] - units_ = operand0.unit + u = operand0.unit return replace( operand0, - value=qlax.concatenate( - [ustrip(units_, op) for op in operands], dimension=dimension - ), + value=qlax.concatenate([ustrip(u, op) for op in operands], dimension=dimension), ) @@ -1484,8 +1489,8 @@ def _div_p_vq(x: ArrayLike, y: AbstractQuantity) -> AbstractQuantity: Quantity['wavenumber'](Array([0.5, 1. , 1.5], dtype=float32), unit='1 / m') """ - units_ = (1 / y.unit).unit # TODO: better construction of the unit - return type_np(y)(lax.div(x, ustrip(y)), unit=units_) + u = (1 / y.unit).unit # TODO: better construction of the unit + return type_np(y)(lax.div(x, ustrip(y)), unit=u) @register(lax.div_p) @@ -1495,21 +1500,21 @@ def _div_p_qv(x: AbstractQuantity, y: ArrayLike) -> AbstractQuantity: Examples -------- >>> import quaxed.numpy as jnp - >>> y = jnp.asarray([1.0, 2, 3]) + >>> y = jnp.asarray([1, 2, 3]) >>> from unxt.quantity import UncheckedQuantity >>> q = UncheckedQuantity(6.0, "m") >>> jnp.divide(q, y) - UncheckedQuantity(Array([6., 3., 2.], dtype=float32), unit='m') + UncheckedQuantity(Array([6., 3., 2.], dtype=float32, ...), unit='m') >>> q / y - UncheckedQuantity(Array([6., 3., 2.], dtype=float32), unit='m') + UncheckedQuantity(Array([6., 3., 2.], dtype=float32, ...), unit='m') >>> from unxt.quantity import Quantity >>> q = Quantity(6.0, "m") >>> jnp.divide(q, y) - Quantity['length'](Array([6., 3., 2.], dtype=float32), unit='m') + Quantity['length'](Array([6., 3., 2.], dtype=float32, ...), unit='m') >>> q / y - Quantity['length'](Array([6., 3., 2.], dtype=float32), unit='m') + Quantity['length'](Array([6., 3., 2.], dtype=float32, ...), unit='m') """ return replace(x, value=qlax.div(ustrip(x), y)) @@ -2536,12 +2541,22 @@ def _lt_p_qq(x: AbstractQuantity, y: AbstractQuantity, /) -> ArrayLike: Array([ True, False, False], dtype=bool) """ + # Check if the units are convertible. x = eqx.error_if( # TODO: customize Exception type x, not is_unit_convertible(x.unit, y.unit), f"Cannot compare Q(x, {x.unit}) < Q(y, {y.unit}).", ) - return qlax.lt(ustrip(x), ustrip(x.unit, y)) # re-dispatch on the values + # Strip the units to compare the values. + xv = ustrip(x) + yv = ustrip(x.unit, y) # this can change the dtype + # `lax.lt` requires that the dtypes are the same. Since `ustrip` can change + # the dtype, we need to special-case the situation where the dtypes started + # off the same, but `ustrip` changed them. + if x.dtype == y.dtype and xv.dtype != yv.dtype: + xv, yv = promote_dtypes(xv, yv) + + return qlax.lt(xv, yv) # re-dispatch on the values @register(lax.lt_p) diff --git a/src/unxt/_src/utils.py b/src/unxt/_src/utils.py index 07a6891f..2b389c44 100644 --- a/src/unxt/_src/utils.py +++ b/src/unxt/_src/utils.py @@ -7,6 +7,12 @@ from typing import TYPE_CHECKING, Any, cast +from jax import dtypes +from jaxtyping import ArrayLike +from quax import quaxify + +import quaxed.lax as qlax + if TYPE_CHECKING: from typing import Self @@ -39,3 +45,30 @@ def __new__(cls, /, *_: Any, **__: Any) -> "Self": self = object.__new__(cls) _singleton_insts[cls] = self return self + + +@quaxify # type: ignore[misc] +def promote_dtypes(*arrays: ArrayLike) -> tuple[ArrayLike, ...]: + """Promotes all input arrays to a common dtype. + + Examples + -------- + >>> import jax.numpy as jnp + >>> import unxt as u + + >>> x1 = jnp.array([1, 2, 3], dtype=jnp.int32) + >>> x2 = jnp.array([4, 5, 6], dtype=jnp.float32) + + >>> x1, x2 = promote_dtypes(x1, x2) + >>> x1.dtype, x2.dtype + (dtype('float32'), dtype('float32')) + + >>> q1 = u.Quantity.from_([1, 2, 3], "m", dtype=int) + >>> q2 = u.Quantity([4.0, 5, 6], unit="km") + >>> q1, q2 = promote_dtypes(q1, q2) + >>> q1.dtype, q2.dtype + (dtype('float32'), dtype('float32')) + + """ + common_dtype = dtypes.result_type(*arrays) + return tuple(qlax.convert_element_type(arr, common_dtype) for arr in arrays)