diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index c585724..d63cdd9 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -80,6 +80,7 @@ repos:
files: src
additional_dependencies:
- plum-dispatch>=2.5.6
+ - quaxed>=0.8.1
- repo: https://github.com/codespell-project/codespell
rev: "v2.3.0"
@@ -91,11 +92,3 @@ repos:
# hooks:
# - id: validate-pyproject
# additional_dependencies: ["validate-pyproject-schema-store[all]"]
-
- - repo: local
- hooks:
- - id: disallow-caps
- name: Disallow improper capitalization
- language: pygrep
- entry: PyBind|Numpy|Cmake|CCache|Github|PyTest
- exclude: .pre-commit-config.yaml
diff --git a/pyproject.toml b/pyproject.toml
index 547e7d3..da888de 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -36,7 +36,7 @@
"optional-dependencies>=0.3.2",
"plum-dispatch>=2.5.6",
"quax>=0.0.5",
- "quaxed>=0.7.1",
+ "quaxed>=0.8.1",
"xmmutablemap>=0.1",
"zeroth>=1.0.0",
]
diff --git a/src/unxt/_src/experimental.py b/src/unxt/_src/experimental.py
index 0a6c524..1a4a468 100644
--- a/src/unxt/_src/experimental.py
+++ b/src/unxt/_src/experimental.py
@@ -24,19 +24,20 @@
from collections.abc import Callable
from functools import partial
-from typing import Any, ParamSpec, TypeVar
+from typing import Any, TypeVar, TypeVarTuple
+from typing_extensions import Unpack
import equinox as eqx
import jax
from jaxtyping import ArrayLike
from plum.parametric import type_unparametrized
-from .quantity import Quantity, ustrip
+from .quantity import AbstractQuantity, UncheckedQuantity as FastQ, ustrip
from .typing_ext import Unit
from .units import unit, unit_of
-P = ParamSpec("P")
-R = TypeVar("R", bound=Quantity)
+Args = TypeVarTuple("Args")
+R = TypeVar("R", bound=AbstractQuantity)
def unit_or_none(obj: Any) -> Unit | None:
@@ -44,8 +45,8 @@ def unit_or_none(obj: Any) -> Unit | None:
def grad(
- fun: Callable[P, R], argnums: int = 0, *, units: tuple[Unit | str, ...]
-) -> Callable[P, R]:
+ fun: Callable[[Unpack[Args]], R], argnums: int = 0, *, units: tuple[Unit | str, ...]
+) -> Callable[[Unpack[Args]], R]:
"""Gradient of a function with units.
In general, if you can use ``quax.quaxify(jax.grad(func))`` (or the
@@ -76,22 +77,22 @@ def grad(
# Gradient of function, stripping and adding units
@partial(jax.grad, argnums=argnums)
- def gradfun_mag(*args: P.args) -> ArrayLike:
+ def gradfun_mag(*args: Any) -> ArrayLike:
args_ = (
- (a if unit is None else Quantity(a, unit))
+ (a if unit is None else FastQ(a, unit))
for a, unit in zip(args, theunits, strict=True)
)
- return ustrip(fun(*args_)) # type: ignore[call-arg]
+ return ustrip(fun(*args_)) # type: ignore[arg-type]
- def gradfun(*args: P.args, **kw: P.kwargs) -> R:
+ def gradfun(*args: Unpack[Args]) -> R:
# Get the value of the args. They are turned back into Quantity
# inside the function we are taking the grad of.
- args_ = tuple(
+ args_ = tuple( # type: ignore[var-annotated]
(a if unit is None else ustrip(unit, a))
- for a, unit in zip(args, theunits, strict=True)
+ for a, unit in zip(args, theunits, strict=True) # type: ignore[arg-type]
)
# Call the grad, returning a Quantity
- value = fun(*args) # type: ignore[call-arg]
+ value = fun(*args)
grad_value = gradfun_mag(*args_)
# Adjust the Quantity by the units of the derivative
# TODO: get Quantity[unit] / unit2 -> Quantity[unit/unit2] working
@@ -103,8 +104,8 @@ def gradfun(*args: P.args, **kw: P.kwargs) -> R:
def jacfwd(
- fun: Callable[P, R], argnums: int = 0, *, units: tuple[Unit | str, ...]
-) -> Callable[P, R]:
+ fun: Callable[[Unpack[Args]], R], argnums: int = 0, *, units: tuple[Unit | str, ...]
+) -> Callable[[Unpack[Args]], R]:
"""Jacobian of ``fun`` evaluated column-by-column using forward-mode AD.
In general, if you can use `quax.quaxify(jax.jacfwd(func))` (or the
@@ -128,7 +129,7 @@ def jacfwd(
>>> jacfwd_cubbe_volume = u.experimental.jacfwd(cubbe_volume, units=("m",))
>>> jacfwd_cubbe_volume(u.Quantity(2.0, "m"))
- Quantity['area'](Array(12., dtype=float32, weak_type=True), unit='m2')
+ UncheckedQuantity(Array(12., dtype=float32, weak_type=True), unit='m2')
"""
argnums = eqx.error_if(
@@ -140,19 +141,19 @@ def jacfwd(
theunits: tuple[Unit | None, ...] = tuple(map(unit_or_none, units))
@partial(jax.jacfwd, argnums=argnums)
- def jacfun_mag(*args: P.args) -> R:
- args_ = (
- (a if unit is None else Quantity(a, unit))
+ def jacfun_mag(*args: Any) -> R:
+ args_ = tuple(
+ (a if unit is None else FastQ(a, unit))
for a, unit in zip(args, theunits, strict=True)
)
- return fun(*args_) # type: ignore[call-arg]
+ return fun(*args_) # type: ignore[arg-type]
- def jacfun(*args: P.args, **kw: P.kwargs) -> R:
+ def jacfun(*args: Unpack[Args]) -> R:
# Get the value of the args. They are turned back into Quantity
# inside the function we are taking the Jacobian of.
- args_ = tuple(
+ args_ = tuple( # type: ignore[var-annotated]
(a if unit is None else ustrip(unit, a))
- for a, unit in zip(args, theunits, strict=True)
+ for a, unit in zip(args, theunits, strict=True) # type: ignore[arg-type]
)
# Call the Jacobian, returning a Quantity
value = jacfun_mag(*args_)
@@ -167,8 +168,8 @@ def jacfun(*args: P.args, **kw: P.kwargs) -> R:
def hessian(
- fun: Callable[P, R], argnums: int = 0, *, units: tuple[Unit | str, ...]
-) -> Callable[P, R]:
+ fun: Callable[[Unpack[Args]], R], argnums: int = 0, *, units: tuple[Unit | str, ...]
+) -> Callable[[Unpack[Args]], R]:
"""Hessian.
In general, if you can use `quax.quaxify(jax.hessian(func))` (or the
@@ -192,25 +193,25 @@ def hessian(
>>> hessian_cubbe_volume = u.experimental.hessian(cubbe_volume, units=("m",))
>>> hessian_cubbe_volume(u.Quantity(2.0, "m"))
- Quantity['length'](Array(12., dtype=float32, weak_type=True), unit='m')
+ UncheckedQuantity(Array(12., dtype=float32, weak_type=True), unit='m')
"""
theunits: tuple[Unit, ...] = tuple(map(unit_or_none, units))
@partial(jax.hessian)
- def hessfun_mag(*args: P.args) -> R:
- args_ = (
- (a if unit is None else Quantity(a, unit))
+ def hessfun_mag(*args: Any) -> R:
+ args_ = tuple(
+ (a if unit is None else FastQ(a, unit))
for a, unit in zip(args, theunits, strict=True)
)
- return fun(*args_) # type: ignore[call-arg]
+ return fun(*args_) # type: ignore[arg-type]
- def hessfun(*args: P.args, **kw: P.kwargs) -> R:
+ def hessfun(*args: Unpack[Args]) -> R:
# Get the value of the args. They are turned back into Quantity
# inside the function we are taking the hessian of.
- args_ = tuple(
+ args_ = tuple( # type: ignore[var-annotated]
(a if unit is None else ustrip(unit, a))
- for a, unit in zip(args, units, strict=True)
+ for a, unit in zip(args, units, strict=True) # type: ignore[arg-type]
)
# Call the hessian, returning a Quantity
value = hessfun_mag(*args_)
diff --git a/src/unxt/_src/quantity/base.py b/src/unxt/_src/quantity/base.py
index 876a714..f6cabd3 100644
--- a/src/unxt/_src/quantity/base.py
+++ b/src/unxt/_src/quantity/base.py
@@ -7,19 +7,20 @@
from functools import partial
from types import ModuleType
from typing import TYPE_CHECKING, Any, ClassVar, NoReturn, TypeAlias, TypeVar
+from typing_extensions import override
import equinox as eqx
import jax
import jax.core
from astropy.units import UnitConversionError
from jax._src.numpy.array_methods import _IndexUpdateHelper, _IndexUpdateRef
-from jaxtyping import Array, ArrayLike, ScalarLike, Shaped
+from jaxtyping import Array, ArrayLike, Bool, ScalarLike, Shaped
from plum import add_promotion_rule, dispatch
from quax import ArrayValue
import quaxed.numpy as jnp
-import quaxed.operator as qoperator
-from dataclassish import fields, replace
+from dataclassish import replace
+from quaxed.experimental import arrayish
from .api import is_unit_convertible, uconvert, ustrip
from .mixins import AstropyQuantityCompatMixin, IPythonReprMixin, NumPyCompatMixin
@@ -55,7 +56,16 @@ class AbstractQuantity(
AstropyQuantityCompatMixin,
NumPyCompatMixin,
IPythonReprMixin,
- ArrayValue, # type: ignore[misc]
+ ArrayValue,
+ arrayish.NumpyBinaryOpsMixin[Any, "AbstractQuantity"],
+ arrayish.NumpyComparisonMixin[Any, Bool[Array, "*shape"]], # TODO: shape hint
+ arrayish.NumpyUnaryMixin["AbstractQuantity"],
+ arrayish.NumpyRoundMixin["AbstractQuantity"],
+ arrayish.NumpyTruncMixin["AbstractQuantity"],
+ arrayish.NumpyFloorMixin["AbstractQuantity"],
+ arrayish.NumpyCeilMixin["AbstractQuantity"],
+ arrayish.LaxLenMixin,
+ arrayish.LaxLengthHintMixin,
):
"""Represents an array, with each axis bound to a name.
@@ -175,7 +185,7 @@ def materialise(self) -> NoReturn:
raise RuntimeError(msg)
def aval(self) -> jax.core.ShapedArray:
- return jax.core.get_aval(self.value)
+ return jax.core.get_aval(self.value) # type: ignore[no-untyped-call]
def enable_materialise(self, _: bool = True) -> "Self": # noqa: FBT001, FBT002
return replace(self, value=self.value, unit=self.unit)
@@ -289,47 +299,6 @@ def T(self) -> "AbstractQuantity": # noqa: N802
# ---------------------------------------------------------------
# arithmetic operators
- def __pos__(self) -> "AbstractQuantity":
- """Return the value of the array.
-
- Examples
- --------
- >>> import unxt as u
- >>> q = u.Quantity(1, "m")
- >>> +q
- Quantity['length'](Array(1, dtype=int32, ...), unit='m')
-
- """
- return replace(self, value=+self.value) # pylint: disable=E1130
-
- def __neg__(self) -> "AbstractQuantity":
- """Negate the value of the array.
-
- Examples
- --------
- >>> import unxt as u
- >>> q = u.Quantity(1, "m")
- >>> -q
- Quantity['length'](Array(-1, dtype=int32, ...), unit='m')
-
- """
- return replace(self, value=-self.value) # pylint: disable=E1130
-
- __add__ = qoperator.add
- __radd__ = _flip_binop(qoperator.add)
-
- __sub__ = qoperator.sub
- __rsub__ = _flip_binop(qoperator.sub)
-
- __mul__ = qoperator.mul
- __rmul__ = _flip_binop(qoperator.mul)
-
- __truediv__ = qoperator.truediv
- __rtruediv__ = _flip_binop(qoperator.truediv)
-
- __floordiv__ = qoperator.floordiv
- __rfloordiv__ = _flip_binop(qoperator.floordiv)
-
@dispatch
def __mod__(self: "AbstractQuantity", other: Any) -> "AbstractQuantity":
"""Take the modulus.
@@ -363,46 +332,12 @@ def __rmod__(self, other: Any) -> Any:
"""
return self % other
- __pow__ = qoperator.pow
- __rpow__ = _flip_binop(qoperator.pow)
-
- # ---------------------------------------------------------------
- # array operators
-
- __matmul__ = qoperator.matmul
- __rmatmul__ = _flip_binop(qoperator.matmul)
-
- # ---------------------------------------------------------------
- # bitwise operators
- # TODO: handle edge cases, e.g. boolean Quantity, not in Astropy
-
- # __invert__ = qoperator.invert
- # __and__ = qoperator.and_
- # __rand__ = _flip_binop(qoperator.and_)
- # __or__ = qoperator.or_
- # __ror__ = _flip_binop(qoperator.or_)
- # __xor__ = qoperator.xor
- # __rxor__ = _flip_binop(qoperator.xor)
- # __lshift__ = qoperator.lshift
- # __rlshift__ = _flip_binop(qoperator.lshift)
- # __rshift__ = qoperator.rshift
- # __rrshift__ = _flip_binop(qoperator.rshift)
-
- # ---------------------------------------------------------------
- # comparison operators
-
- __lt__ = bool_op(jnp.less)
- __le__ = bool_op(jnp.less_equal)
- __eq__ = bool_op(jnp.equal)
- __ge__ = bool_op(jnp.greater_equal)
- __gt__ = bool_op(jnp.greater)
- __ne__ = bool_op(jnp.not_equal)
+ # required to override mixin methods
+ __eq__ = arrayish.NumpyEqMixin.__eq__
# ---------------------------------------------------------------
# methods
- __abs__ = qoperator.abs
-
def __bool__(self) -> bool:
"""Convert a zero-dimensional array to a Python bool object.
@@ -538,29 +473,6 @@ def __iter__(self) -> Any:
"""
yield from (self[i] for i in range(len(self.value)))
- def __len__(self) -> int:
- """Return the length of the array.
-
- Examples
- --------
- >>> import unxt as u
-
- Length of an unsized array:
-
- >>> try:
- ... len(u.Quantity(1, "m"))
- ... except TypeError as e:
- ... print(e)
- len() of unsized object
-
- Length of a sized array:
-
- >>> len(u.Quantity([1, 2, 3], "m"))
- 3
-
- """
- return len(self.value)
-
def argmax(self, *args: Any, **kwargs: Any) -> Array:
"""Return the indices of the maximum value.
@@ -605,7 +517,7 @@ def astype(self, *args: Any, **kwargs: Any) -> "AbstractQuantity":
@partial(property, doc=jax.Array.at.__doc__)
def at(self) -> "_QuantityIndexUpdateHelper":
- return _QuantityIndexUpdateHelper(self)
+ return _QuantityIndexUpdateHelper(self) # type: ignore[no-untyped-call]
def block_until_ready(self) -> "AbstractQuantity":
"""Block until the array is ready.
@@ -757,25 +669,6 @@ def squeeze(self, *args: Any, **kwargs: Any) -> "AbstractQuantity":
# ===============================================================
# Python stuff
- def __hash__(self) -> int:
- """Hash the object as the tuple of its field values.
-
- This raises a `TypeError` if the object is unhashable,
- which JAX arrays are.
-
- Examples
- --------
- >>> import unxt as u
- >>> q1 = u.Quantity(1, "m")
- >>> try:
- ... hash(q1)
- ... except TypeError as e:
- ... print(e)
- unhashable type: ...
-
- """
- return hash(tuple(getattr(self, f.name) for f in fields(self)))
-
def __repr__(self) -> str:
return f"{type(self).__name__}({self.value!r}, unit={self.unit.to_string()!r})"
@@ -976,9 +869,9 @@ def from_(
# runtime-checkable type annotation in `AbstractQuantity.at`.
# `_QuantityIndexUpdateRef` is defined after `AbstractQuantity` because it
# references `AbstractQuantity` in its runtime-checkable type annotations.
-class _QuantityIndexUpdateHelper(_IndexUpdateHelper): # type: ignore[misc]
+class _QuantityIndexUpdateHelper(_IndexUpdateHelper):
def __getitem__(self, index: Any) -> "_IndexUpdateRef":
- return _QuantityIndexUpdateRef(self.array, index)
+ return _QuantityIndexUpdateRef(self.array, index) # type: ignore[no-untyped-call]
def __repr__(self) -> str:
"""Return a string representation of the object.
@@ -994,14 +887,15 @@ def __repr__(self) -> str:
return f"_QuantityIndexUpdateHelper({self.array!r})"
-class _QuantityIndexUpdateRef(_IndexUpdateRef): # type: ignore[misc]
+class _QuantityIndexUpdateRef(_IndexUpdateRef):
# This is a subclass of `_IndexUpdateRef` that is used to implement the `at`
# attribute of `AbstractQuantity`. See also `_QuantityIndexUpdateHelper`.
def __repr__(self) -> str:
return super().__repr__().replace("_IndexUpdateRef", "_QuantityIndexUpdateRef")
- def get(
+ @override
+ def get( # type: ignore[override]
self,
*,
indices_are_sorted: bool = False,
diff --git a/src/unxt/_src/quantity/mixins.py b/src/unxt/_src/quantity/mixins.py
index 348e768..65a93c4 100644
--- a/src/unxt/_src/quantity/mixins.py
+++ b/src/unxt/_src/quantity/mixins.py
@@ -157,7 +157,7 @@ def _repr_html_(self) -> str:
"""
unit_repr = getattr(self.unit, "_repr_html_", self.unit.__repr__)()
- value_repr = np.array2string(self.value, separator=", ")
+ value_repr = np.array2string(self.value, separator=", ") # type: ignore[arg-type]
return f"{value_repr} * {unit_repr}"
@@ -174,7 +174,7 @@ def _repr_latex_(self) -> str:
"""
unit_repr = getattr(self.unit, "_repr_latex_", self.unit.__repr__)()
- value_repr = np.array2string(self.value, separator=",~")
+ value_repr = np.array2string(self.value, separator=",~") # type: ignore[arg-type]
return f"${value_repr} \\; {unit_repr[1:-1]}$"
diff --git a/src/unxt/_src/quantity/register_primitives.py b/src/unxt/_src/quantity/register_primitives.py
index 5f25a50..8722a02 100644
--- a/src/unxt/_src/quantity/register_primitives.py
+++ b/src/unxt/_src/quantity/register_primitives.py
@@ -3,7 +3,7 @@
__all__: list[str] = []
-from collections.abc import Callable, Sequence
+from collections.abc import Sequence
from dataclasses import replace
from math import prod
from typing import Any, TypeAlias, TypeVar
@@ -17,11 +17,10 @@
)
from jax import lax, numpy as jnp
from jax._src.ad_util import add_any_p
-from jax.core import Primitive
from jaxtyping import Array, ArrayLike
from plum import promote
from plum.parametric import type_unparametrized as type_np
-from quax import register as register_
+from quax import register
from quaxed import lax as qlax
@@ -37,11 +36,6 @@
Axes: TypeAlias = tuple[int, ...]
-def register(primitive: Primitive, **kwargs: Any) -> Callable[[T], T]:
- """`quax.register`, but makes mypy happy."""
- return register_(primitive, **kwargs)
-
-
def _to_value_rad_or_one(q: AbstractQuantity) -> ArrayLike:
return ustrip(radian if is_unit_convertible(q.unit, radian) else one, q)
@@ -334,7 +328,7 @@ def _add_any_p(
Quantity['length'](Array(1.5, dtype=float32, ...), unit='km')
"""
- return replace(x, value=add_any_p.bind(ustrip(x), ustrip(y)))
+ return replace(x, value=add_any_p.bind(ustrip(x), ustrip(y))) # type: ignore[no-untyped-call]
# ==============================================================================
@@ -1078,9 +1072,7 @@ def _cond_p_vq(
>>> from unxt import Quantity
"""
- # print(branches)
- # raise AttributeError
- return lax.cond_p.bind(index, ustrip(consts), branches=branches)
+ return lax.cond_p.bind(index, ustrip(consts), branches=branches) # type: ignore[no-untyped-call]
# ==============================================================================
@@ -1119,7 +1111,8 @@ def _convert_element_type_p(
"""Convert the element type of a quantity."""
# TODO: examples
return replace(
- operand, value=lax.convert_element_type_p.bind(ustrip(operand), **kwargs)
+ operand,
+ value=lax.convert_element_type_p.bind(ustrip(operand), **kwargs), # type: ignore[no-untyped-call]
)
@@ -1146,7 +1139,7 @@ def _copy_p(x: AbstractQuantity) -> AbstractQuantity:
Quantity['length'](Array(1, dtype=int32, ...), unit='m')
"""
- return replace(x, value=lax.copy_p.bind(ustrip(x)))
+ return replace(x, value=lax.copy_p.bind(ustrip(x))) # type: ignore[no-untyped-call]
# ==============================================================================
@@ -1408,7 +1401,7 @@ def _device_put_p(x: AbstractQuantity, **kwargs: Any) -> AbstractQuantity:
Quantity['length'](Array(1, dtype=int32, ...), unit='m')
"""
- return jt.map(lambda y: lax.device_put_p.bind(y, **kwargs), x)
+ return jt.map(lambda y: lax.device_put_p.bind(y, **kwargs), x) # type: ignore[no-untyped-call]
# ==============================================================================
@@ -1865,7 +1858,7 @@ def _exp2_p(x: AbstractQuantity) -> AbstractQuantity:
Quantity['dimensionless'](Array(8., dtype=float32, ...), unit='')
"""
- return replace(x, value=qlax.exp2(ustrip(one, x)))
+ return replace(x, value=qlax.exp2(ustrip(one, x))) # type: ignore[attr-defined]
# ==============================================================================
@@ -3110,6 +3103,30 @@ def _neg_p(x: AbstractQuantity) -> AbstractQuantity:
return replace(x, value=qlax.neg(ustrip(x)))
+# =============================================================================
+
+
+@register(lax.not_p)
+def _not_p(x: AbstractQuantity) -> AbstractQuantity:
+ """Logical negation of a quantity.
+
+ Examples
+ --------
+ >>> from unxt.quantity import UncheckedQuantity
+
+ >>> q = UncheckedQuantity(1, "")
+ >>> ~q
+ UncheckedQuantity(Array(-2, dtype=int32, weak_type=True), unit='')
+
+ >>> from unxt.quantity import Quantity
+ >>> q = Quantity(1, "")
+ >>> ~q
+ Quantity['dimensionless'](Array(-2, dtype=int32, weak_type=True), unit='')
+
+ """
+ return replace(x, value=qlax.bitwise_not(ustrip(one, x)))
+
+
# ==============================================================================
@@ -3437,7 +3454,7 @@ def _scan_p(
u = unit_of(arg0)
arg0_ = ustrip(u, arg0)
arg1_ = ustrip(u, arg1)
- return lax.scan_p.bind(arg0_, arg1_, *args, **kwargs)
+ return lax.scan_p.bind(arg0_, arg1_, *args, **kwargs) # type: ignore[no-untyped-call]
# ==============================================================================
@@ -3596,7 +3613,7 @@ def _select_n_p_jqq(which: ArrayLike, *cases: AbstractQuantity) -> AbstractQuant
dtypes = tuple(case.dtype for case in cases)
casesv = promote_dtypes_if_needed(dtypes, *(ustrip(u, case) for case in cases))
- return replace(cases[0], value=qlax.select_n(which, *casesv))
+ return replace(cases[0], value=qlax.select_n(which, *casesv)) # type: ignore[arg-type]
# ==============================================================================
@@ -3719,7 +3736,7 @@ def _sort_p_two_operands(
is_stable: bool,
num_keys: int,
) -> tuple[AbstractQuantity, ArrayLike]:
- out0, out1 = lax.sort_p.bind(
+ out0, out1 = lax.sort_p.bind( # type: ignore[no-untyped-call]
ustrip(operand0),
operand1,
dimension=dimension,
@@ -3734,7 +3751,7 @@ def _sort_p_two_operands(
def _sort_p_one_operand(
operand: AbstractQuantity, *, dimension: int, is_stable: bool, num_keys: int
) -> tuple[AbstractQuantity]:
- (out,) = lax.sort_p.bind(
+ (out,) = lax.sort_p.bind( # type: ignore[no-untyped-call]
ustrip(operand), dimension=dimension, is_stable=is_stable, num_keys=num_keys
)
return (type_np(operand)(out, unit=operand.unit),)
diff --git a/src/unxt/_src/utils.py b/src/unxt/_src/utils.py
index ccb0133..08b7336 100644
--- a/src/unxt/_src/utils.py
+++ b/src/unxt/_src/utils.py
@@ -50,7 +50,9 @@ def __new__(cls, /, *_: Any, **__: Any) -> "Self":
class HasDType(Protocol):
"""Protocol for objects that have a dtype attribute."""
- dtype: DType
+ @property
+ def dtype(self) -> DType:
+ """The dtype of the object."""
def promote_dtypes(*arrays: HasDType) -> tuple[HasDType, ...]:
@@ -77,7 +79,7 @@ def promote_dtypes(*arrays: HasDType) -> tuple[HasDType, ...]:
"""
common_dtype = dtypes.result_type(*arrays)
# TODO: check if this copies.
- return tuple(qlax.convert_element_type(arr, common_dtype) for arr in arrays)
+ return tuple(qlax.convert_element_type(arr, common_dtype) for arr in arrays) # type: ignore[arg-type]
def promote_dtypes_if_needed(
diff --git a/tests/unit/test_quantity.py b/tests/unit/test_quantity.py
index 57e76c1..fa72ed1 100644
--- a/tests/unit/test_quantity.py
+++ b/tests/unit/test_quantity.py
@@ -138,9 +138,14 @@ def test_getitem():
def test_len():
"""Test the ``len(Quantity)`` method."""
+ # Length 3
q = u.Quantity([1, 2, 3], "m")
assert len(q) == 3
+ # Scalar
+ q = u.Quantity(1, "m")
+ assert len(q) == 0
+
@pytest.mark.skip("TODO")
def test_add():
diff --git a/uv.lock b/uv.lock
index 15bb138..41d26e7 100644
--- a/uv.lock
+++ b/uv.lock
@@ -1541,6 +1541,18 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/a1/45/582749493b8eddd0933b2d8de39b2f0cda20b3475a9d16b2447a804fd6b3/optional_dependencies-0.3.2-py3-none-any.whl", hash = "sha256:e377389e9db9d54e42b04319ba71cdd256dde561ad8d4e89cbb5847d1f8bae3f", size = 8506 },
]
+[[package]]
+name = "optype"
+version = "0.8.0"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "typing-extensions", marker = "python_full_version < '3.13' or (extra == 'extra-4-unxt-cpu' and extra == 'extra-4-unxt-cuda') or (extra == 'extra-4-unxt-cpu' and extra == 'extra-4-unxt-cuda12') or (extra == 'extra-4-unxt-cpu' and extra == 'extra-4-unxt-cuda12-local') or (extra == 'extra-4-unxt-cpu' and extra == 'extra-4-unxt-k8s') or (extra == 'extra-4-unxt-cpu' and extra == 'extra-4-unxt-rocm') or (extra == 'extra-4-unxt-cuda' and extra == 'extra-4-unxt-cuda12') or (extra == 'extra-4-unxt-cuda' and extra == 'extra-4-unxt-cuda12-local') or (extra == 'extra-4-unxt-cuda' and extra == 'extra-4-unxt-k8s') or (extra == 'extra-4-unxt-cuda' and extra == 'extra-4-unxt-rocm') or (extra == 'extra-4-unxt-cuda12' and extra == 'extra-4-unxt-cuda12-local') or (extra == 'extra-4-unxt-cuda12' and extra == 'extra-4-unxt-k8s') or (extra == 'extra-4-unxt-cuda12' and extra == 'extra-4-unxt-rocm') or (extra == 'extra-4-unxt-cuda12-local' and extra == 'extra-4-unxt-k8s') or (extra == 'extra-4-unxt-cuda12-local' and extra == 'extra-4-unxt-rocm') or (extra == 'extra-4-unxt-k8s' and extra == 'extra-4-unxt-rocm')" },
+]
+sdist = { url = "https://files.pythonhosted.org/packages/89/42/543e02c72aba7ebe78adb76bbfbed1bc1314eba633ad453984948e5a5f46/optype-0.8.0.tar.gz", hash = "sha256:8cbfd452d6f06c7c70502048f38a0d5451bc601054d3a577dd09c7d6363950e1", size = 85295 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/86/ff/604be975eb0e9fd02358cdacf496f4411db97bffc27279ce260e8f50aba4/optype-0.8.0-py3-none-any.whl", hash = "sha256:90a7760177f2e7feae379a60445fceec37b932b75a00c3d96067497573c5e84d", size = 74228 },
+]
+
[[package]]
name = "packaging"
version = "24.1"
@@ -2051,18 +2063,19 @@ wheels = [
[[package]]
name = "quaxed"
-version = "0.7.1"
+version = "0.8.1"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "jax" },
{ name = "jaxlib" },
{ name = "jaxtyping" },
+ { name = "optype" },
{ name = "plum-dispatch" },
{ name = "quax" },
]
-sdist = { url = "https://files.pythonhosted.org/packages/aa/c6/01e425d053ee37ae2fd2cc736aaf418088c0b71c402719d06bda768bdf83/quaxed-0.7.1.tar.gz", hash = "sha256:bbc3596be9211324c2e5045e7fd919281f1a76eab63ad0dd1c8027378b1760f9", size = 123702 }
+sdist = { url = "https://files.pythonhosted.org/packages/75/17/80c24f8f21e764f2819182eaf032d620fbdd72df56388929ac4781347d02/quaxed-0.8.1.tar.gz", hash = "sha256:3856a778360cf74f01860f6774bb699a1a49b2b73c4ec7df13dd518f144084fc", size = 131590 }
wheels = [
- { url = "https://files.pythonhosted.org/packages/0a/63/a28e96b31fc27edac3c62d4359b82bdcb61da978b6a891061309a5bfb422/quaxed-0.7.1-py3-none-any.whl", hash = "sha256:1e1f66f3a523ff87f0a70b69c7079d5809f23091636747d8af4a6a99fa9925ba", size = 36547 },
+ { url = "https://files.pythonhosted.org/packages/cd/54/3694e96202a6fea6b197071e4d874ebda49dca5c9eb53434f7c3310b3f06/quaxed-0.8.1-py3-none-any.whl", hash = "sha256:0ce637a7154ccaf7d045caeebb0b7d1433c1f75b6d20a233a29ea99810dbd0ba", size = 48267 },
]
[[package]]
@@ -2614,7 +2627,7 @@ wheels = [
[[package]]
name = "unxt"
-version = "1.0.1.dev26+g1d69f30.d20250115"
+version = "1.0.1.dev27+gef1e1d6.d20250121"
source = { editable = "." }
dependencies = [
{ name = "astropy" },
@@ -2767,7 +2780,7 @@ requires-dist = [
{ name = "optional-dependencies", specifier = ">=0.3.2" },
{ name = "plum-dispatch", specifier = ">=2.5.6" },
{ name = "quax", specifier = ">=0.0.5" },
- { name = "quaxed", specifier = ">=0.7.1" },
+ { name = "quaxed", specifier = ">=0.8.1" },
{ name = "xmmutablemap", specifier = ">=0.1" },
{ name = "zeroth", specifier = ">=1.0.0" },
]