Skip to content

Commit

Permalink
Now using pytest -Werror
Browse files Browse the repository at this point in the history
  • Loading branch information
patrick-kidger committed Dec 24, 2024
1 parent 952d0d2 commit 396aac6
Show file tree
Hide file tree
Showing 3 changed files with 134 additions and 122 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ build-backend = "hatchling.build"
include = ["optimistix/*"]

[tool.pytest.ini_options]
addopts = "--jaxtyping-packages=optimistix,beartype.beartype(conf=beartype.BeartypeConf(strategy=beartype.BeartypeStrategy.On))"
addopts = "-Werror --jaxtyping-packages=optimistix,beartype.beartype(conf=beartype.BeartypeConf(strategy=beartype.BeartypeStrategy.On))"

[tool.ruff]
extend-include = ["*.ipynb"]
Expand Down
124 changes: 65 additions & 59 deletions tests/test_fixed_point.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import contextlib
import random

import equinox as eqx
Expand Down Expand Up @@ -47,66 +48,71 @@ def test_fixed_point(solver, _fn, init, args):
@pytest.mark.parametrize("_fn, init, args", fixed_point_fn_init_args)
@pytest.mark.parametrize("dtype", [jnp.float64, jnp.complex128])
def test_fixed_point_jvp(getkey, solver, _fn, init, dtype, args):
args = jtu.tree_map(lambda x: x.astype(dtype), args)
init = jtu.tree_map(lambda x: x.astype(dtype), init)
atol = rtol = 1e-3
has_aux = random.choice([True, False])
if has_aux:
fn = lambda x, args: (_fn(x, args), smoke_aux)
if dtype == jnp.complex128:
context = pytest.warns(match="Complex support in Optimistix is a work in")
else:
fn = _fn

dynamic_args, static_args = eqx.partition(args, eqx.is_array)
t_init = jtu.tree_map(lambda x: jr.normal(getkey(), x.shape, dtype=dtype), init)
t_dynamic_args = jtu.tree_map(
lambda x: jr.normal(getkey(), x.shape, dtype=dtype), dynamic_args
)

def fixed_point(x, dynamic_args, *, adjoint):
args = eqx.combine(dynamic_args, static_args)
return optx.fixed_point(
fn,
solver,
x,
has_aux=has_aux,
args=args,
max_steps=10_000,
adjoint=adjoint,
throw=False,
).value

otd = optx.ImplicitAdjoint()
expected_out, t_expected_out = finite_difference_jvp(
fixed_point,
(init, dynamic_args),
(t_init, t_dynamic_args),
adjoint=otd,
)
out, t_out = eqx.filter_jvp(
fixed_point,
(init, dynamic_args),
(t_init, t_dynamic_args),
adjoint=otd,
)
dto = PiggybackAdjoint()
expected_out2, t_expected_out2 = finite_difference_jvp(
fixed_point,
(init, dynamic_args),
(t_init, t_dynamic_args),
adjoint=dto,
)
out2, t_out2 = eqx.filter_jvp(
fixed_point,
(init, dynamic_args),
(t_init, t_dynamic_args),
adjoint=dto,
)
assert tree_allclose(expected_out2, expected_out, atol=atol, rtol=rtol)
assert tree_allclose(out, expected_out, atol=atol, rtol=rtol)
assert tree_allclose(out2, expected_out, atol=atol, rtol=rtol)
assert tree_allclose(t_expected_out2, t_expected_out, atol=atol, rtol=rtol)
assert tree_allclose(t_out, t_expected_out, atol=atol, rtol=rtol)
assert tree_allclose(t_out2, t_expected_out, atol=atol, rtol=rtol)
context = contextlib.nullcontext()
with context:
args = jtu.tree_map(lambda x: x.astype(dtype), args)
init = jtu.tree_map(lambda x: x.astype(dtype), init)
atol = rtol = 1e-3
has_aux = random.choice([True, False])
if has_aux:
fn = lambda x, args: (_fn(x, args), smoke_aux)
else:
fn = _fn

dynamic_args, static_args = eqx.partition(args, eqx.is_array)
t_init = jtu.tree_map(lambda x: jr.normal(getkey(), x.shape, dtype=dtype), init)
t_dynamic_args = jtu.tree_map(
lambda x: jr.normal(getkey(), x.shape, dtype=dtype), dynamic_args
)

def fixed_point(x, dynamic_args, *, adjoint):
args = eqx.combine(dynamic_args, static_args)
return optx.fixed_point(
fn,
solver,
x,
has_aux=has_aux,
args=args,
max_steps=10_000,
adjoint=adjoint,
throw=False,
).value

otd = optx.ImplicitAdjoint()
expected_out, t_expected_out = finite_difference_jvp(
fixed_point,
(init, dynamic_args),
(t_init, t_dynamic_args),
adjoint=otd,
)
out, t_out = eqx.filter_jvp(
fixed_point,
(init, dynamic_args),
(t_init, t_dynamic_args),
adjoint=otd,
)
dto = PiggybackAdjoint()
expected_out2, t_expected_out2 = finite_difference_jvp(
fixed_point,
(init, dynamic_args),
(t_init, t_dynamic_args),
adjoint=dto,
)
out2, t_out2 = eqx.filter_jvp(
fixed_point,
(init, dynamic_args),
(t_init, t_dynamic_args),
adjoint=dto,
)
assert tree_allclose(expected_out2, expected_out, atol=atol, rtol=rtol)
assert tree_allclose(out, expected_out, atol=atol, rtol=rtol)
assert tree_allclose(out2, expected_out, atol=atol, rtol=rtol)
assert tree_allclose(t_expected_out2, t_expected_out, atol=atol, rtol=rtol)
assert tree_allclose(t_out, t_expected_out, atol=atol, rtol=rtol)
assert tree_allclose(t_out2, t_expected_out, atol=atol, rtol=rtol)


@pytest.mark.parametrize(
Expand Down
130 changes: 68 additions & 62 deletions tests/test_root_find.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import contextlib
import random

import equinox as eqx
Expand Down Expand Up @@ -54,70 +55,75 @@ def root_find_problem(y, args):
@pytest.mark.parametrize("_fn, init, args", fixed_point_fn_init_args)
@pytest.mark.parametrize("dtype", [jnp.float64, jnp.complex128])
def test_root_find_jvp(getkey, solver, _fn, init, dtype, args):
args = jtu.tree_map(lambda x: x.astype(dtype), args)
init = jtu.tree_map(lambda x: x.astype(dtype), init)
atol = rtol = 1e-3
has_aux = random.choice([True, False])

def root_find_problem(y, args):
f_val = _fn(y, args)
return (f_val**ω - y**ω).ω

if has_aux:
fn = lambda x, args: (root_find_problem(x, args), smoke_aux)
if dtype == jnp.complex128:
context = pytest.warns(match="Complex support in Optimistix is a work in")
else:
fn = root_find_problem
dynamic_args, static_args = eqx.partition(args, eqx.is_array)
t_init = jtu.tree_map(lambda x: jr.normal(getkey(), x.shape, dtype=dtype), init)
t_dynamic_args = jtu.tree_map(
lambda x: jr.normal(getkey(), x.shape, dtype=dtype), dynamic_args
)
context = contextlib.nullcontext()
with context:
args = jtu.tree_map(lambda x: x.astype(dtype), args)
init = jtu.tree_map(lambda x: x.astype(dtype), init)
atol = rtol = 1e-3
has_aux = random.choice([True, False])

def root_find_problem(y, args):
f_val = _fn(y, args)
return (f_val**ω - y**ω).ω

if has_aux:
fn = lambda x, args: (root_find_problem(x, args), smoke_aux)
else:
fn = root_find_problem
dynamic_args, static_args = eqx.partition(args, eqx.is_array)
t_init = jtu.tree_map(lambda x: jr.normal(getkey(), x.shape, dtype=dtype), init)
t_dynamic_args = jtu.tree_map(
lambda x: jr.normal(getkey(), x.shape, dtype=dtype), dynamic_args
)

def root_find(x, dynamic_args, *, adjoint):
args = eqx.combine(dynamic_args, static_args)
return optx.root_find(
fn,
solver,
x,
has_aux=has_aux,
args=args,
max_steps=10_000,
adjoint=adjoint,
throw=False,
).value

otd = optx.ImplicitAdjoint()
expected_out, t_expected_out = finite_difference_jvp(
root_find,
(init, dynamic_args),
(t_init, t_dynamic_args),
adjoint=otd,
)
out, t_out = eqx.filter_jvp(
root_find,
(init, dynamic_args),
(t_init, t_dynamic_args),
adjoint=otd,
)
dto = PiggybackAdjoint()
expected_out2, t_expected_out2 = finite_difference_jvp(
root_find,
(init, dynamic_args),
(t_init, t_dynamic_args),
adjoint=dto,
)
out2, t_out2 = eqx.filter_jvp(
root_find,
(init, dynamic_args),
(t_init, t_dynamic_args),
adjoint=dto,
)
assert tree_allclose(expected_out2, expected_out, atol=atol, rtol=rtol)
assert tree_allclose(out, expected_out, atol=atol, rtol=rtol)
assert tree_allclose(out2, expected_out, atol=atol, rtol=rtol)
assert tree_allclose(t_expected_out2, t_expected_out, atol=atol, rtol=rtol)
assert tree_allclose(t_out, t_expected_out, atol=atol, rtol=rtol)
assert tree_allclose(t_out2, t_expected_out, atol=atol, rtol=rtol)
def root_find(x, dynamic_args, *, adjoint):
args = eqx.combine(dynamic_args, static_args)
return optx.root_find(
fn,
solver,
x,
has_aux=has_aux,
args=args,
max_steps=10_000,
adjoint=adjoint,
throw=False,
).value

otd = optx.ImplicitAdjoint()
expected_out, t_expected_out = finite_difference_jvp(
root_find,
(init, dynamic_args),
(t_init, t_dynamic_args),
adjoint=otd,
)
out, t_out = eqx.filter_jvp(
root_find,
(init, dynamic_args),
(t_init, t_dynamic_args),
adjoint=otd,
)
dto = PiggybackAdjoint()
expected_out2, t_expected_out2 = finite_difference_jvp(
root_find,
(init, dynamic_args),
(t_init, t_dynamic_args),
adjoint=dto,
)
out2, t_out2 = eqx.filter_jvp(
root_find,
(init, dynamic_args),
(t_init, t_dynamic_args),
adjoint=dto,
)
assert tree_allclose(expected_out2, expected_out, atol=atol, rtol=rtol)
assert tree_allclose(out, expected_out, atol=atol, rtol=rtol)
assert tree_allclose(out2, expected_out, atol=atol, rtol=rtol)
assert tree_allclose(t_expected_out2, t_expected_out, atol=atol, rtol=rtol)
assert tree_allclose(t_out, t_expected_out, atol=atol, rtol=rtol)
assert tree_allclose(t_out2, t_expected_out, atol=atol, rtol=rtol)


def test_bisection_flip():
Expand Down

0 comments on commit 396aac6

Please sign in to comment.