-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
✨ feat(quantity): enable non-Array quax.ArrayValue as Quantity's value
Signed-off-by: Nathaniel Starkman <[email protected]>
- Loading branch information
Showing
10 changed files
with
164 additions
and
14 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
__all__ = ["value_converter"] | ||
|
||
import warnings | ||
from typing import Any, NoReturn | ||
|
||
import quax | ||
from jaxtyping import Array, ArrayLike | ||
from plum import dispatch | ||
|
||
import quaxed.numpy as jnp | ||
|
||
from .base import AbstractQuantity | ||
|
||
|
||
@dispatch.abstract | ||
def value_converter(obj: Any, /) -> Any: | ||
"""Convert for the value field of an `AbstractQuantity` subclass.""" | ||
raise NotImplementedError # pragma: no cover | ||
|
||
|
||
@dispatch | ||
def value_converter(obj: quax.ArrayValue, /) -> Any: | ||
"""Convert a `quax.ArrayValue` for the value field. | ||
>>> import warnings | ||
>>> import jax | ||
>>> import jax.numpy as jnp | ||
>>> from jaxtyping import Array | ||
>>> from quax import ArrayValue | ||
>>> class MyArray(ArrayValue): | ||
... value: Array | ||
... | ||
... def aval(self): | ||
... return jax.core.ShapedArray(self.value.shape, self.value.dtype) | ||
... | ||
... def materialise(self): | ||
... return self.value | ||
>>> x = MyArray(jnp.array([1, 2, 3])) | ||
>>> with warnings.catch_warnings(record=True) as w: | ||
... warnings.simplefilter("always") | ||
... y = value_converter(x) | ||
>>> y | ||
MyArray(value=i32[3]) | ||
>>> print(f"Warning caught: {w[-1].message}") | ||
Warning caught: 'quax.ArrayValue' subclass 'MyArray' ... | ||
""" | ||
warnings.warn( | ||
f"'quax.ArrayValue' subclass {type(obj).__name__!r} does not have a registered " | ||
"converter. Returning the object as is.", | ||
category=UserWarning, | ||
stacklevel=2, | ||
) | ||
return obj | ||
|
||
|
||
@dispatch | ||
def value_converter(obj: ArrayLike | list[Any] | tuple[Any, ...], /) -> Array: | ||
"""Convert an array-like object to a `jax.numpy.ndarray`. | ||
>>> import jax.numpy as jnp | ||
>>> from unxt.quantity import value_converter | ||
>>> value_converter([1, 2, 3]) | ||
Array([1, 2, 3], dtype=int32) | ||
""" | ||
return jnp.asarray(obj) | ||
|
||
|
||
@dispatch | ||
def value_converter(obj: AbstractQuantity, /) -> NoReturn: | ||
"""Disallow conversion of `AbstractQuantity` to a value. | ||
>>> import unxt as u | ||
>>> from unxt.quantity import value_converter | ||
>>> try: | ||
... value_converter(u.Quantity(1, "m")) | ||
... except TypeError as e: | ||
... print(e) | ||
Cannot convert 'Quantity[PhysicalType('length')]' to a value. | ||
For a Quantity, use the `.from_` constructor instead. | ||
""" | ||
msg = ( | ||
f"Cannot convert '{type(obj).__name__}' to a value. " | ||
"For a Quantity, use the `.from_` constructor instead." | ||
) | ||
raise TypeError(msg) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
"""Tests.""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
"""Tests.""" | ||
|
||
import re | ||
|
||
import jax.numpy as jnp | ||
import jax.random as jr | ||
import pytest | ||
from quax.examples import lora | ||
|
||
import unxt as u | ||
|
||
|
||
def test_lora_array_as_quantity_value(): | ||
lora_array = lora.LoraArray(jnp.asarray([[1.0, 2, 3]]), rank=1, key=jr.key(0)) | ||
with pytest.warns( | ||
UserWarning, match=re.escape("'quax.ArrayValue' subclass 'LoraArray'") | ||
): | ||
quantity = u.Quantity(lora_array, "m") | ||
|
||
assert quantity.value is lora_array | ||
assert quantity.unit == "m" |
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.