Skip to content

Commit

Permalink
Backport PR #342: ✨ feat(promotion): enable promotion to common dtype (
Browse files Browse the repository at this point in the history
…#343)

Co-authored-by: Nathaniel Starkman <[email protected]>
  • Loading branch information
meeseeksmachine and nstarman authored Dec 18, 2024
1 parent 4ea4978 commit d0c0804
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 35 deletions.
85 changes: 50 additions & 35 deletions src/unxt/_src/quantity/register_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from .base_parametric import AbstractParametricQuantity
from .core import Quantity
from unxt._src.units import unit, unit_of
from unxt._src.utils import promote_dtypes

T = TypeVar("T")

Expand Down Expand Up @@ -143,15 +144,15 @@ def _add_p_aqaq(x: AbstractQuantity, y: AbstractQuantity) -> AbstractQuantity:
>>> import quaxed.numpy as jnp
>>> from unxt.quantity import UncheckedQuantity
>>> q1 = UncheckedQuantity(1.0, "km")
>>> q1 = UncheckedQuantity(1, "km")
>>> q2 = UncheckedQuantity(500.0, "m")
>>> jnp.add(q1, q2)
UncheckedQuantity(Array(1.5, dtype=float32, ...), unit='km')
>>> q1 + q2
UncheckedQuantity(Array(1.5, dtype=float32, ...), unit='km')
>>> from unxt.quantity import Quantity
>>> q1 = Quantity(1.0, "km")
>>> q1 = Quantity(1, "km")
>>> q2 = Quantity(500.0, "m")
>>> jnp.add(q1, q2)
Quantity['length'](Array(1.5, dtype=float32, ...), unit='km')
Expand All @@ -170,7 +171,7 @@ def _add_p_vaq(x: ArrayLike, y: AbstractQuantity) -> AbstractQuantity:
--------
>>> import quaxed.numpy as jnp
>>> x = jnp.asarray(500.0)
>>> x = jnp.asarray(500)
`unxt.UncheckedQuantity`:
Expand Down Expand Up @@ -237,7 +238,7 @@ def _add_p_aqv(x: AbstractQuantity, y: ArrayLike) -> AbstractQuantity:
--------
>>> import quaxed.numpy as jnp
>>> y = jnp.asarray(500.0)
>>> y = jnp.asarray(500)
`unxt.UncheckedQuantity`:
Expand Down Expand Up @@ -317,7 +318,7 @@ def _add_any_p(
>>> import quaxed.numpy as jnp
>>> import unxt as u
>>> q1 = u.Quantity(1.0, "km")
>>> q1 = u.Quantity(1, "km")
>>> q2 = u.Quantity(500.0, "m")
>>> @jax.jit
Expand Down Expand Up @@ -495,14 +496,14 @@ def _atan2_p_aqaq(x: AbstractQuantity, y: AbstractQuantity) -> AbstractQuantity:
>>> import quaxed.numpy as jnp
>>> from unxt.quantity import UncheckedQuantity
>>> q1 = UncheckedQuantity(1, "m")
>>> q2 = UncheckedQuantity(3, "m")
>>> q2 = UncheckedQuantity(3.0, "m")
>>> jnp.atan2(q1, q2)
UncheckedQuantity(Array(0.32175055, dtype=float32, ...), unit='rad')
"""
x, y = promote(x, y) # e.g. Distance -> Quantity
y_ = ustrip(x.unit, y)
return type_np(x)(lax.atan2(ustrip(x), y_), unit=radian)
yv = ustrip(x.unit, y)
return type_np(x)(lax.atan2(ustrip(x), yv), unit=radian)


@register(lax.atan2_p)
Expand All @@ -516,14 +517,14 @@ def _atan2_p_qq(
>>> import quaxed.numpy as jnp
>>> from unxt.quantity import Quantity
>>> q1 = Quantity(1, "m")
>>> q2 = Quantity(3, "m")
>>> q2 = Quantity(3.0, "m")
>>> jnp.atan2(q1, q2)
Quantity['angle'](Array(0.32175055, dtype=float32, ...), unit='rad')
"""
x, y = promote(x, y) # e.g. Distance -> Quantity
y_ = ustrip(x.unit, y)
return type_np(x)(lax.atan2(ustrip(x), y_), unit=radian)
yv = ustrip(x.unit, y)
return type_np(x)(lax.atan2(ustrip(x), yv), unit=radian)


# ---------------------------
Expand All @@ -538,13 +539,13 @@ def _atan2_p_vaq(x: ArrayLike, y: AbstractQuantity) -> AbstractQuantity:
>>> import quaxed.numpy as jnp
>>> from unxt.quantity import UncheckedQuantity
>>> x1 = jnp.asarray(1.0)
>>> q2 = UncheckedQuantity(3.0, "")
>>> q2 = UncheckedQuantity(3, "")
>>> jnp.atan2(x1, q2)
UncheckedQuantity(Array(0.32175055, dtype=float32, ...), unit='rad')
"""
y_ = ustrip(one, y)
return type_np(y)(lax.atan2(x, y_), unit=radian)
yv = ustrip(one, y)
return type_np(y)(lax.atan2(x, yv), unit=radian)


@register(lax.atan2_p)
Expand All @@ -558,13 +559,13 @@ def _atan2_p_vq(
>>> import quaxed.numpy as jnp
>>> from unxt.quantity import Quantity
>>> x1 = jnp.asarray(1.0)
>>> q2 = Quantity(3.0, "")
>>> q2 = Quantity(3, "")
>>> jnp.atan2(x1, q2)
Quantity['angle'](Array(0.32175055, dtype=float32, ...), unit='rad')
"""
y_ = ustrip(one, y)
return Quantity(lax.atan2(x, y_), unit=radian)
yv = ustrip(one, y)
return Quantity(lax.atan2(x, yv), unit=radian)


# ---------------------------
Expand All @@ -579,13 +580,13 @@ def _atan2_p_aqv(x: AbstractQuantity, y: ArrayLike) -> AbstractQuantity:
>>> import quaxed.numpy as jnp
>>> from unxt.quantity import UncheckedQuantity
>>> q1 = UncheckedQuantity(1.0, "")
>>> x2 = jnp.asarray(3.0)
>>> x2 = jnp.asarray(3)
>>> jnp.atan2(q1, x2)
UncheckedQuantity(Array(0.32175055, dtype=float32, ...), unit='rad')
"""
x_ = ustrip(one, x)
return type_np(x)(lax.atan2(x_, y), unit=radian)
xv = ustrip(one, x)
return type_np(x)(lax.atan2(xv, y), unit=radian)


@register(lax.atan2_p)
Expand All @@ -599,13 +600,13 @@ def _atan2_p_qv(
>>> import quaxed.numpy as jnp
>>> from unxt.quantity import Quantity
>>> q1 = Quantity(1.0, "")
>>> x2 = jnp.asarray(3.0)
>>> x2 = jnp.asarray(3)
>>> jnp.atan2(q1, x2)
Quantity['angle'](Array(0.32175055, dtype=float32, ...), unit='rad')
"""
x_ = ustrip(one, x)
return type_np(x)(lax.atan2(x_, y), unit=radian)
xv = ustrip(one, x)
return type_np(x)(lax.atan2(xv, y), unit=radian)


# ==============================================================================
Expand Down Expand Up @@ -761,13 +762,19 @@ def _clamp_p(
>>> lax.clamp(min, q, max)
UncheckedQuantity(Array([0, 1, 2], dtype=int32), unit='m')
>>> jnp.clip(q.astype(float), min, max)
UncheckedQuantity(Array([0., 1., 2.], dtype=float32), unit='m')
>>> from unxt.quantity import Quantity
>>> min = Quantity(0, "m")
>>> max = Quantity(2, "m")
>>> q = Quantity([-1, 1, 3], "m")
>>> lax.clamp(min, q, max)
Quantity['length'](Array([0, 1, 2], dtype=int32), unit='m')
>>> jnp.clip(q.astype(float), min, max)
Quantity['length'](Array([0., 1., 2.], dtype=float32), unit='m')
"""
return replace(
x, value=qlax.clamp(ustrip(x.unit, min), ustrip(x), ustrip(x.unit, max))
Expand Down Expand Up @@ -958,12 +965,10 @@ def _concatenate_p_aq(*operands: AbstractQuantity, dimension: Any) -> AbstractQu
"""
operand0 = operands[0]
units_ = operand0.unit
u = operand0.unit
return replace(
operand0,
value=qlax.concatenate(
[ustrip(units_, op) for op in operands], dimension=dimension
),
value=qlax.concatenate([ustrip(u, op) for op in operands], dimension=dimension),
)


Expand Down Expand Up @@ -1484,8 +1489,8 @@ def _div_p_vq(x: ArrayLike, y: AbstractQuantity) -> AbstractQuantity:
Quantity['wavenumber'](Array([0.5, 1. , 1.5], dtype=float32), unit='1 / m')
"""
units_ = (1 / y.unit).unit # TODO: better construction of the unit
return type_np(y)(lax.div(x, ustrip(y)), unit=units_)
u = (1 / y.unit).unit # TODO: better construction of the unit
return type_np(y)(lax.div(x, ustrip(y)), unit=u)


@register(lax.div_p)
Expand All @@ -1495,21 +1500,21 @@ def _div_p_qv(x: AbstractQuantity, y: ArrayLike) -> AbstractQuantity:
Examples
--------
>>> import quaxed.numpy as jnp
>>> y = jnp.asarray([1.0, 2, 3])
>>> y = jnp.asarray([1, 2, 3])
>>> from unxt.quantity import UncheckedQuantity
>>> q = UncheckedQuantity(6.0, "m")
>>> jnp.divide(q, y)
UncheckedQuantity(Array([6., 3., 2.], dtype=float32), unit='m')
UncheckedQuantity(Array([6., 3., 2.], dtype=float32, ...), unit='m')
>>> q / y
UncheckedQuantity(Array([6., 3., 2.], dtype=float32), unit='m')
UncheckedQuantity(Array([6., 3., 2.], dtype=float32, ...), unit='m')
>>> from unxt.quantity import Quantity
>>> q = Quantity(6.0, "m")
>>> jnp.divide(q, y)
Quantity['length'](Array([6., 3., 2.], dtype=float32), unit='m')
Quantity['length'](Array([6., 3., 2.], dtype=float32, ...), unit='m')
>>> q / y
Quantity['length'](Array([6., 3., 2.], dtype=float32), unit='m')
Quantity['length'](Array([6., 3., 2.], dtype=float32, ...), unit='m')
"""
return replace(x, value=qlax.div(ustrip(x), y))
Expand Down Expand Up @@ -2536,12 +2541,22 @@ def _lt_p_qq(x: AbstractQuantity, y: AbstractQuantity, /) -> ArrayLike:
Array([ True, False, False], dtype=bool)
"""
# Check if the units are convertible.
x = eqx.error_if( # TODO: customize Exception type
x,
not is_unit_convertible(x.unit, y.unit),
f"Cannot compare Q(x, {x.unit}) < Q(y, {y.unit}).",
)
return qlax.lt(ustrip(x), ustrip(x.unit, y)) # re-dispatch on the values
# Strip the units to compare the values.
xv = ustrip(x)
yv = ustrip(x.unit, y) # this can change the dtype
# `lax.lt` requires that the dtypes are the same. Since `ustrip` can change
# the dtype, we need to special-case the situation where the dtypes started
# off the same, but `ustrip` changed them.
if x.dtype == y.dtype and xv.dtype != yv.dtype:
xv, yv = promote_dtypes(xv, yv)

return qlax.lt(xv, yv) # re-dispatch on the values


@register(lax.lt_p)
Expand Down
33 changes: 33 additions & 0 deletions src/unxt/_src/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,12 @@

from typing import TYPE_CHECKING, Any, cast

from jax import dtypes
from jaxtyping import ArrayLike
from quax import quaxify

import quaxed.lax as qlax

if TYPE_CHECKING:
from typing import Self

Expand Down Expand Up @@ -39,3 +45,30 @@ def __new__(cls, /, *_: Any, **__: Any) -> "Self":
self = object.__new__(cls)
_singleton_insts[cls] = self
return self


@quaxify # type: ignore[misc]
def promote_dtypes(*arrays: ArrayLike) -> tuple[ArrayLike, ...]:
"""Promotes all input arrays to a common dtype.
Examples
--------
>>> import jax.numpy as jnp
>>> import unxt as u
>>> x1 = jnp.array([1, 2, 3], dtype=jnp.int32)
>>> x2 = jnp.array([4, 5, 6], dtype=jnp.float32)
>>> x1, x2 = promote_dtypes(x1, x2)
>>> x1.dtype, x2.dtype
(dtype('float32'), dtype('float32'))
>>> q1 = u.Quantity.from_([1, 2, 3], "m", dtype=int)
>>> q2 = u.Quantity([4.0, 5, 6], unit="km")
>>> q1, q2 = promote_dtypes(q1, q2)
>>> q1.dtype, q2.dtype
(dtype('float32'), dtype('float32'))
"""
common_dtype = dtypes.result_type(*arrays)
return tuple(qlax.convert_element_type(arr, common_dtype) for arr in arrays)

0 comments on commit d0c0804

Please sign in to comment.