Skip to content

Commit

Permalink
feat: add multi-dispatch Quantity constructor (#14)
Browse files Browse the repository at this point in the history
* feat: add multi-dispatch Quantity constructor
* ci: pylint

Signed-off-by: nstarman <[email protected]>
  • Loading branch information
nstarman authored Feb 10, 2024
1 parent d6619aa commit b51e968
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 3 deletions.
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -161,10 +161,11 @@ similarities.ignore-imports = "yes"
messages_control.disable = [
"design",
"fixme",
"function-redefined", # plum-dispatch
"function-redefined", # plum-dispatch
"line-too-long",
"missing-function-docstring", # TODO: resolve
"missing-module-docstring",
"protected-access", # handled by ruff
"redefined-builtin", # handled by ruff
"unused-argument", # handled by ruff
"wrong-import-position",
Expand Down
53 changes: 51 additions & 2 deletions src/jax_quantity/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,14 @@
import operator
from collections.abc import Callable
from dataclasses import replace
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, final

import array_api_jax_compat
import equinox as eqx
import jax
import jax.core
from astropy.units import Unit, UnitConversionError
from array_api_jax_compat._dispatch import dispatcher
from astropy.units import Quantity as AstropyQuantity, Unit, UnitConversionError
from jaxtyping import ArrayLike
from quax import ArrayValue, quaxify
from typing_extensions import Self
Expand All @@ -28,12 +29,38 @@ def _binop(x: Any, y: Any) -> Any:
return _binop


@final
class Quantity(ArrayValue): # type: ignore[misc]
"""Represents an array, with each axis bound to a name."""

value: jax.Array = eqx.field(converter=jax.numpy.asarray)
unit: Unit = eqx.field(static=True, converter=Unit)

@classmethod
@dispatcher
def constructor(
cls: "type[Quantity]", value: ArrayLike, unit: Any, /
) -> "Quantity":
# Dispatch on both arguments.
# Construct using the standard `__init__` method.
return cls(value, unit)

@classmethod # type: ignore[no-redef]
@dispatcher
def constructor(
cls: "type[Quantity]", value: ArrayLike, *, unit: Any
) -> "Quantity":
# Dispatch on the `value` only. Dispatch to the full constructor.
return cls.constructor(value, unit)

@classmethod # type: ignore[no-redef]
@dispatcher
def constructor(
cls: "type[Quantity]", *, value: ArrayLike, unit: Any
) -> "Quantity":
# Dispatched on no argument. Dispatch to the full constructor.
return cls.constructor(value, unit)

# ===============================================================
# Quax

Expand Down Expand Up @@ -96,6 +123,28 @@ def __getitem__(self, key: Any) -> "Quantity":
__neg__ = quaxify(operator.__neg__)


# -----------------------------------------------
# Register additional constructors


@Quantity.constructor._f.register # noqa: SLF001
def constructor(cls: type[Quantity], value: Quantity, unit: Any, /) -> Quantity:
"""Construct a `Quantity` from another `Quantity`.
The `value` is converted to the new `unit`.
"""
return Quantity(value.to_value(unit), unit)


@Quantity.constructor._f.register # type: ignore[no-redef] # noqa: SLF001
def constructor(cls: type[Quantity], value: AstropyQuantity, unit: Any, /) -> Quantity:
"""Construct a `Quantity` from another `Quantity`.
The `value` is converted to the new `unit`.
"""
return Quantity(value.to_value(unit), unit)


# ===============================================================


Expand Down

0 comments on commit b51e968

Please sign in to comment.