From 791e48e117a8cca0c70f3ca1a2243619bfc95175 Mon Sep 17 00:00:00 2001 From: Nathaniel Starkman Date: Sun, 22 Dec 2024 13:40:42 -0800 Subject: [PATCH] =?UTF-8?q?Backport=20PR=20#348:=20=E2=99=BB=EF=B8=8F=20re?= =?UTF-8?q?factor(quantity):=20ensure=20abstract=20quantities=20are=20abst?= =?UTF-8?q?ract?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/unxt/_src/quantity/base.py | 6 +++--- src/unxt/_src/quantity/base_parametric.py | 6 ++---- src/unxt/_src/quantity/core.py | 12 +++++++++++- src/unxt/_src/quantity/fast.py | 12 ++++++++++++ src/unxt/_src/quantity/mixins.py | 5 +++-- tests/unit/test_quantity.py | 4 +++- 6 files changed, 34 insertions(+), 11 deletions(-) diff --git a/src/unxt/_src/quantity/base.py b/src/unxt/_src/quantity/base.py index 5d40c61f..b38c7751 100644 --- a/src/unxt/_src/quantity/base.py +++ b/src/unxt/_src/quantity/base.py @@ -23,7 +23,7 @@ from .api import is_unit_convertible, uconvert, ustrip from .mixins import AstropyQuantityCompatMixin, IPythonReprMixin, NumPyCompatMixin -from unxt._src.units import unit as parse_unit, unit_of +from unxt._src.units import unit_of from unxt._src.units.api import AbstractUnits if TYPE_CHECKING: @@ -103,10 +103,10 @@ class AbstractQuantity( """ - value: Shaped[Array, "*shape"] = eqx.field(converter=jax.numpy.asarray) + value: eqx.AbstractVar[Shaped[Array, "*shape"]] """The value of the `AbstractQuantity`.""" - unit: AbstractUnits = eqx.field(static=True, converter=parse_unit) + unit: eqx.AbstractVar[AbstractUnits] """The unit associated with this value.""" # --------------------------------------------------------------- diff --git a/src/unxt/_src/quantity/base_parametric.py b/src/unxt/_src/quantity/base_parametric.py index 0aba78fa..b6ee940f 100644 --- a/src/unxt/_src/quantity/base_parametric.py +++ b/src/unxt/_src/quantity/base_parametric.py @@ -8,8 +8,6 @@ from typing import Any import equinox as eqx -import jax -import jax.core from astropy.units import PhysicalType, Unit from jaxtyping import Array, ArrayLike, Shaped from plum import dispatch, parametric, type_nonparametric, type_unparametrized @@ -30,10 +28,10 @@ class AbstractParametricQuantity(AbstractQuantity): """ - value: Shaped[Array, "*shape"] = eqx.field(converter=jax.numpy.asarray) + value: eqx.AbstractVar[Shaped[Array, "*shape"]] """The value of the `AbstractQuantity`.""" - unit: Unit = eqx.field(static=True, converter=parse_unit) + unit: eqx.AbstractVar[Unit] """The unit associated with this value.""" def __post_init__(self) -> None: diff --git a/src/unxt/_src/quantity/core.py b/src/unxt/_src/quantity/core.py index 5f941683..cfa48fb2 100644 --- a/src/unxt/_src/quantity/core.py +++ b/src/unxt/_src/quantity/core.py @@ -6,11 +6,15 @@ from dataclasses import replace from typing import final -from jaxtyping import ArrayLike +import equinox as eqx +import jax +from jaxtyping import Array, ArrayLike, Shaped from plum import parametric from .base import AbstractQuantity from .base_parametric import AbstractParametricQuantity +from unxt._src.units import unit as parse_unit +from unxt._src.units.api import AbstractUnits @final @@ -102,6 +106,12 @@ class Quantity(AbstractParametricQuantity): """ + value: Shaped[Array, "*shape"] = eqx.field(converter=jax.numpy.asarray) + """The value of the `AbstractQuantity`.""" + + unit: AbstractUnits = eqx.field(static=True, converter=parse_unit) + """The unit associated with this value.""" + @AbstractQuantity.__mod__.dispatch # type: ignore[misc] def mod(self: Quantity["dimensionless"], other: ArrayLike) -> Quantity["dimensionless"]: diff --git a/src/unxt/_src/quantity/fast.py b/src/unxt/_src/quantity/fast.py index 31a69599..b2f0cab3 100644 --- a/src/unxt/_src/quantity/fast.py +++ b/src/unxt/_src/quantity/fast.py @@ -5,7 +5,13 @@ from typing import Any +import equinox as eqx +import jax +from jaxtyping import Array, Shaped + from .base import AbstractQuantity +from unxt._src.units import unit as parse_unit +from unxt._src.units.api import AbstractUnits class UncheckedQuantity(AbstractQuantity): @@ -14,6 +20,12 @@ class UncheckedQuantity(AbstractQuantity): This class is not parametrized by its dimensionality. """ + value: Shaped[Array, "*shape"] = eqx.field(converter=jax.numpy.asarray) + """The value of the `AbstractQuantity`.""" + + unit: AbstractUnits = eqx.field(static=True, converter=parse_unit) + """The unit associated with this value.""" + def __class_getitem__( cls: type["UncheckedQuantity"], item: Any ) -> type["UncheckedQuantity"]: diff --git a/src/unxt/_src/quantity/mixins.py b/src/unxt/_src/quantity/mixins.py index 7b9aabf3..348e7684 100644 --- a/src/unxt/_src/quantity/mixins.py +++ b/src/unxt/_src/quantity/mixins.py @@ -5,6 +5,7 @@ from collections.abc import Callable, Sequence from typing import TYPE_CHECKING, Any, cast +import equinox as eqx import numpy as np from astropy.units import CompositeUnit from jax.typing import ArrayLike @@ -23,8 +24,8 @@ class AstropyQuantityCompatMixin: """Mixin for compatibility with `astropy.units.Quantity`.""" - value: ArrayLike - unit: AbstractUnits + value: eqx.AbstractVar[ArrayLike] + unit: eqx.AbstractVar[AbstractUnits] uconvert: Callable[[Any], "unxt.quantity.AbstractQuantity"] ustrip: Callable[[Any], ArrayLike] diff --git a/tests/unit/test_quantity.py b/tests/unit/test_quantity.py index 5807029a..57e76c1f 100644 --- a/tests/unit/test_quantity.py +++ b/tests/unit/test_quantity.py @@ -14,7 +14,7 @@ from hypothesis.extra.array_api import make_strategies_namespace from hypothesis.extra.numpy import array_shapes as np_array_shapes, arrays as np_arrays from jax.dtypes import canonicalize_dtype -from jaxtyping import TypeCheckError +from jaxtyping import Array, TypeCheckError from plum import convert, parametric import quaxed.numpy as jnp @@ -504,6 +504,8 @@ def test_at(): class NewQuantity(AbstractParametricQuantity): """Quantity with a flag.""" + value: Array = eqx.field(converter=jnp.asarray) + unit: str = eqx.field(converter=u.unit) flag: bool = eqx.field(static=True, kw_only=True)