Skip to content

Commit

Permalink
Merge pull request #17 from invrs-io/flat
Browse files Browse the repository at this point in the history
Move additional logic out of pure callback
  • Loading branch information
mfschubert authored Feb 7, 2024
2 parents 7e0de79 + e08429e commit 8766e02
Show file tree
Hide file tree
Showing 6 changed files with 18 additions and 13 deletions.
2 changes: 1 addition & 1 deletion .bumpversion.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[tool.bumpversion]
current_version = "v0.3.1"
current_version = "v0.3.2"
commit = true
commit_args = "--no-verify"
tag = true
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# invrs-opt - Optimization algorithms for inverse design
`v0.3.1`
`v0.3.2`

## Overview

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[project]

name = "invrs_opt"
version = "v0.3.1"
version = "v0.3.2"
description = "Algorithms for inverse design"
keywords = ["topology", "optimization", "jax", "inverse design"]
readme = "README.md"
Expand Down
2 changes: 1 addition & 1 deletion src/invrs_opt/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
Copyright (c) 2023 The INVRS-IO authors.
"""

__version__ = "v0.3.1"
__version__ = "v0.3.2"
__author__ = "Martin F. Schubert <[email protected]>"

from invrs_opt.lbfgsb.lbfgsb import density_lbfgsb as density_lbfgsb
Expand Down
21 changes: 13 additions & 8 deletions src/invrs_opt/lbfgsb/lbfgsb.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,32 +248,37 @@ def update_fn(
del params

def _update_pure(
latent_grad: PyTree,
flat_latent_grad: PyTree,
value: jnp.ndarray,
jax_lbfgsb_state: JaxLbfgsbDict,
) -> Tuple[PyTree, JaxLbfgsbDict]:
assert onp.size(value) == 1
scipy_lbfgsb_state = ScipyLbfgsbState.from_jax(jax_lbfgsb_state)
scipy_lbfgsb_state.update(
grad=_to_numpy(latent_grad), value=onp.asarray(value)
grad=onp.asarray(flat_latent_grad, dtype=onp.float64),
value=onp.asarray(value, dtype=onp.float64),
)
latent_params = _to_pytree(scipy_lbfgsb_state.x, latent_grad)
return latent_params, scipy_lbfgsb_state.to_jax()
flat_latent_params = jnp.asarray(scipy_lbfgsb_state.x)
return flat_latent_params, scipy_lbfgsb_state.to_jax()

params, latent_params, jax_lbfgsb_state = state
_, latent_params, jax_lbfgsb_state = state
_, vjp_fn = jax.vjp(transform_fn, latent_params)
(latent_grad,) = vjp_fn(grad)
flat_latent_grad, unflatten_fn = flatten_util.ravel_pytree(
latent_grad
) # type: ignore[no-untyped-call]

(
latent_params,
flat_latent_params,
jax_lbfgsb_state,
) = jax.pure_callback( # type: ignore[attr-defined]
_update_pure,
(latent_params, jax_lbfgsb_state),
latent_grad,
(flat_latent_grad, jax_lbfgsb_state),
flat_latent_grad,
value,
jax_lbfgsb_state,
)
latent_params = unflatten_fn(flat_latent_params)
return transform_fn(latent_params), latent_params, jax_lbfgsb_state

return base.Optimizer(
Expand Down
2 changes: 1 addition & 1 deletion tests/lbfgsb/test_lbfgsb.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,7 +428,7 @@ def step_fn(state):
no_batch_values[-1].append(value)

onp.testing.assert_allclose(
batch_values, onp.transpose(no_batch_values, (1, 0))
batch_values, onp.transpose(no_batch_values, (1, 0)), atol=1e-4
)


Expand Down

0 comments on commit 8766e02

Please sign in to comment.