Skip to content

Commit

Permalink
👷 build(benchmarks): add more benchmarks and config (#363)
Browse files Browse the repository at this point in the history
* 👷 build(benchmarks): add more benchmarks and config
* ⬆️ dep-bump(deps): bump and set optional dependency pins

Signed-off-by: Nathaniel Starkman <[email protected]>
  • Loading branch information
nstarman authored Jan 8, 2025
1 parent aefd9d4 commit a9e0afb
Show file tree
Hide file tree
Showing 8 changed files with 363 additions and 202 deletions.
2 changes: 1 addition & 1 deletion .python-version
Original file line number Diff line number Diff line change
@@ -1 +1 @@
3.11
3.12
26 changes: 15 additions & 11 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@
all = ["unxt[backend-astropy,interop-gala,interop-mpl]"]
backend-astropy = ["astropy>=6.0"]
interop-gala = ["gala>=1.8"]
interop-mpl = ["matplotlib>=3.4"]
interop-mpl = ["matplotlib>=3.5"]

[project.urls]
"Bug Tracker" = "https://github.com/GalacticDynamics/unxt/issues"
Expand All @@ -80,12 +80,12 @@
"pytz>=2024.2", # for copyright date
"sphinx-book-theme==1.1.3",
"sphinx-prompt>=1.8.0",
"sphinx-tippy",
"sphinx-tippy>=0.4.3",
"sphinx>=7.0",
"sphinx_autodoc_typehints",
"sphinx_copybutton",
"sphinx_design",
"sphinx_togglebutton",
"sphinx_autodoc_typehints>=3.0.0",
"sphinx_copybutton>=0.5.2",
"sphinx_design>=0.6.1",
"sphinx_togglebutton>=0.3.2",
"sphinxext-opengraph>=0.9.1",
"sphinxext-rediraffe>=0.2.7",
]
Expand All @@ -94,13 +94,17 @@
"hypothesis[numpy]>=6.112.2",
"pytest>=8.3.3",
"pytest-arraydiff>=0.6.1",
"pytest-codspeed",
"pytest-cov >=3",
"pytest-env",
"pytest-github-actions-annotate-failures",
"pytest-benchmark>=5.1",
"pytest-codspeed>=3.1",
"pytest-cov>=3",
"pytest-env>=1.1.5",
"pytest-github-actions-annotate-failures>=0.2.0",
"sybil>=8.0.0",
]
pytest-benchmark-parallel = [
"pytest-xdist>=3.6.1",
]
test-mpl = ["pytest-mpl"]
test-mpl = ["pytest-mpl>=0.17.0"]
test-all = [{ include-group = "test" }, { include-group = "test-mpl" }]


Expand Down
94 changes: 0 additions & 94 deletions tests/benchmark/test_dimensions.py

This file was deleted.

78 changes: 78 additions & 0 deletions tests/benchmark/test_dims.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
"""Benchmark tests for quaxified jax."""

import equinox as eqx
import jax
import pytest
from jaxlib.xla_extension import PjitFunction

import unxt as u


@pytest.fixture
def func_dimension() -> PjitFunction:
return eqx.filter_jit(u.dimension)


@pytest.fixture
def func_dimension_of() -> PjitFunction:
return eqx.filter_jit(u.dimension_of)


#####################################################################
# `dimension`

args = [(u.dimension("length"),), ("length",)]


@pytest.mark.parametrize("args", args, ids=str)
@pytest.mark.benchmark(group="dimensions", warmup=False, max_time=1.0)
def test_dimension(args):
"""Test calling `unxt.dimension`."""
_ = u.dimension(*args)


@pytest.mark.parametrize("args", args, ids=str)
@pytest.mark.benchmark(group="dimensions", warmup=False, max_time=1.0)
def test_dimension_jit_compile(func_dimension, args):
"""Test the speed of jitting."""
_ = func_dimension.lower(*args).compile()


@pytest.mark.parametrize("args", args, ids=str)
@pytest.mark.benchmark(group="dimensions", warmup=True, max_time=1.0)
def test_dimension_execute(func_dimension, args):
"""Test the speed of calling the function."""
_ = jax.block_until_ready(func_dimension(*args))


#####################################################################
# `dimension_of`


args = [
(u.dimension("length"),), # -> Dimension('length')
(u.unit("m"),), # -> Dimension('length')
(u.Quantity(1, "m"),), # -> Dimension('length')
(2,), # -> None
]


@pytest.mark.parametrize("args", args, ids=str)
@pytest.mark.benchmark(group="dimensions", warmup=False, max_time=1.0)
def test_dimension_of(args):
"""Test calling `unxt.dimension_of`."""
_ = u.dimension_of(*args)


@pytest.mark.parametrize("args", args, ids=str)
@pytest.mark.benchmark(group="dimensions", warmup=False, max_time=1.0)
def test_dimension_of_jit_compile(func_dimension_of, args):
"""Test the speed of jitting."""
_ = func_dimension_of.lower(*args).compile()


@pytest.mark.parametrize("args", args, ids=str)
@pytest.mark.benchmark(group="dimensions", warmup=True, max_time=1.0)
def test_dimension_of_execute(func_dimension_of, args):
"""Test the speed of calling the function."""
_ = jax.block_until_ready(func_dimension_of(*args))
19 changes: 7 additions & 12 deletions tests/benchmark/test_quaxed.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Benchmark tests for quaxified jax."""
"""Benchmark tests for quaxed functions on quantities."""

from collections.abc import Callable
from typing import Any, TypeAlias, TypedDict
Expand All @@ -12,25 +12,19 @@

import unxt as u

x_nodim = u.Quantity(jnp.linspace(0, 1, 1000), "")
x_length = u.Quantity(jnp.linspace(0, 1, 1000), "m")
x_angle = u.Quantity(jnp.linspace(0, 1, 1000), "rad")


Args: TypeAlias = tuple[Any, ...]

x = jnp.linspace(0, 1, 1000)
x_nodim = u.Quantity(x, "")
x_length = u.Quantity(x, "m")
x_angle = u.Quantity(x, "rad")


def process_func(func: Callable[..., Any], args: Args) -> tuple[Compiled, Args]:
"""JIT and compile the function."""
return jax.jit(func), args


# def process_execute_func(func, args):
# """JIT and compile the function."""
# compiled_eager_func = jax.jit(func).lower(*args).compile()
# return compiled_eager_func, args


class ParameterizationKWArgs(TypedDict):
"""Keyword arguments for a pytest parameterization."""

Expand All @@ -57,6 +51,7 @@ def process_pytest_argvalues(
return {"argvalues": processed_argvalues, "ids": ids}


# TODO: also benchmark UncheckedQuantity
funcs_and_args: list[tuple[Callable[..., Any], Unpack[tuple[Args, ...]]]] = [
(jnp.abs, (x_nodim,), (x_length,)),
(jnp.acos, (x_nodim,)),
Expand Down
79 changes: 37 additions & 42 deletions tests/benchmark/test_units.py
Original file line number Diff line number Diff line change
@@ -1,78 +1,73 @@
"""Benchmark tests for `unxt.units`."""

import equinox as eqx
import jax
import pytest
from jaxlib.xla_extension import PjitFunction

import unxt as u

METER = u.unit("m")


@pytest.fixture
def func_unit_is_length():
return lambda x: u.unit(x) == METER
def func_unit() -> PjitFunction:
return eqx.filter_jit(u.unit)


@pytest.fixture
def func_unit_of_length():
return lambda x: u.unit_of(x) == METER
def func_unit_of() -> PjitFunction:
# need to filter_jit because arg can be a array or other object
return eqx.filter_jit(u.unit_of)


#####################################################################
# `unit`

args = [(u.unit("meter"),), ("meter",)]


@pytest.mark.parametrize(
"args",
[
(u.unit("meter"),), # -> Unit('meter')
("meter",), # -> Unit('meter')
],
)
@pytest.mark.benchmark(group="units", warmup=False)
@pytest.mark.parametrize("args", args, ids=str)
@pytest.mark.benchmark(group="units", warmup=False, max_time=1.0)
def test_unit(args):
"""Test calling `unxt.unit`."""
_ = u.unit(*args)


@pytest.mark.parametrize(
"args",
[
(u.unit("meter"),), # -> Unit('meter')
("meter",), # -> Unit('meter')
],
)
@pytest.mark.benchmark(group="units", warmup=True)
def test_unit_execute(func_unit_is_length, args):
@pytest.mark.parametrize("args", args, ids=str)
@pytest.mark.benchmark(group="units", warmup=True, max_time=1.0)
def test_unit_jit_compile(func_unit, args):
"""Test the speed of calling the function."""
_ = func_unit.lower(*args).compile()


@pytest.mark.parametrize("args", args, ids=str)
@pytest.mark.benchmark(group="units", warmup=True, max_time=1.0)
def test_unit_execute(func_unit, args):
"""Test the speed of calling the function."""
_ = jax.block_until_ready(func_unit_is_length(*args))
_ = jax.block_until_ready(func_unit(*args))


#####################################################################
# `unit_of`

args = [(u.unit("meter"),), (u.Quantity(1, "m"),), (2,)]

@pytest.mark.parametrize(
"args",
[
(u.unit("meter"),), # -> Unit('meter')
(u.Quantity(1, "m"),), # -> Unit('meter')
(2,),
],
)
@pytest.mark.benchmark(group="units", warmup=False)

@pytest.mark.parametrize("args", args, ids=str)
@pytest.mark.benchmark(group="units", warmup=False, max_time=1.0)
def test_unit_of(args):
"""Test calling `unxt.unit_of`."""
_ = u.unit_of(*args)


@pytest.mark.parametrize(
"args",
[
(u.Quantity(1, "m"),), # -> Unit('meter')
],
)
@pytest.mark.benchmark(group="units", warmup=False)
def test_unit_of_jit_compile(func_unit_of_length, args):
@pytest.mark.parametrize("args", args, ids=str)
@pytest.mark.benchmark(group="units", warmup=False, max_time=1.0)
def test_unit_of_jit_compile(func_unit_of, args):
"""Test the speed of jitting a function."""
_ = jax.jit(func_unit_of_length).lower(*args).compile()
_ = func_unit_of.lower(*args).compile()


@pytest.mark.parametrize("args", args, ids=str)
@pytest.mark.benchmark(group="units", warmup=True, max_time=1.0)
def test_unit_of_execute(func_unit_of, args):
"""Test the speed of calling the function."""
_ = jax.block_until_ready(func_unit_of(*args))
Loading

0 comments on commit a9e0afb

Please sign in to comment.