Skip to content

Commit

Permalink
✨ feat(quantity): enable non-Array quax.ArrayValue as Quantity's value
Browse files Browse the repository at this point in the history
Signed-off-by: Nathaniel Starkman <[email protected]>
  • Loading branch information
nstarman committed Jan 21, 2025
1 parent ef1e1d6 commit 72e9b00
Show file tree
Hide file tree
Showing 10 changed files with 164 additions and 14 deletions.
28 changes: 27 additions & 1 deletion src/unxt/_interop/unxt_interop_astropy/quantity.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

__all__: list[str] = []

from typing import Any, TypeAlias
from typing import Any, NoReturn, TypeAlias

import astropy.units as apyu
from astropy.coordinates import Angle as AstropyAngle, Distance as AstropyDistance
Expand All @@ -16,6 +16,32 @@
from unxt.quantity import AbstractQuantity, Quantity, UncheckedQuantity, ustrip
from unxt.units import unit, unit_of

# ============================================================================
# Value Converter


@dispatch
def value_converter(obj: AstropyQuantity, /) -> NoReturn:
"""Disallow conversion of `AstropyQuantity` to a value.
>>> import astropy.units as apyu
>>> from unxt.quantity import value_converter
>>> try:
... value_converter(apyu.Quantity(1, "m"))
... except TypeError as e:
... print(e)
Cannot convert 'Quantity' to a value.
For a Quantity, use the `.from_` constructor instead.
"""
msg = (
f"Cannot convert {type(obj).__name__!r} to a value. "
"For a Quantity, use the `.from_` constructor instead."
)
raise TypeError(msg)


# ============================================================================
# AbstractQuantity

Expand Down
2 changes: 2 additions & 0 deletions src/unxt/_src/quantity/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,12 @@
"is_unit_convertible",
"uconvert",
"ustrip",
"value_converter",
]

from .api import is_unit_convertible, uconvert, ustrip
from .base import AbstractQuantity
from .base_parametric import AbstractParametricQuantity
from .quantity import Quantity
from .unchecked import UncheckedQuantity
from .value import value_converter
6 changes: 3 additions & 3 deletions src/unxt/_src/quantity/base_parametric.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@

__all__ = ["AbstractParametricQuantity"]

from collections.abc import Callable, Sequence
from collections.abc import Callable
from functools import partial
from typing import Any

import equinox as eqx
from astropy.units import PhysicalType, Unit
from jaxtyping import Array, ArrayLike, Shaped
from jaxtyping import Array, Shaped
from plum import dispatch, parametric, type_nonparametric, type_unparametrized

from dataclassish import field_items, fields
Expand Down Expand Up @@ -79,7 +79,7 @@ def __init_type_parameter__(cls, unit: UnitTypes, /) -> tuple[AbstractDimension]

@classmethod
def __infer_type_parameter__(
cls, value: ArrayLike | Sequence[Any], unit: Any, **kwargs: Any
cls, value: Any, unit: Any, **kwargs: Any
) -> tuple[AbstractDimension]:
"""Infer the type parameter from the arguments."""
return (dimension_of(parse_unit(unit)),)
Expand Down
4 changes: 2 additions & 2 deletions src/unxt/_src/quantity/quantity.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@
from typing import final

import equinox as eqx
import jax
from jaxtyping import Array, ArrayLike, Shaped
from plum import parametric

from .base import AbstractQuantity
from .base_parametric import AbstractParametricQuantity
from .value import value_converter
from unxt._src.units import unit as parse_unit
from unxt._src.units.api import AbstractUnits

Expand Down Expand Up @@ -106,7 +106,7 @@ class Quantity(AbstractParametricQuantity):
"""

value: Shaped[Array, "*shape"] = eqx.field(converter=jax.numpy.asarray)
value: Shaped[Array, "*shape"] = eqx.field(converter=value_converter)
"""The value of the `AbstractQuantity`."""

unit: AbstractUnits = eqx.field(static=True, converter=parse_unit)
Expand Down
4 changes: 2 additions & 2 deletions src/unxt/_src/quantity/unchecked.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@
from typing import Any

import equinox as eqx
import jax
from jaxtyping import Array, Shaped

from .base import AbstractQuantity
from .value import value_converter
from unxt._src.units import unit as parse_unit
from unxt._src.units.api import AbstractUnits

Expand All @@ -20,7 +20,7 @@ class UncheckedQuantity(AbstractQuantity):
This class is not parametrized by its dimensionality.
"""

value: Shaped[Array, "*shape"] = eqx.field(converter=jax.numpy.asarray)
value: Shaped[Array, "*shape"] = eqx.field(converter=value_converter)
"""The value of the `AbstractQuantity`."""

unit: AbstractUnits = eqx.field(static=True, converter=parse_unit)
Expand Down
92 changes: 92 additions & 0 deletions src/unxt/_src/quantity/value.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
__all__ = ["value_converter"]

import warnings
from typing import Any, NoReturn

import quax
from jaxtyping import Array, ArrayLike
from plum import dispatch

import quaxed.numpy as jnp

from .base import AbstractQuantity


@dispatch.abstract
def value_converter(obj: Any, /) -> Any:
"""Convert for the value field of an `AbstractQuantity` subclass."""
raise NotImplementedError # pragma: no cover


@dispatch
def value_converter(obj: quax.ArrayValue, /) -> Any:
"""Convert a `quax.ArrayValue` for the value field.
>>> import warnings
>>> import jax
>>> import jax.numpy as jnp
>>> from jaxtyping import Array
>>> from quax import ArrayValue
>>> class MyArray(ArrayValue):
... value: Array
...
... def aval(self):
... return jax.core.ShapedArray(self.value.shape, self.value.dtype)
...
... def materialise(self):
... return self.value
>>> x = MyArray(jnp.array([1, 2, 3]))
>>> with warnings.catch_warnings(record=True) as w:
... warnings.simplefilter("always")
... y = value_converter(x)
>>> y
MyArray(value=i32[3])
>>> print(f"Warning caught: {w[-1].message}")
Warning caught: 'quax.ArrayValue' subclass 'MyArray' ...
"""
warnings.warn(
f"'quax.ArrayValue' subclass {type(obj).__name__!r} does not have a registered "
"converter. Returning the object as is.",
category=UserWarning,
stacklevel=2,
)
return obj


@dispatch
def value_converter(obj: ArrayLike | list[Any] | tuple[Any, ...], /) -> Array:
"""Convert an array-like object to a `jax.numpy.ndarray`.
>>> import jax.numpy as jnp
>>> from unxt.quantity import value_converter
>>> value_converter([1, 2, 3])
Array([1, 2, 3], dtype=int32)
"""
return jnp.asarray(obj)


@dispatch
def value_converter(obj: AbstractQuantity, /) -> NoReturn:
"""Disallow conversion of `AbstractQuantity` to a value.
>>> import unxt as u
>>> from unxt.quantity import value_converter
>>> try:
... value_converter(u.Quantity(1, "m"))
... except TypeError as e:
... print(e)
Cannot convert 'Quantity[PhysicalType('length')]' to a value.
For a Quantity, use the `.from_` constructor instead.
"""
msg = (
f"Cannot convert '{type(obj).__name__}' to a value. "
"For a Quantity, use the `.from_` constructor instead."
)
raise TypeError(msg)
18 changes: 13 additions & 5 deletions src/unxt/quantity.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,17 @@
"""
# ruff:noqa: F403

from ._src.quantity.api import is_unit_convertible, uconvert, ustrip
from ._src.quantity.base import AbstractQuantity
from ._src.quantity.base_parametric import AbstractParametricQuantity
from ._src.quantity.quantity import Quantity
from ._src.quantity.unchecked import UncheckedQuantity
from jaxtyping import install_import_hook

from .setup_package import RUNTIME_TYPECHECKER

with install_import_hook("unxt.quantity", RUNTIME_TYPECHECKER):
from ._src.quantity.api import is_unit_convertible, uconvert, ustrip
from ._src.quantity.base import AbstractQuantity
from ._src.quantity.base_parametric import AbstractParametricQuantity
from ._src.quantity.quantity import Quantity
from ._src.quantity.unchecked import UncheckedQuantity
from ._src.quantity.value import value_converter

# isort: split
# Register dispatches and conversions
Expand All @@ -40,6 +46,8 @@
"uconvert",
"ustrip",
"is_unit_convertible",
# utils
"value_converter",
]


Expand Down
1 change: 1 addition & 0 deletions tests/integration/quax/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Tests."""
21 changes: 21 additions & 0 deletions tests/integration/quax/test_lora.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
"""Tests."""

import re

import jax.numpy as jnp
import jax.random as jr
import pytest
from quax.examples import lora

import unxt as u


def test_lora_array_as_quantity_value():
lora_array = lora.LoraArray(jnp.asarray([[1.0, 2, 3]]), rank=1, key=jr.key(0))
with pytest.warns(
UserWarning, match=re.escape("'quax.ArrayValue' subclass 'LoraArray'")
):
quantity = u.Quantity(lora_array, "m")

assert quantity.value is lora_array
assert quantity.unit == "m"
2 changes: 1 addition & 1 deletion uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 72e9b00

Please sign in to comment.