Skip to content

Commit

Permalink
Merge pull request #15 from invrs-io/less_pure
Browse files Browse the repository at this point in the history
Simplify pure_callback functions and improve latent initialization
  • Loading branch information
mfschubert authored Feb 6, 2024
2 parents bb1fcbd + 59be242 commit 4971830
Show file tree
Hide file tree
Showing 6 changed files with 108 additions and 33 deletions.
2 changes: 1 addition & 1 deletion .bumpversion.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[tool.bumpversion]
current_version = "v0.2.0"
current_version = "v0.3.0"
commit = true
commit_args = "--no-verify"
tag = true
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# invrs-opt - Optimization algorithms for inverse design
`v0.2.0`
`v0.3.0`

## Overview

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[project]

name = "invrs_opt"
version = "v0.2.0"
version = "v0.3.0"
description = "Algorithms for inverse design"
keywords = ["topology", "optimization", "jax", "inverse design"]
readme = "README.md"
Expand Down
2 changes: 1 addition & 1 deletion src/invrs_opt/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
Copyright (c) 2023 The INVRS-IO authors.
"""

__version__ = "v0.2.0"
__version__ = "v0.3.0"
__author__ = "Martin F. Schubert <[email protected]>"

from invrs_opt.lbfgsb.lbfgsb import density_lbfgsb as density_lbfgsb
Expand Down
95 changes: 66 additions & 29 deletions src/invrs_opt/lbfgsb/lbfgsb.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@
NDArray = onp.ndarray[Any, Any]
PyTree = Any
ElementwiseBound = Union[NDArray, Sequence[Optional[float]]]
LbfgsbState = Tuple[PyTree, Dict[str, jnp.ndarray]]
JaxLbfgsbDict = Dict[str, jnp.ndarray]
LbfgsbState = Tuple[PyTree, PyTree, JaxLbfgsbDict]


# Task message prefixes for the underlying L-BFGS-B implementation.
Expand Down Expand Up @@ -96,6 +97,7 @@ def fn(x):
maxcor=maxcor,
line_search_max_steps=line_search_max_steps,
transform_fn=lambda x: x,
initialize_latent_fn=lambda x: x,
)


Expand Down Expand Up @@ -127,11 +129,16 @@ def density_lbfgsb(

def transform_fn(tree: PyTree) -> PyTree:
return tree_util.tree_map(
lambda x: (
transform_density(x) if isinstance(x, types.Density2DArray) else x
),
lambda x: transform_density(x) if _is_density(x) else x,
tree,
is_leaf=lambda x: isinstance(x, types.CUSTOM_TYPES),
is_leaf=_is_density,
)

def initialize_latent_fn(tree: PyTree) -> PyTree:
return tree_util.tree_map(
lambda x: initialize_latent_density(x) if _is_density(x) else x,
tree,
is_leaf=_is_density,
)

def transform_density(density: types.Density2DArray) -> types.Density2DArray:
Expand All @@ -144,17 +151,29 @@ def transform_density(density: types.Density2DArray) -> types.Density2DArray:
)
return transform.apply_fixed_pixels(transformed)

def initialize_latent_density(
density: types.Density2DArray,
) -> types.Density2DArray:
array = transform.normalized_array_from_density(density)
array = jnp.clip(array, -1, 1)
array *= jnp.tanh(beta)
latent_array = jnp.arctanh(array) / beta
latent_array = transform.rescale_array_for_density(latent_array, density)
return dataclasses.replace(density, array=latent_array)

return transformed_lbfgsb(
maxcor=maxcor,
line_search_max_steps=line_search_max_steps,
transform_fn=transform_fn,
initialize_latent_fn=initialize_latent_fn,
)


def transformed_lbfgsb(
maxcor: int,
line_search_max_steps: int,
transform_fn: Callable[[PyTree], PyTree],
initialize_latent_fn: Callable[[PyTree], PyTree],
) -> base.Optimizer:
"""Construct an latent parameter L-BFGS-B optimizer.
Expand All @@ -169,6 +188,8 @@ def transformed_lbfgsb(
line_search_max_steps: The maximum number of steps in the line search.
transform_fn: Function which transforms the internal latent parameters to
the parameters returned by the optimizer.
initialize_latent_fn: Function which computes the initial latent parameters
given the initial parameters.
Returns:
The `base.Optimizer`.
Expand All @@ -188,7 +209,7 @@ def transformed_lbfgsb(
def init_fn(params: PyTree) -> LbfgsbState:
"""Initializes the optimization state."""

def _init_pure(params: PyTree) -> LbfgsbState:
def _init_pure(params: PyTree) -> Tuple[PyTree, JaxLbfgsbDict]:
lower_bound = types.extract_lower_bound(params)
upper_bound = types.extract_upper_bound(params)
scipy_lbfgsb_state = ScipyLbfgsbState.init(
Expand All @@ -199,16 +220,21 @@ def _init_pure(params: PyTree) -> LbfgsbState:
line_search_max_steps=line_search_max_steps,
)
latent_params = _to_pytree(scipy_lbfgsb_state.x, params)
params = transform_fn(latent_params)
return params, scipy_lbfgsb_state.to_jax()

return jax.pure_callback( # type: ignore[no-any-return, attr-defined]
_init_pure, _example_state(params, maxcor), params
return latent_params, scipy_lbfgsb_state.to_jax()

(
latent_params,
jax_lbfgsb_state,
) = jax.pure_callback( # type: ignore[attr-defined]
_init_pure,
_example_state(params, maxcor),
initialize_latent_fn(params),
)
return transform_fn(latent_params), latent_params, jax_lbfgsb_state

def params_fn(state: LbfgsbState) -> PyTree:
"""Returns the parameters for the given `state`."""
params, _ = state
params, _, _ = state
return params

def update_fn(
Expand All @@ -219,30 +245,36 @@ def update_fn(
state: LbfgsbState,
) -> LbfgsbState:
"""Updates the state."""
del params

def _update_pure(
grad: PyTree, value: float, params: PyTree, state: LbfgsbState
) -> LbfgsbState:
del params

params, jax_lbfgsb_state = state
scipy_lbfgsb_state = ScipyLbfgsbState.from_jax(jax_lbfgsb_state)

latent_params = _to_pytree(scipy_lbfgsb_state.x, params)
_, vjp_fn = jax.vjp(transform_fn, latent_params)
(latent_grad,) = vjp_fn(grad)

latent_grad: PyTree,
value: jnp.ndarray,
jax_lbfgsb_state: JaxLbfgsbDict,
) -> Tuple[PyTree, JaxLbfgsbDict]:
assert onp.size(value) == 1
scipy_lbfgsb_state = ScipyLbfgsbState.from_jax(jax_lbfgsb_state)
scipy_lbfgsb_state.update(
grad=_to_numpy(latent_grad), value=onp.asarray(value)
)
latent_params = _to_pytree(scipy_lbfgsb_state.x, params)
params = transform_fn(latent_params)
return params, scipy_lbfgsb_state.to_jax()

return jax.pure_callback( # type: ignore[no-any-return, attr-defined]
_update_pure, state, grad, value, params, state
latent_params = _to_pytree(scipy_lbfgsb_state.x, latent_grad)
return latent_params, scipy_lbfgsb_state.to_jax()

params, latent_params, jax_lbfgsb_state = state
_, vjp_fn = jax.vjp(transform_fn, latent_params)
(latent_grad,) = vjp_fn(grad)

(
latent_params,
jax_lbfgsb_state,
) = jax.pure_callback( # type: ignore[attr-defined]
_update_pure,
(latent_params, jax_lbfgsb_state),
latent_grad,
value,
jax_lbfgsb_state,
)
return transform_fn(latent_params), latent_params, jax_lbfgsb_state

return base.Optimizer(
init=init_fn,
Expand All @@ -256,6 +288,11 @@ def _update_pure(
# ------------------------------------------------------------------------------


def _is_density(leaf: Any) -> Any:
"""Return `True` if `leaf` is a density array."""
return isinstance(leaf, types.Density2DArray)


def _to_numpy(params: PyTree) -> NDArray:
"""Flattens a `params` pytree into a single rank-1 numpy array."""
x, _ = flatten_util.ravel_pytree(params) # type: ignore[no-untyped-call]
Expand Down
38 changes: 38 additions & 0 deletions tests/lbfgsb/test_lbfgsb.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,44 @@ def loss_fn(density):
onp.testing.assert_allclose(params.array, expected)


class DensityLbfgsbInitializeTest(unittest.TestCase):
@parameterized.expand(
[
[-1, 1, -0.95],
[-1, 1, -0.50],
[-1, 1, 0.00],
[-1, 1, 0.50],
[-1, 1, 0.95],
[0, 1, 0.05],
[0, 1, 0.25],
[0, 1, 0.00],
[0, 1, 0.25],
[0, 1, 0.95],
]
)
def test_initial_params_match_expected(self, lb, ub, value):
density = types.Density2DArray(
array=jnp.full((10, 10), value),
lower_bound=lb,
upper_bound=ub,
)
opt = lbfgsb.density_lbfgsb(beta=4)
state = opt.init(density)
params = opt.params(state)
onp.testing.assert_allclose(density.array, params.array, atol=1e-2)

def test_initial_params_out_of_bounds(self):
density = types.Density2DArray(
array=jnp.full((10, 10), 10),
lower_bound=-1,
upper_bound=1,
)
opt = lbfgsb.density_lbfgsb(beta=4)
state = opt.init(density)
params = opt.params(state)
onp.testing.assert_allclose(params.array, onp.ones_like(params.array))


class LbfgsbBoundsTest(unittest.TestCase):
@parameterized.expand([[-1, 1, 1], [-1, 1, -1], [0, 1, 1], [0, 1, -1]])
def test_respects_bounds(self, lower_bound, upper_bound, sign):
Expand Down

0 comments on commit 4971830

Please sign in to comment.