From f2dc52b6ba062e459e18a29f282681a08c322f81 Mon Sep 17 00:00:00 2001 From: Martin Schubert Date: Tue, 6 Feb 2024 17:17:33 -0800 Subject: [PATCH 1/3] Move additional logic out of pure callback --- src/invrs_opt/lbfgsb/lbfgsb.py | 17 ++++++++++------- tests/lbfgsb/test_lbfgsb.py | 2 +- 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/src/invrs_opt/lbfgsb/lbfgsb.py b/src/invrs_opt/lbfgsb/lbfgsb.py index 9b950b0..7be67b1 100644 --- a/src/invrs_opt/lbfgsb/lbfgsb.py +++ b/src/invrs_opt/lbfgsb/lbfgsb.py @@ -248,32 +248,35 @@ def update_fn( del params def _update_pure( - latent_grad: PyTree, + flat_latent_grad: PyTree, value: jnp.ndarray, jax_lbfgsb_state: JaxLbfgsbDict, ) -> Tuple[PyTree, JaxLbfgsbDict]: assert onp.size(value) == 1 scipy_lbfgsb_state = ScipyLbfgsbState.from_jax(jax_lbfgsb_state) scipy_lbfgsb_state.update( - grad=_to_numpy(latent_grad), value=onp.asarray(value) + grad=onp.asarray(flat_latent_grad, dtype=onp.float64), + value=onp.asarray(value, dtype=onp.float64), ) - latent_params = _to_pytree(scipy_lbfgsb_state.x, latent_grad) - return latent_params, scipy_lbfgsb_state.to_jax() + flat_latent_params = jnp.asarray(scipy_lbfgsb_state.x) + return flat_latent_params, scipy_lbfgsb_state.to_jax() params, latent_params, jax_lbfgsb_state = state _, vjp_fn = jax.vjp(transform_fn, latent_params) (latent_grad,) = vjp_fn(grad) + flat_latent_grad, unflatten_fn = flatten_util.ravel_pytree(latent_grad) ( - latent_params, + flat_latent_params, jax_lbfgsb_state, ) = jax.pure_callback( # type: ignore[attr-defined] _update_pure, - (latent_params, jax_lbfgsb_state), - latent_grad, + (flat_latent_grad, jax_lbfgsb_state), + flat_latent_grad, value, jax_lbfgsb_state, ) + latent_params = unflatten_fn(flat_latent_params) return transform_fn(latent_params), latent_params, jax_lbfgsb_state return base.Optimizer( diff --git a/tests/lbfgsb/test_lbfgsb.py b/tests/lbfgsb/test_lbfgsb.py index dec5cc1..29cb90e 100644 --- a/tests/lbfgsb/test_lbfgsb.py +++ b/tests/lbfgsb/test_lbfgsb.py @@ -428,7 +428,7 @@ def step_fn(state): no_batch_values[-1].append(value) onp.testing.assert_allclose( - batch_values, onp.transpose(no_batch_values, (1, 0)) + batch_values, onp.transpose(no_batch_values, (1, 0)), atol=1e-4 ) From 6baa542fb7eb1fd6b2eb21d7eb62130ceea97b3c Mon Sep 17 00:00:00 2001 From: Martin Schubert Date: Tue, 6 Feb 2024 17:24:08 -0800 Subject: [PATCH 2/3] Version updated from v0.3.1 to v0.3.2 --- .bumpversion.toml | 2 +- README.md | 2 +- pyproject.toml | 2 +- src/invrs_opt/__init__.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/.bumpversion.toml b/.bumpversion.toml index 6539462..1ab3950 100644 --- a/.bumpversion.toml +++ b/.bumpversion.toml @@ -1,5 +1,5 @@ [tool.bumpversion] -current_version = "v0.3.1" +current_version = "v0.3.2" commit = true commit_args = "--no-verify" tag = true diff --git a/README.md b/README.md index 4073d80..3e0e87f 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,5 @@ # invrs-opt - Optimization algorithms for inverse design -`v0.3.1` +`v0.3.2` ## Overview diff --git a/pyproject.toml b/pyproject.toml index 4af0328..c592db5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,7 @@ [project] name = "invrs_opt" -version = "v0.3.1" +version = "v0.3.2" description = "Algorithms for inverse design" keywords = ["topology", "optimization", "jax", "inverse design"] readme = "README.md" diff --git a/src/invrs_opt/__init__.py b/src/invrs_opt/__init__.py index 9c1dc57..f9846f6 100644 --- a/src/invrs_opt/__init__.py +++ b/src/invrs_opt/__init__.py @@ -3,7 +3,7 @@ Copyright (c) 2023 The INVRS-IO authors. """ -__version__ = "v0.3.1" +__version__ = "v0.3.2" __author__ = "Martin F. Schubert " from invrs_opt.lbfgsb.lbfgsb import density_lbfgsb as density_lbfgsb From e08429e19132efd8f87f3e2826dd6d6f73bd6d24 Mon Sep 17 00:00:00 2001 From: Martin Schubert Date: Tue, 6 Feb 2024 17:26:55 -0800 Subject: [PATCH 3/3] Formatting --- src/invrs_opt/lbfgsb/lbfgsb.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/invrs_opt/lbfgsb/lbfgsb.py b/src/invrs_opt/lbfgsb/lbfgsb.py index 7be67b1..b97a6a7 100644 --- a/src/invrs_opt/lbfgsb/lbfgsb.py +++ b/src/invrs_opt/lbfgsb/lbfgsb.py @@ -261,10 +261,12 @@ def _update_pure( flat_latent_params = jnp.asarray(scipy_lbfgsb_state.x) return flat_latent_params, scipy_lbfgsb_state.to_jax() - params, latent_params, jax_lbfgsb_state = state + _, latent_params, jax_lbfgsb_state = state _, vjp_fn = jax.vjp(transform_fn, latent_params) (latent_grad,) = vjp_fn(grad) - flat_latent_grad, unflatten_fn = flatten_util.ravel_pytree(latent_grad) + flat_latent_grad, unflatten_fn = flatten_util.ravel_pytree( + latent_grad + ) # type: ignore[no-untyped-call] ( flat_latent_params,