diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index c585724..d63cdd9 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -80,6 +80,7 @@ repos: files: src additional_dependencies: - plum-dispatch>=2.5.6 + - quaxed>=0.8.1 - repo: https://github.com/codespell-project/codespell rev: "v2.3.0" @@ -91,11 +92,3 @@ repos: # hooks: # - id: validate-pyproject # additional_dependencies: ["validate-pyproject-schema-store[all]"] - - - repo: local - hooks: - - id: disallow-caps - name: Disallow improper capitalization - language: pygrep - entry: PyBind|Numpy|Cmake|CCache|Github|PyTest - exclude: .pre-commit-config.yaml diff --git a/pyproject.toml b/pyproject.toml index 547e7d3..da888de 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,7 +36,7 @@ "optional-dependencies>=0.3.2", "plum-dispatch>=2.5.6", "quax>=0.0.5", - "quaxed>=0.7.1", + "quaxed>=0.8.1", "xmmutablemap>=0.1", "zeroth>=1.0.0", ] diff --git a/src/unxt/_src/experimental.py b/src/unxt/_src/experimental.py index 0a6c524..1a4a468 100644 --- a/src/unxt/_src/experimental.py +++ b/src/unxt/_src/experimental.py @@ -24,19 +24,20 @@ from collections.abc import Callable from functools import partial -from typing import Any, ParamSpec, TypeVar +from typing import Any, TypeVar, TypeVarTuple +from typing_extensions import Unpack import equinox as eqx import jax from jaxtyping import ArrayLike from plum.parametric import type_unparametrized -from .quantity import Quantity, ustrip +from .quantity import AbstractQuantity, UncheckedQuantity as FastQ, ustrip from .typing_ext import Unit from .units import unit, unit_of -P = ParamSpec("P") -R = TypeVar("R", bound=Quantity) +Args = TypeVarTuple("Args") +R = TypeVar("R", bound=AbstractQuantity) def unit_or_none(obj: Any) -> Unit | None: @@ -44,8 +45,8 @@ def unit_or_none(obj: Any) -> Unit | None: def grad( - fun: Callable[P, R], argnums: int = 0, *, units: tuple[Unit | str, ...] -) -> Callable[P, R]: + fun: Callable[[Unpack[Args]], R], argnums: int = 0, *, units: tuple[Unit | str, ...] +) -> Callable[[Unpack[Args]], R]: """Gradient of a function with units. In general, if you can use ``quax.quaxify(jax.grad(func))`` (or the @@ -76,22 +77,22 @@ def grad( # Gradient of function, stripping and adding units @partial(jax.grad, argnums=argnums) - def gradfun_mag(*args: P.args) -> ArrayLike: + def gradfun_mag(*args: Any) -> ArrayLike: args_ = ( - (a if unit is None else Quantity(a, unit)) + (a if unit is None else FastQ(a, unit)) for a, unit in zip(args, theunits, strict=True) ) - return ustrip(fun(*args_)) # type: ignore[call-arg] + return ustrip(fun(*args_)) # type: ignore[arg-type] - def gradfun(*args: P.args, **kw: P.kwargs) -> R: + def gradfun(*args: Unpack[Args]) -> R: # Get the value of the args. They are turned back into Quantity # inside the function we are taking the grad of. - args_ = tuple( + args_ = tuple( # type: ignore[var-annotated] (a if unit is None else ustrip(unit, a)) - for a, unit in zip(args, theunits, strict=True) + for a, unit in zip(args, theunits, strict=True) # type: ignore[arg-type] ) # Call the grad, returning a Quantity - value = fun(*args) # type: ignore[call-arg] + value = fun(*args) grad_value = gradfun_mag(*args_) # Adjust the Quantity by the units of the derivative # TODO: get Quantity[unit] / unit2 -> Quantity[unit/unit2] working @@ -103,8 +104,8 @@ def gradfun(*args: P.args, **kw: P.kwargs) -> R: def jacfwd( - fun: Callable[P, R], argnums: int = 0, *, units: tuple[Unit | str, ...] -) -> Callable[P, R]: + fun: Callable[[Unpack[Args]], R], argnums: int = 0, *, units: tuple[Unit | str, ...] +) -> Callable[[Unpack[Args]], R]: """Jacobian of ``fun`` evaluated column-by-column using forward-mode AD. In general, if you can use `quax.quaxify(jax.jacfwd(func))` (or the @@ -128,7 +129,7 @@ def jacfwd( >>> jacfwd_cubbe_volume = u.experimental.jacfwd(cubbe_volume, units=("m",)) >>> jacfwd_cubbe_volume(u.Quantity(2.0, "m")) - Quantity['area'](Array(12., dtype=float32, weak_type=True), unit='m2') + UncheckedQuantity(Array(12., dtype=float32, weak_type=True), unit='m2') """ argnums = eqx.error_if( @@ -140,19 +141,19 @@ def jacfwd( theunits: tuple[Unit | None, ...] = tuple(map(unit_or_none, units)) @partial(jax.jacfwd, argnums=argnums) - def jacfun_mag(*args: P.args) -> R: - args_ = ( - (a if unit is None else Quantity(a, unit)) + def jacfun_mag(*args: Any) -> R: + args_ = tuple( + (a if unit is None else FastQ(a, unit)) for a, unit in zip(args, theunits, strict=True) ) - return fun(*args_) # type: ignore[call-arg] + return fun(*args_) # type: ignore[arg-type] - def jacfun(*args: P.args, **kw: P.kwargs) -> R: + def jacfun(*args: Unpack[Args]) -> R: # Get the value of the args. They are turned back into Quantity # inside the function we are taking the Jacobian of. - args_ = tuple( + args_ = tuple( # type: ignore[var-annotated] (a if unit is None else ustrip(unit, a)) - for a, unit in zip(args, theunits, strict=True) + for a, unit in zip(args, theunits, strict=True) # type: ignore[arg-type] ) # Call the Jacobian, returning a Quantity value = jacfun_mag(*args_) @@ -167,8 +168,8 @@ def jacfun(*args: P.args, **kw: P.kwargs) -> R: def hessian( - fun: Callable[P, R], argnums: int = 0, *, units: tuple[Unit | str, ...] -) -> Callable[P, R]: + fun: Callable[[Unpack[Args]], R], argnums: int = 0, *, units: tuple[Unit | str, ...] +) -> Callable[[Unpack[Args]], R]: """Hessian. In general, if you can use `quax.quaxify(jax.hessian(func))` (or the @@ -192,25 +193,25 @@ def hessian( >>> hessian_cubbe_volume = u.experimental.hessian(cubbe_volume, units=("m",)) >>> hessian_cubbe_volume(u.Quantity(2.0, "m")) - Quantity['length'](Array(12., dtype=float32, weak_type=True), unit='m') + UncheckedQuantity(Array(12., dtype=float32, weak_type=True), unit='m') """ theunits: tuple[Unit, ...] = tuple(map(unit_or_none, units)) @partial(jax.hessian) - def hessfun_mag(*args: P.args) -> R: - args_ = ( - (a if unit is None else Quantity(a, unit)) + def hessfun_mag(*args: Any) -> R: + args_ = tuple( + (a if unit is None else FastQ(a, unit)) for a, unit in zip(args, theunits, strict=True) ) - return fun(*args_) # type: ignore[call-arg] + return fun(*args_) # type: ignore[arg-type] - def hessfun(*args: P.args, **kw: P.kwargs) -> R: + def hessfun(*args: Unpack[Args]) -> R: # Get the value of the args. They are turned back into Quantity # inside the function we are taking the hessian of. - args_ = tuple( + args_ = tuple( # type: ignore[var-annotated] (a if unit is None else ustrip(unit, a)) - for a, unit in zip(args, units, strict=True) + for a, unit in zip(args, units, strict=True) # type: ignore[arg-type] ) # Call the hessian, returning a Quantity value = hessfun_mag(*args_) diff --git a/src/unxt/_src/quantity/base.py b/src/unxt/_src/quantity/base.py index 876a714..f6cabd3 100644 --- a/src/unxt/_src/quantity/base.py +++ b/src/unxt/_src/quantity/base.py @@ -7,19 +7,20 @@ from functools import partial from types import ModuleType from typing import TYPE_CHECKING, Any, ClassVar, NoReturn, TypeAlias, TypeVar +from typing_extensions import override import equinox as eqx import jax import jax.core from astropy.units import UnitConversionError from jax._src.numpy.array_methods import _IndexUpdateHelper, _IndexUpdateRef -from jaxtyping import Array, ArrayLike, ScalarLike, Shaped +from jaxtyping import Array, ArrayLike, Bool, ScalarLike, Shaped from plum import add_promotion_rule, dispatch from quax import ArrayValue import quaxed.numpy as jnp -import quaxed.operator as qoperator -from dataclassish import fields, replace +from dataclassish import replace +from quaxed.experimental import arrayish from .api import is_unit_convertible, uconvert, ustrip from .mixins import AstropyQuantityCompatMixin, IPythonReprMixin, NumPyCompatMixin @@ -55,7 +56,16 @@ class AbstractQuantity( AstropyQuantityCompatMixin, NumPyCompatMixin, IPythonReprMixin, - ArrayValue, # type: ignore[misc] + ArrayValue, + arrayish.NumpyBinaryOpsMixin[Any, "AbstractQuantity"], + arrayish.NumpyComparisonMixin[Any, Bool[Array, "*shape"]], # TODO: shape hint + arrayish.NumpyUnaryMixin["AbstractQuantity"], + arrayish.NumpyRoundMixin["AbstractQuantity"], + arrayish.NumpyTruncMixin["AbstractQuantity"], + arrayish.NumpyFloorMixin["AbstractQuantity"], + arrayish.NumpyCeilMixin["AbstractQuantity"], + arrayish.LaxLenMixin, + arrayish.LaxLengthHintMixin, ): """Represents an array, with each axis bound to a name. @@ -175,7 +185,7 @@ def materialise(self) -> NoReturn: raise RuntimeError(msg) def aval(self) -> jax.core.ShapedArray: - return jax.core.get_aval(self.value) + return jax.core.get_aval(self.value) # type: ignore[no-untyped-call] def enable_materialise(self, _: bool = True) -> "Self": # noqa: FBT001, FBT002 return replace(self, value=self.value, unit=self.unit) @@ -289,47 +299,6 @@ def T(self) -> "AbstractQuantity": # noqa: N802 # --------------------------------------------------------------- # arithmetic operators - def __pos__(self) -> "AbstractQuantity": - """Return the value of the array. - - Examples - -------- - >>> import unxt as u - >>> q = u.Quantity(1, "m") - >>> +q - Quantity['length'](Array(1, dtype=int32, ...), unit='m') - - """ - return replace(self, value=+self.value) # pylint: disable=E1130 - - def __neg__(self) -> "AbstractQuantity": - """Negate the value of the array. - - Examples - -------- - >>> import unxt as u - >>> q = u.Quantity(1, "m") - >>> -q - Quantity['length'](Array(-1, dtype=int32, ...), unit='m') - - """ - return replace(self, value=-self.value) # pylint: disable=E1130 - - __add__ = qoperator.add - __radd__ = _flip_binop(qoperator.add) - - __sub__ = qoperator.sub - __rsub__ = _flip_binop(qoperator.sub) - - __mul__ = qoperator.mul - __rmul__ = _flip_binop(qoperator.mul) - - __truediv__ = qoperator.truediv - __rtruediv__ = _flip_binop(qoperator.truediv) - - __floordiv__ = qoperator.floordiv - __rfloordiv__ = _flip_binop(qoperator.floordiv) - @dispatch def __mod__(self: "AbstractQuantity", other: Any) -> "AbstractQuantity": """Take the modulus. @@ -363,46 +332,12 @@ def __rmod__(self, other: Any) -> Any: """ return self % other - __pow__ = qoperator.pow - __rpow__ = _flip_binop(qoperator.pow) - - # --------------------------------------------------------------- - # array operators - - __matmul__ = qoperator.matmul - __rmatmul__ = _flip_binop(qoperator.matmul) - - # --------------------------------------------------------------- - # bitwise operators - # TODO: handle edge cases, e.g. boolean Quantity, not in Astropy - - # __invert__ = qoperator.invert - # __and__ = qoperator.and_ - # __rand__ = _flip_binop(qoperator.and_) - # __or__ = qoperator.or_ - # __ror__ = _flip_binop(qoperator.or_) - # __xor__ = qoperator.xor - # __rxor__ = _flip_binop(qoperator.xor) - # __lshift__ = qoperator.lshift - # __rlshift__ = _flip_binop(qoperator.lshift) - # __rshift__ = qoperator.rshift - # __rrshift__ = _flip_binop(qoperator.rshift) - - # --------------------------------------------------------------- - # comparison operators - - __lt__ = bool_op(jnp.less) - __le__ = bool_op(jnp.less_equal) - __eq__ = bool_op(jnp.equal) - __ge__ = bool_op(jnp.greater_equal) - __gt__ = bool_op(jnp.greater) - __ne__ = bool_op(jnp.not_equal) + # required to override mixin methods + __eq__ = arrayish.NumpyEqMixin.__eq__ # --------------------------------------------------------------- # methods - __abs__ = qoperator.abs - def __bool__(self) -> bool: """Convert a zero-dimensional array to a Python bool object. @@ -538,29 +473,6 @@ def __iter__(self) -> Any: """ yield from (self[i] for i in range(len(self.value))) - def __len__(self) -> int: - """Return the length of the array. - - Examples - -------- - >>> import unxt as u - - Length of an unsized array: - - >>> try: - ... len(u.Quantity(1, "m")) - ... except TypeError as e: - ... print(e) - len() of unsized object - - Length of a sized array: - - >>> len(u.Quantity([1, 2, 3], "m")) - 3 - - """ - return len(self.value) - def argmax(self, *args: Any, **kwargs: Any) -> Array: """Return the indices of the maximum value. @@ -605,7 +517,7 @@ def astype(self, *args: Any, **kwargs: Any) -> "AbstractQuantity": @partial(property, doc=jax.Array.at.__doc__) def at(self) -> "_QuantityIndexUpdateHelper": - return _QuantityIndexUpdateHelper(self) + return _QuantityIndexUpdateHelper(self) # type: ignore[no-untyped-call] def block_until_ready(self) -> "AbstractQuantity": """Block until the array is ready. @@ -757,25 +669,6 @@ def squeeze(self, *args: Any, **kwargs: Any) -> "AbstractQuantity": # =============================================================== # Python stuff - def __hash__(self) -> int: - """Hash the object as the tuple of its field values. - - This raises a `TypeError` if the object is unhashable, - which JAX arrays are. - - Examples - -------- - >>> import unxt as u - >>> q1 = u.Quantity(1, "m") - >>> try: - ... hash(q1) - ... except TypeError as e: - ... print(e) - unhashable type: ... - - """ - return hash(tuple(getattr(self, f.name) for f in fields(self))) - def __repr__(self) -> str: return f"{type(self).__name__}({self.value!r}, unit={self.unit.to_string()!r})" @@ -976,9 +869,9 @@ def from_( # runtime-checkable type annotation in `AbstractQuantity.at`. # `_QuantityIndexUpdateRef` is defined after `AbstractQuantity` because it # references `AbstractQuantity` in its runtime-checkable type annotations. -class _QuantityIndexUpdateHelper(_IndexUpdateHelper): # type: ignore[misc] +class _QuantityIndexUpdateHelper(_IndexUpdateHelper): def __getitem__(self, index: Any) -> "_IndexUpdateRef": - return _QuantityIndexUpdateRef(self.array, index) + return _QuantityIndexUpdateRef(self.array, index) # type: ignore[no-untyped-call] def __repr__(self) -> str: """Return a string representation of the object. @@ -994,14 +887,15 @@ def __repr__(self) -> str: return f"_QuantityIndexUpdateHelper({self.array!r})" -class _QuantityIndexUpdateRef(_IndexUpdateRef): # type: ignore[misc] +class _QuantityIndexUpdateRef(_IndexUpdateRef): # This is a subclass of `_IndexUpdateRef` that is used to implement the `at` # attribute of `AbstractQuantity`. See also `_QuantityIndexUpdateHelper`. def __repr__(self) -> str: return super().__repr__().replace("_IndexUpdateRef", "_QuantityIndexUpdateRef") - def get( + @override + def get( # type: ignore[override] self, *, indices_are_sorted: bool = False, diff --git a/src/unxt/_src/quantity/mixins.py b/src/unxt/_src/quantity/mixins.py index 348e768..65a93c4 100644 --- a/src/unxt/_src/quantity/mixins.py +++ b/src/unxt/_src/quantity/mixins.py @@ -157,7 +157,7 @@ def _repr_html_(self) -> str: """ unit_repr = getattr(self.unit, "_repr_html_", self.unit.__repr__)() - value_repr = np.array2string(self.value, separator=", ") + value_repr = np.array2string(self.value, separator=", ") # type: ignore[arg-type] return f"{value_repr} * {unit_repr}" @@ -174,7 +174,7 @@ def _repr_latex_(self) -> str: """ unit_repr = getattr(self.unit, "_repr_latex_", self.unit.__repr__)() - value_repr = np.array2string(self.value, separator=",~") + value_repr = np.array2string(self.value, separator=",~") # type: ignore[arg-type] return f"${value_repr} \\; {unit_repr[1:-1]}$" diff --git a/src/unxt/_src/quantity/register_primitives.py b/src/unxt/_src/quantity/register_primitives.py index 5f25a50..8722a02 100644 --- a/src/unxt/_src/quantity/register_primitives.py +++ b/src/unxt/_src/quantity/register_primitives.py @@ -3,7 +3,7 @@ __all__: list[str] = [] -from collections.abc import Callable, Sequence +from collections.abc import Sequence from dataclasses import replace from math import prod from typing import Any, TypeAlias, TypeVar @@ -17,11 +17,10 @@ ) from jax import lax, numpy as jnp from jax._src.ad_util import add_any_p -from jax.core import Primitive from jaxtyping import Array, ArrayLike from plum import promote from plum.parametric import type_unparametrized as type_np -from quax import register as register_ +from quax import register from quaxed import lax as qlax @@ -37,11 +36,6 @@ Axes: TypeAlias = tuple[int, ...] -def register(primitive: Primitive, **kwargs: Any) -> Callable[[T], T]: - """`quax.register`, but makes mypy happy.""" - return register_(primitive, **kwargs) - - def _to_value_rad_or_one(q: AbstractQuantity) -> ArrayLike: return ustrip(radian if is_unit_convertible(q.unit, radian) else one, q) @@ -334,7 +328,7 @@ def _add_any_p( Quantity['length'](Array(1.5, dtype=float32, ...), unit='km') """ - return replace(x, value=add_any_p.bind(ustrip(x), ustrip(y))) + return replace(x, value=add_any_p.bind(ustrip(x), ustrip(y))) # type: ignore[no-untyped-call] # ============================================================================== @@ -1078,9 +1072,7 @@ def _cond_p_vq( >>> from unxt import Quantity """ - # print(branches) - # raise AttributeError - return lax.cond_p.bind(index, ustrip(consts), branches=branches) + return lax.cond_p.bind(index, ustrip(consts), branches=branches) # type: ignore[no-untyped-call] # ============================================================================== @@ -1119,7 +1111,8 @@ def _convert_element_type_p( """Convert the element type of a quantity.""" # TODO: examples return replace( - operand, value=lax.convert_element_type_p.bind(ustrip(operand), **kwargs) + operand, + value=lax.convert_element_type_p.bind(ustrip(operand), **kwargs), # type: ignore[no-untyped-call] ) @@ -1146,7 +1139,7 @@ def _copy_p(x: AbstractQuantity) -> AbstractQuantity: Quantity['length'](Array(1, dtype=int32, ...), unit='m') """ - return replace(x, value=lax.copy_p.bind(ustrip(x))) + return replace(x, value=lax.copy_p.bind(ustrip(x))) # type: ignore[no-untyped-call] # ============================================================================== @@ -1408,7 +1401,7 @@ def _device_put_p(x: AbstractQuantity, **kwargs: Any) -> AbstractQuantity: Quantity['length'](Array(1, dtype=int32, ...), unit='m') """ - return jt.map(lambda y: lax.device_put_p.bind(y, **kwargs), x) + return jt.map(lambda y: lax.device_put_p.bind(y, **kwargs), x) # type: ignore[no-untyped-call] # ============================================================================== @@ -1865,7 +1858,7 @@ def _exp2_p(x: AbstractQuantity) -> AbstractQuantity: Quantity['dimensionless'](Array(8., dtype=float32, ...), unit='') """ - return replace(x, value=qlax.exp2(ustrip(one, x))) + return replace(x, value=qlax.exp2(ustrip(one, x))) # type: ignore[attr-defined] # ============================================================================== @@ -3110,6 +3103,30 @@ def _neg_p(x: AbstractQuantity) -> AbstractQuantity: return replace(x, value=qlax.neg(ustrip(x))) +# ============================================================================= + + +@register(lax.not_p) +def _not_p(x: AbstractQuantity) -> AbstractQuantity: + """Logical negation of a quantity. + + Examples + -------- + >>> from unxt.quantity import UncheckedQuantity + + >>> q = UncheckedQuantity(1, "") + >>> ~q + UncheckedQuantity(Array(-2, dtype=int32, weak_type=True), unit='') + + >>> from unxt.quantity import Quantity + >>> q = Quantity(1, "") + >>> ~q + Quantity['dimensionless'](Array(-2, dtype=int32, weak_type=True), unit='') + + """ + return replace(x, value=qlax.bitwise_not(ustrip(one, x))) + + # ============================================================================== @@ -3437,7 +3454,7 @@ def _scan_p( u = unit_of(arg0) arg0_ = ustrip(u, arg0) arg1_ = ustrip(u, arg1) - return lax.scan_p.bind(arg0_, arg1_, *args, **kwargs) + return lax.scan_p.bind(arg0_, arg1_, *args, **kwargs) # type: ignore[no-untyped-call] # ============================================================================== @@ -3596,7 +3613,7 @@ def _select_n_p_jqq(which: ArrayLike, *cases: AbstractQuantity) -> AbstractQuant dtypes = tuple(case.dtype for case in cases) casesv = promote_dtypes_if_needed(dtypes, *(ustrip(u, case) for case in cases)) - return replace(cases[0], value=qlax.select_n(which, *casesv)) + return replace(cases[0], value=qlax.select_n(which, *casesv)) # type: ignore[arg-type] # ============================================================================== @@ -3719,7 +3736,7 @@ def _sort_p_two_operands( is_stable: bool, num_keys: int, ) -> tuple[AbstractQuantity, ArrayLike]: - out0, out1 = lax.sort_p.bind( + out0, out1 = lax.sort_p.bind( # type: ignore[no-untyped-call] ustrip(operand0), operand1, dimension=dimension, @@ -3734,7 +3751,7 @@ def _sort_p_two_operands( def _sort_p_one_operand( operand: AbstractQuantity, *, dimension: int, is_stable: bool, num_keys: int ) -> tuple[AbstractQuantity]: - (out,) = lax.sort_p.bind( + (out,) = lax.sort_p.bind( # type: ignore[no-untyped-call] ustrip(operand), dimension=dimension, is_stable=is_stable, num_keys=num_keys ) return (type_np(operand)(out, unit=operand.unit),) diff --git a/src/unxt/_src/utils.py b/src/unxt/_src/utils.py index ccb0133..08b7336 100644 --- a/src/unxt/_src/utils.py +++ b/src/unxt/_src/utils.py @@ -50,7 +50,9 @@ def __new__(cls, /, *_: Any, **__: Any) -> "Self": class HasDType(Protocol): """Protocol for objects that have a dtype attribute.""" - dtype: DType + @property + def dtype(self) -> DType: + """The dtype of the object.""" def promote_dtypes(*arrays: HasDType) -> tuple[HasDType, ...]: @@ -77,7 +79,7 @@ def promote_dtypes(*arrays: HasDType) -> tuple[HasDType, ...]: """ common_dtype = dtypes.result_type(*arrays) # TODO: check if this copies. - return tuple(qlax.convert_element_type(arr, common_dtype) for arr in arrays) + return tuple(qlax.convert_element_type(arr, common_dtype) for arr in arrays) # type: ignore[arg-type] def promote_dtypes_if_needed( diff --git a/tests/unit/test_quantity.py b/tests/unit/test_quantity.py index 57e76c1..fa72ed1 100644 --- a/tests/unit/test_quantity.py +++ b/tests/unit/test_quantity.py @@ -138,9 +138,14 @@ def test_getitem(): def test_len(): """Test the ``len(Quantity)`` method.""" + # Length 3 q = u.Quantity([1, 2, 3], "m") assert len(q) == 3 + # Scalar + q = u.Quantity(1, "m") + assert len(q) == 0 + @pytest.mark.skip("TODO") def test_add(): diff --git a/uv.lock b/uv.lock index 15bb138..41d26e7 100644 --- a/uv.lock +++ b/uv.lock @@ -1541,6 +1541,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/a1/45/582749493b8eddd0933b2d8de39b2f0cda20b3475a9d16b2447a804fd6b3/optional_dependencies-0.3.2-py3-none-any.whl", hash = "sha256:e377389e9db9d54e42b04319ba71cdd256dde561ad8d4e89cbb5847d1f8bae3f", size = 8506 }, ] +[[package]] +name = "optype" +version = "0.8.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "typing-extensions", marker = "python_full_version < '3.13' or (extra == 'extra-4-unxt-cpu' and extra == 'extra-4-unxt-cuda') or (extra == 'extra-4-unxt-cpu' and extra == 'extra-4-unxt-cuda12') or (extra == 'extra-4-unxt-cpu' and extra == 'extra-4-unxt-cuda12-local') or (extra == 'extra-4-unxt-cpu' and extra == 'extra-4-unxt-k8s') or (extra == 'extra-4-unxt-cpu' and extra == 'extra-4-unxt-rocm') or (extra == 'extra-4-unxt-cuda' and extra == 'extra-4-unxt-cuda12') or (extra == 'extra-4-unxt-cuda' and extra == 'extra-4-unxt-cuda12-local') or (extra == 'extra-4-unxt-cuda' and extra == 'extra-4-unxt-k8s') or (extra == 'extra-4-unxt-cuda' and extra == 'extra-4-unxt-rocm') or (extra == 'extra-4-unxt-cuda12' and extra == 'extra-4-unxt-cuda12-local') or (extra == 'extra-4-unxt-cuda12' and extra == 'extra-4-unxt-k8s') or (extra == 'extra-4-unxt-cuda12' and extra == 'extra-4-unxt-rocm') or (extra == 'extra-4-unxt-cuda12-local' and extra == 'extra-4-unxt-k8s') or (extra == 'extra-4-unxt-cuda12-local' and extra == 'extra-4-unxt-rocm') or (extra == 'extra-4-unxt-k8s' and extra == 'extra-4-unxt-rocm')" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/89/42/543e02c72aba7ebe78adb76bbfbed1bc1314eba633ad453984948e5a5f46/optype-0.8.0.tar.gz", hash = "sha256:8cbfd452d6f06c7c70502048f38a0d5451bc601054d3a577dd09c7d6363950e1", size = 85295 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/86/ff/604be975eb0e9fd02358cdacf496f4411db97bffc27279ce260e8f50aba4/optype-0.8.0-py3-none-any.whl", hash = "sha256:90a7760177f2e7feae379a60445fceec37b932b75a00c3d96067497573c5e84d", size = 74228 }, +] + [[package]] name = "packaging" version = "24.1" @@ -2051,18 +2063,19 @@ wheels = [ [[package]] name = "quaxed" -version = "0.7.1" +version = "0.8.1" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "jax" }, { name = "jaxlib" }, { name = "jaxtyping" }, + { name = "optype" }, { name = "plum-dispatch" }, { name = "quax" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/aa/c6/01e425d053ee37ae2fd2cc736aaf418088c0b71c402719d06bda768bdf83/quaxed-0.7.1.tar.gz", hash = "sha256:bbc3596be9211324c2e5045e7fd919281f1a76eab63ad0dd1c8027378b1760f9", size = 123702 } +sdist = { url = "https://files.pythonhosted.org/packages/75/17/80c24f8f21e764f2819182eaf032d620fbdd72df56388929ac4781347d02/quaxed-0.8.1.tar.gz", hash = "sha256:3856a778360cf74f01860f6774bb699a1a49b2b73c4ec7df13dd518f144084fc", size = 131590 } wheels = [ - { url = "https://files.pythonhosted.org/packages/0a/63/a28e96b31fc27edac3c62d4359b82bdcb61da978b6a891061309a5bfb422/quaxed-0.7.1-py3-none-any.whl", hash = "sha256:1e1f66f3a523ff87f0a70b69c7079d5809f23091636747d8af4a6a99fa9925ba", size = 36547 }, + { url = "https://files.pythonhosted.org/packages/cd/54/3694e96202a6fea6b197071e4d874ebda49dca5c9eb53434f7c3310b3f06/quaxed-0.8.1-py3-none-any.whl", hash = "sha256:0ce637a7154ccaf7d045caeebb0b7d1433c1f75b6d20a233a29ea99810dbd0ba", size = 48267 }, ] [[package]] @@ -2614,7 +2627,7 @@ wheels = [ [[package]] name = "unxt" -version = "1.0.1.dev26+g1d69f30.d20250115" +version = "1.0.1.dev27+gef1e1d6.d20250121" source = { editable = "." } dependencies = [ { name = "astropy" }, @@ -2767,7 +2780,7 @@ requires-dist = [ { name = "optional-dependencies", specifier = ">=0.3.2" }, { name = "plum-dispatch", specifier = ">=2.5.6" }, { name = "quax", specifier = ">=0.0.5" }, - { name = "quaxed", specifier = ">=0.7.1" }, + { name = "quaxed", specifier = ">=0.8.1" }, { name = "xmmutablemap", specifier = ">=0.1" }, { name = "zeroth", specifier = ">=1.0.0" }, ]