Skip to content

Commit

Permalink
👽️ external(quaxed): use new quaxed.experimental.arrayish module
Browse files Browse the repository at this point in the history
  • Loading branch information
nstarman committed Jan 21, 2025
1 parent ef1e1d6 commit cec6837
Show file tree
Hide file tree
Showing 9 changed files with 125 additions and 200 deletions.
9 changes: 1 addition & 8 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ repos:
files: src
additional_dependencies:
- plum-dispatch>=2.5.6
- quaxed>=0.8.1

- repo: https://github.com/codespell-project/codespell
rev: "v2.3.0"
Expand All @@ -91,11 +92,3 @@ repos:
# hooks:
# - id: validate-pyproject
# additional_dependencies: ["validate-pyproject-schema-store[all]"]

- repo: local
hooks:
- id: disallow-caps
name: Disallow improper capitalization
language: pygrep
entry: PyBind|Numpy|Cmake|CCache|Github|PyTest
exclude: .pre-commit-config.yaml
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
"optional-dependencies>=0.3.2",
"plum-dispatch>=2.5.6",
"quax>=0.0.5",
"quaxed>=0.7.1",
"quaxed>=0.8.1",
"xmmutablemap>=0.1",
"zeroth>=1.0.0",
]
Expand Down
67 changes: 34 additions & 33 deletions src/unxt/_src/experimental.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,28 +24,29 @@

from collections.abc import Callable
from functools import partial
from typing import Any, ParamSpec, TypeVar
from typing import Any, TypeVar, TypeVarTuple
from typing_extensions import Unpack

import equinox as eqx
import jax
from jaxtyping import ArrayLike
from plum.parametric import type_unparametrized

from .quantity import Quantity, ustrip
from .quantity import AbstractQuantity, UncheckedQuantity as FastQ, ustrip
from .typing_ext import Unit
from .units import unit, unit_of

P = ParamSpec("P")
R = TypeVar("R", bound=Quantity)
Args = TypeVarTuple("Args")
R = TypeVar("R", bound=AbstractQuantity)


def unit_or_none(obj: Any) -> Unit | None:
return obj if obj is None else unit(obj)


def grad(
fun: Callable[P, R], argnums: int = 0, *, units: tuple[Unit | str, ...]
) -> Callable[P, R]:
fun: Callable[[Unpack[Args]], R], argnums: int = 0, *, units: tuple[Unit | str, ...]
) -> Callable[[Unpack[Args]], R]:
"""Gradient of a function with units.
In general, if you can use ``quax.quaxify(jax.grad(func))`` (or the
Expand Down Expand Up @@ -76,22 +77,22 @@ def grad(

# Gradient of function, stripping and adding units
@partial(jax.grad, argnums=argnums)
def gradfun_mag(*args: P.args) -> ArrayLike:
def gradfun_mag(*args: Any) -> ArrayLike:
args_ = (
(a if unit is None else Quantity(a, unit))
(a if unit is None else FastQ(a, unit))
for a, unit in zip(args, theunits, strict=True)
)
return ustrip(fun(*args_)) # type: ignore[call-arg]
return ustrip(fun(*args_)) # type: ignore[arg-type]

def gradfun(*args: P.args, **kw: P.kwargs) -> R:
def gradfun(*args: Unpack[Args]) -> R:
# Get the value of the args. They are turned back into Quantity
# inside the function we are taking the grad of.
args_ = tuple(
args_ = tuple( # type: ignore[var-annotated]
(a if unit is None else ustrip(unit, a))
for a, unit in zip(args, theunits, strict=True)
for a, unit in zip(args, theunits, strict=True) # type: ignore[arg-type]
)
# Call the grad, returning a Quantity
value = fun(*args) # type: ignore[call-arg]
value = fun(*args)
grad_value = gradfun_mag(*args_)
# Adjust the Quantity by the units of the derivative
# TODO: get Quantity[unit] / unit2 -> Quantity[unit/unit2] working
Expand All @@ -103,8 +104,8 @@ def gradfun(*args: P.args, **kw: P.kwargs) -> R:


def jacfwd(
fun: Callable[P, R], argnums: int = 0, *, units: tuple[Unit | str, ...]
) -> Callable[P, R]:
fun: Callable[[Unpack[Args]], R], argnums: int = 0, *, units: tuple[Unit | str, ...]
) -> Callable[[Unpack[Args]], R]:
"""Jacobian of ``fun`` evaluated column-by-column using forward-mode AD.
In general, if you can use `quax.quaxify(jax.jacfwd(func))` (or the
Expand All @@ -128,7 +129,7 @@ def jacfwd(
>>> jacfwd_cubbe_volume = u.experimental.jacfwd(cubbe_volume, units=("m",))
>>> jacfwd_cubbe_volume(u.Quantity(2.0, "m"))
Quantity['area'](Array(12., dtype=float32, weak_type=True), unit='m2')
UncheckedQuantity(Array(12., dtype=float32, weak_type=True), unit='m2')
"""
argnums = eqx.error_if(
Expand All @@ -140,19 +141,19 @@ def jacfwd(
theunits: tuple[Unit | None, ...] = tuple(map(unit_or_none, units))

@partial(jax.jacfwd, argnums=argnums)
def jacfun_mag(*args: P.args) -> R:
args_ = (
(a if unit is None else Quantity(a, unit))
def jacfun_mag(*args: Any) -> R:
args_ = tuple(
(a if unit is None else FastQ(a, unit))
for a, unit in zip(args, theunits, strict=True)
)
return fun(*args_) # type: ignore[call-arg]
return fun(*args_) # type: ignore[arg-type]

def jacfun(*args: P.args, **kw: P.kwargs) -> R:
def jacfun(*args: Unpack[Args]) -> R:
# Get the value of the args. They are turned back into Quantity
# inside the function we are taking the Jacobian of.
args_ = tuple(
args_ = tuple( # type: ignore[var-annotated]
(a if unit is None else ustrip(unit, a))
for a, unit in zip(args, theunits, strict=True)
for a, unit in zip(args, theunits, strict=True) # type: ignore[arg-type]
)
# Call the Jacobian, returning a Quantity
value = jacfun_mag(*args_)
Expand All @@ -167,8 +168,8 @@ def jacfun(*args: P.args, **kw: P.kwargs) -> R:


def hessian(
fun: Callable[P, R], argnums: int = 0, *, units: tuple[Unit | str, ...]
) -> Callable[P, R]:
fun: Callable[[Unpack[Args]], R], argnums: int = 0, *, units: tuple[Unit | str, ...]
) -> Callable[[Unpack[Args]], R]:
"""Hessian.
In general, if you can use `quax.quaxify(jax.hessian(func))` (or the
Expand All @@ -192,25 +193,25 @@ def hessian(
>>> hessian_cubbe_volume = u.experimental.hessian(cubbe_volume, units=("m",))
>>> hessian_cubbe_volume(u.Quantity(2.0, "m"))
Quantity['length'](Array(12., dtype=float32, weak_type=True), unit='m')
UncheckedQuantity(Array(12., dtype=float32, weak_type=True), unit='m')
"""
theunits: tuple[Unit, ...] = tuple(map(unit_or_none, units))

@partial(jax.hessian)
def hessfun_mag(*args: P.args) -> R:
args_ = (
(a if unit is None else Quantity(a, unit))
def hessfun_mag(*args: Any) -> R:
args_ = tuple(
(a if unit is None else FastQ(a, unit))
for a, unit in zip(args, theunits, strict=True)
)
return fun(*args_) # type: ignore[call-arg]
return fun(*args_) # type: ignore[arg-type]

def hessfun(*args: P.args, **kw: P.kwargs) -> R:
def hessfun(*args: Unpack[Args]) -> R:
# Get the value of the args. They are turned back into Quantity
# inside the function we are taking the hessian of.
args_ = tuple(
args_ = tuple( # type: ignore[var-annotated]
(a if unit is None else ustrip(unit, a))
for a, unit in zip(args, units, strict=True)
for a, unit in zip(args, units, strict=True) # type: ignore[arg-type]
)
# Call the hessian, returning a Quantity
value = hessfun_mag(*args_)
Expand Down
Loading

0 comments on commit cec6837

Please sign in to comment.