Skip to content

Commit

Permalink
♻️ refactor(quantity): ensure abstract quantities are abstract (#348)
Browse files Browse the repository at this point in the history
  • Loading branch information
nstarman authored Dec 22, 2024
1 parent acfb136 commit 5bb134d
Show file tree
Hide file tree
Showing 6 changed files with 34 additions and 11 deletions.
6 changes: 3 additions & 3 deletions src/unxt/_src/quantity/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

from .api import is_unit_convertible, uconvert, ustrip
from .mixins import AstropyQuantityCompatMixin, IPythonReprMixin, NumPyCompatMixin
from unxt._src.units import unit as parse_unit, unit_of
from unxt._src.units import unit_of
from unxt._src.units.api import AbstractUnits

if TYPE_CHECKING:
Expand Down Expand Up @@ -103,10 +103,10 @@ class AbstractQuantity(
"""

value: Shaped[Array, "*shape"] = eqx.field(converter=jax.numpy.asarray)
value: eqx.AbstractVar[Shaped[Array, "*shape"]]
"""The value of the `AbstractQuantity`."""

unit: AbstractUnits = eqx.field(static=True, converter=parse_unit)
unit: eqx.AbstractVar[AbstractUnits]
"""The unit associated with this value."""

# ---------------------------------------------------------------
Expand Down
6 changes: 2 additions & 4 deletions src/unxt/_src/quantity/base_parametric.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@
from typing import Any

import equinox as eqx
import jax
import jax.core
from astropy.units import PhysicalType, Unit
from jaxtyping import Array, ArrayLike, Shaped
from plum import dispatch, parametric, type_nonparametric, type_unparametrized
Expand All @@ -30,10 +28,10 @@ class AbstractParametricQuantity(AbstractQuantity):
"""

value: Shaped[Array, "*shape"] = eqx.field(converter=jax.numpy.asarray)
value: eqx.AbstractVar[Shaped[Array, "*shape"]]
"""The value of the `AbstractQuantity`."""

unit: Unit = eqx.field(static=True, converter=parse_unit)
unit: eqx.AbstractVar[Unit]
"""The unit associated with this value."""

def __post_init__(self) -> None:
Expand Down
12 changes: 11 additions & 1 deletion src/unxt/_src/quantity/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,15 @@
from dataclasses import replace
from typing import final

from jaxtyping import ArrayLike
import equinox as eqx
import jax
from jaxtyping import Array, ArrayLike, Shaped
from plum import parametric

from .base import AbstractQuantity
from .base_parametric import AbstractParametricQuantity
from unxt._src.units import unit as parse_unit
from unxt._src.units.api import AbstractUnits


@final
Expand Down Expand Up @@ -102,6 +106,12 @@ class Quantity(AbstractParametricQuantity):
"""

value: Shaped[Array, "*shape"] = eqx.field(converter=jax.numpy.asarray)
"""The value of the `AbstractQuantity`."""

unit: AbstractUnits = eqx.field(static=True, converter=parse_unit)
"""The unit associated with this value."""


@AbstractQuantity.__mod__.dispatch # type: ignore[misc]
def mod(self: Quantity["dimensionless"], other: ArrayLike) -> Quantity["dimensionless"]:
Expand Down
12 changes: 12 additions & 0 deletions src/unxt/_src/quantity/fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,13 @@

from typing import Any

import equinox as eqx
import jax
from jaxtyping import Array, Shaped

from .base import AbstractQuantity
from unxt._src.units import unit as parse_unit
from unxt._src.units.api import AbstractUnits


class UncheckedQuantity(AbstractQuantity):
Expand All @@ -14,6 +20,12 @@ class UncheckedQuantity(AbstractQuantity):
This class is not parametrized by its dimensionality.
"""

value: Shaped[Array, "*shape"] = eqx.field(converter=jax.numpy.asarray)
"""The value of the `AbstractQuantity`."""

unit: AbstractUnits = eqx.field(static=True, converter=parse_unit)
"""The unit associated with this value."""

def __class_getitem__(
cls: type["UncheckedQuantity"], item: Any
) -> type["UncheckedQuantity"]:
Expand Down
5 changes: 3 additions & 2 deletions src/unxt/_src/quantity/mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from collections.abc import Callable, Sequence
from typing import TYPE_CHECKING, Any, cast

import equinox as eqx
import numpy as np
from astropy.units import CompositeUnit
from jax.typing import ArrayLike
Expand All @@ -23,8 +24,8 @@
class AstropyQuantityCompatMixin:
"""Mixin for compatibility with `astropy.units.Quantity`."""

value: ArrayLike
unit: AbstractUnits
value: eqx.AbstractVar[ArrayLike]
unit: eqx.AbstractVar[AbstractUnits]
uconvert: Callable[[Any], "unxt.quantity.AbstractQuantity"]
ustrip: Callable[[Any], ArrayLike]

Expand Down
4 changes: 3 additions & 1 deletion tests/unit/test_quantity.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from hypothesis.extra.array_api import make_strategies_namespace
from hypothesis.extra.numpy import array_shapes as np_array_shapes, arrays as np_arrays
from jax.dtypes import canonicalize_dtype
from jaxtyping import TypeCheckError
from jaxtyping import Array, TypeCheckError
from plum import convert, parametric

import quaxed.numpy as jnp
Expand Down Expand Up @@ -504,6 +504,8 @@ def test_at():
class NewQuantity(AbstractParametricQuantity):
"""Quantity with a flag."""

value: Array = eqx.field(converter=jnp.asarray)
unit: str = eqx.field(converter=u.unit)
flag: bool = eqx.field(static=True, kw_only=True)


Expand Down

0 comments on commit 5bb134d

Please sign in to comment.