Skip to content

Commit

Permalink
split parameterization latents and metadata
Browse files Browse the repository at this point in the history
  • Loading branch information
Martin Schubert authored and Martin Schubert committed Aug 16, 2024
1 parent e1fc836 commit c9d904b
Show file tree
Hide file tree
Showing 8 changed files with 342 additions and 252 deletions.
180 changes: 103 additions & 77 deletions src/invrs_opt/optimizers/lbfgsb.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from invrs_opt.optimizers import base
from invrs_opt.parameterization import (
base as parameterization_base,
base as param_base,
filter_project,
gaussian_levelset,
pixel,
Expand Down Expand Up @@ -252,7 +252,7 @@ def levelset_lbfgsb(


def parameterized_lbfgsb(
density_parameterization: Optional[parameterization_base.Density2DParameterization],
density_parameterization: Optional[param_base.Density2DParameterization],
penalty: float,
maxcor: int = DEFAULT_MAXCOR,
line_search_max_steps: int = DEFAULT_LINE_SEARCH_MAX_STEPS,
Expand Down Expand Up @@ -296,59 +296,6 @@ def parameterized_lbfgsb(
if density_parameterization is None:
density_parameterization = pixel.pixel()

def _init_latents(params: PyTree) -> PyTree:
def _leaf_init_latents(leaf: Any) -> Any:
leaf = _clip(leaf)
if not _is_density(leaf) or density_parameterization is None:
return leaf
return density_parameterization.from_density(leaf)

return tree_util.tree_map(_leaf_init_latents, params, is_leaf=_is_custom_type)

def _params_from_latents(latent_params: PyTree) -> PyTree:
def _leaf_params_from_latents(leaf: Any) -> Any:
if not _is_parameterized_density(leaf) or density_parameterization is None:
return leaf
return density_parameterization.to_density(leaf)

return tree_util.tree_map(
_leaf_params_from_latents,
latent_params,
is_leaf=_is_parameterized_density,
)

def _constraint_loss(latent_params: PyTree) -> jnp.ndarray:
def _constraint_loss_leaf(
params: parameterization_base.ParameterizedDensity2DArrayBase,
) -> jnp.ndarray:
constraints = density_parameterization.constraints(params)
constraints = tree_util.tree_map(
lambda x: jnp.sum(jnp.maximum(x, 0.0) ** 2),
constraints,
)
return jnp.sum(jnp.asarray(constraints))

losses = [0.0] + [
_constraint_loss_leaf(p)
for p in tree_util.tree_leaves(
latent_params, is_leaf=_is_parameterized_density
)
if _is_parameterized_density(p)
]
return penalty * jnp.sum(jnp.asarray(losses))

def _update_parameterized_densities(latent_params: PyTree, step: int) -> PyTree:
def _update_leaf(leaf: Any) -> Any:
if not _is_parameterized_density(leaf):
return leaf
return density_parameterization.update(leaf, step)

return tree_util.tree_map(
_update_leaf,
latent_params,
is_leaf=_is_parameterized_density,
)

def init_fn(params: PyTree) -> LbfgsbState:
"""Initializes the optimization state."""

Expand All @@ -368,10 +315,17 @@ def _init_state_pure(latent_params: PyTree) -> Tuple[PyTree, JaxLbfgsbDict]:
return latent_params, scipy_lbfgsb_state.to_jax()

latent_params = _init_latents(params)
latent_params, jax_lbfgsb_state = jax.pure_callback(
_init_state_pure, _example_state(latent_params, maxcor), latent_params
metadata, latents = param_base.partition_density_metadata(latent_params)
latents, jax_lbfgsb_state = jax.pure_callback(
_init_state_pure, _example_state(latents, maxcor), latents
)
latent_params = param_base.combine_density_metadata(metadata, latents)
return (
0, # step
_params_from_latent_params(latent_params), # params
latent_params, # latent params
jax_lbfgsb_state, # opt state
)
return 0, _params_from_latents(latent_params), latent_params, jax_lbfgsb_state

def params_fn(state: LbfgsbState) -> PyTree:
"""Returns the parameters for the given `state`."""
Expand Down Expand Up @@ -403,46 +357,118 @@ def _update_pure(
return flat_latent_params, scipy_lbfgsb_state.to_jax()

step, _, latent_params, jax_lbfgsb_state = state
_, vjp_fn = jax.vjp(_params_from_latents, latent_params)
(latent_grad,) = vjp_fn(grad)
metadata, latents = param_base.partition_density_metadata(latent_params)

def _params_from_latents(latents: PyTree) -> PyTree:
latent_params = param_base.combine_density_metadata(metadata, latents)
return _params_from_latent_params(latent_params)

def _constraint_loss_latents(latents: PyTree) -> jnp.ndarray:
latent_params = param_base.combine_density_metadata(metadata, latents)
return _constraint_loss(latent_params)

_, vjp_fn = jax.vjp(_params_from_latents, latents)
(latents_grad,) = vjp_fn(grad)

if not (
tree_util.tree_structure(latent_grad)
== tree_util.tree_structure(latent_params) # type: ignore[operator]
tree_util.tree_structure(latents_grad)
== tree_util.tree_structure(latents) # type: ignore[operator]
):
raise ValueError(
f"Tree structure of `latent_grad` was different than expected, got \n"
f"{tree_util.tree_structure(latent_grad)} but expected \n"
f"{tree_util.tree_structure(latent_params)}."
f"Tree structure of `latents_grad` was different than expected, got \n"
f"{tree_util.tree_structure(latents_grad)} but expected \n"
f"{tree_util.tree_structure(latents)}."
)

(
constraint_loss_value,
constraint_loss_grad,
) = jax.value_and_grad(
_constraint_loss
)(latent_params)
_constraint_loss_latents
)(latents)
value += constraint_loss_value
latent_grad = tree_util.tree_map(
lambda a, b: a + b, latent_grad, constraint_loss_grad
latents_grad = tree_util.tree_map(
lambda a, b: a + b, latents_grad, constraint_loss_grad
)

flat_latent_grad, unflatten_fn = flatten_util.ravel_pytree(
latent_grad
flat_latents_grad, unflatten_fn = flatten_util.ravel_pytree(
latents_grad
) # type: ignore[no-untyped-call]

flat_latent_params, jax_lbfgsb_state = jax.pure_callback(
flat_latents, jax_lbfgsb_state = jax.pure_callback(
_update_pure,
(flat_latent_grad, jax_lbfgsb_state),
flat_latent_grad,
(flat_latents_grad, jax_lbfgsb_state),
flat_latents_grad,
value,
jax_lbfgsb_state,
)
latent_params = unflatten_fn(flat_latent_params)
latents = unflatten_fn(flat_latents)
latent_params = param_base.combine_density_metadata(metadata, latents)
latent_params = _update_parameterized_densities(latent_params, step)
params = _params_from_latents(latent_params)
params = _params_from_latent_params(latent_params)
return step + 1, params, latent_params, jax_lbfgsb_state

# -------------------------------------------------------------------------
# Functions related to the density parameterization.
# -------------------------------------------------------------------------

def _init_latents(params: PyTree) -> PyTree:
def _leaf_init_latents(leaf: Any) -> Any:
leaf = _clip(leaf)
if not _is_density(leaf) or density_parameterization is None:
return leaf
return density_parameterization.from_density(leaf)

return tree_util.tree_map(_leaf_init_latents, params, is_leaf=_is_custom_type)

def _params_from_latent_params(latent_params: PyTree) -> PyTree:
def _leaf_params_from_latents(leaf: Any) -> Any:
if not _is_parameterized_density(leaf) or density_parameterization is None:
return leaf
return density_parameterization.to_density(leaf)

return tree_util.tree_map(
_leaf_params_from_latents,
latent_params,
is_leaf=_is_parameterized_density,
)

def _update_parameterized_densities(latent_params: PyTree, step: int) -> PyTree:
def _update_leaf(leaf: Any) -> Any:
if not _is_parameterized_density(leaf):
return leaf
return density_parameterization.update(leaf, step)

return tree_util.tree_map(
_update_leaf,
latent_params,
is_leaf=_is_parameterized_density,
)

# -------------------------------------------------------------------------
# Functions related to the constraints to be minimized.
# -------------------------------------------------------------------------

def _constraint_loss(latent_params: PyTree) -> jnp.ndarray:
def _constraint_loss_leaf(
leaf: param_base.ParameterizedDensity2DArray,
) -> jnp.ndarray:
constraints = density_parameterization.constraints(leaf)
constraints = tree_util.tree_map(
lambda x: jnp.sum(jnp.maximum(x, 0.0) ** 2),
constraints,
)
return jnp.sum(jnp.asarray(constraints))

losses = [0.0] + [
_constraint_loss_leaf(p)
for p in tree_util.tree_leaves(
latent_params, is_leaf=_is_parameterized_density
)
if _is_parameterized_density(p)
]
return penalty * jnp.sum(jnp.asarray(losses))

return base.Optimizer(
init=init_fn,
params=params_fn,
Expand All @@ -467,7 +493,7 @@ def _is_density(leaf: Any) -> Any:

def _is_parameterized_density(leaf: Any) -> Any:
"""Return `True` if `leaf` is a parameterized density array."""
return isinstance(leaf, parameterization_base.ParameterizedDensity2DArrayBase)
return isinstance(leaf, param_base.ParameterizedDensity2DArray)


def _is_custom_type(leaf: Any) -> bool:
Expand Down
64 changes: 32 additions & 32 deletions src/invrs_opt/optimizers/wrapped_optax.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from invrs_opt.optimizers import base
from invrs_opt.parameterization import (
base as parameterization_base,
base as param_base,
filter_project,
gaussian_levelset,
pixel,
Expand Down Expand Up @@ -158,7 +158,7 @@ def levelset_wrapped_optax(

def parameterized_wrapped_optax(
opt: optax.GradientTransformation,
density_parameterization: Optional[parameterization_base.Density2DParameterization],
density_parameterization: Optional[param_base.Density2DParameterization],
penalty: float,
) -> base.Optimizer:
"""Wrapped optax optimizer with specified density parameterization.
Expand All @@ -181,11 +181,12 @@ def parameterized_wrapped_optax(
def init_fn(params: PyTree) -> WrappedOptaxState:
"""Initializes the optimization state."""
latent_params = _init_latents(params)
_, latents = param_base.partition_density_metadata(latent_params)
return (
0, # step
_params_from_latents(latent_params), # params
_params_from_latent_params(latent_params), # params
latent_params, # latent params
opt.init(tree_util.tree_leaves(latent_params)), # opt state
opt.init(latents), # opt state
)

def params_fn(state: WrappedOptaxState) -> PyTree:
Expand All @@ -204,42 +205,41 @@ def update_fn(
del value, params

step, params, latent_params, opt_state = state
metadata, latents = param_base.partition_density_metadata(latent_params)

_, vjp_fn = jax.vjp(_params_from_latents, latent_params)
(latent_grad,) = vjp_fn(grad)
def _params_from_latents(latents: PyTree) -> PyTree:
latent_params = param_base.combine_density_metadata(metadata, latents)
return _params_from_latent_params(latent_params)

def _constraint_loss_latents(latents: PyTree) -> jnp.ndarray:
latent_params = param_base.combine_density_metadata(metadata, latents)
return _constraint_loss(latent_params)

_, vjp_fn = jax.vjp(_params_from_latents, latents)
(latents_grad,) = vjp_fn(grad)

if not (
tree_util.tree_structure(latent_grad)
== tree_util.tree_structure(latent_params) # type: ignore[operator]
tree_util.tree_structure(latents_grad)
== tree_util.tree_structure(latents) # type: ignore[operator]
):
raise ValueError(
f"Tree structure of `latent_grad` was different than expected, got \n"
f"{tree_util.tree_structure(latent_grad)} but expected \n"
f"{tree_util.tree_structure(latent_params)}."
f"Tree structure of `latents_grad` was different than expected, got \n"
f"{tree_util.tree_structure(latents_grad)} but expected \n"
f"{tree_util.tree_structure(latents)}."
)

constraint_loss_grad = jax.grad(_constraint_loss)(latent_params)
latent_grad = tree_util.tree_map(
lambda a, b: a + b, latent_grad, constraint_loss_grad
constraint_loss_grad = jax.grad(_constraint_loss_latents)(latents)
latents_grad = tree_util.tree_map(
lambda a, b: a + b, latents_grad, constraint_loss_grad
)

updates_leaves, opt_state = opt.update(
updates=tree_util.tree_leaves(latent_grad),
state=opt_state,
params=tree_util.tree_leaves(latent_params),
)
latent_params_leaves = optax.apply_updates(
params=tree_util.tree_leaves(latent_params),
updates=updates_leaves,
)
latent_params = tree_util.tree_unflatten(
treedef=tree_util.tree_structure(latent_params),
leaves=latent_params_leaves,
)
updates, opt_state = opt.update(latents_grad, state=opt_state, params=latents)
latents = optax.apply_updates(params=latents, updates=updates)

latent_params = param_base.combine_density_metadata(metadata, latents)
latent_params = _clip(latent_params)
latent_params = _update_parameterized_densities(latent_params, step + 1)
params = _params_from_latents(latent_params)
params = _params_from_latent_params(latent_params)
return (step + 1, params, latent_params, opt_state)

# -------------------------------------------------------------------------
Expand All @@ -255,7 +255,7 @@ def _leaf_init_latents(leaf: Any) -> Any:

return tree_util.tree_map(_leaf_init_latents, params, is_leaf=_is_custom_type)

def _params_from_latents(params: PyTree) -> PyTree:
def _params_from_latent_params(params: PyTree) -> PyTree:
def _leaf_params_from_latents(leaf: Any) -> Any:
if not _is_parameterized_density(leaf):
return leaf
Expand Down Expand Up @@ -285,9 +285,9 @@ def _update_leaf(leaf: Any) -> Any:

def _constraint_loss(latent_params: PyTree) -> jnp.ndarray:
def _constraint_loss_leaf(
params: parameterization_base.ParameterizedDensity2DArrayBase,
leaf: param_base.ParameterizedDensity2DArray,
) -> jnp.ndarray:
constraints = density_parameterization.constraints(params)
constraints = density_parameterization.constraints(leaf)
constraints = tree_util.tree_map(
lambda x: jnp.sum(jnp.maximum(x, 0.0) ** 2),
constraints,
Expand All @@ -313,7 +313,7 @@ def _is_density(leaf: Any) -> Any:

def _is_parameterized_density(leaf: Any) -> Any:
"""Return `True` if `leaf` is a parameterized density array."""
return isinstance(leaf, parameterization_base.ParameterizedDensity2DArrayBase)
return isinstance(leaf, param_base.ParameterizedDensity2DArray)


def _is_custom_type(leaf: Any) -> bool:
Expand Down
Loading

0 comments on commit c9d904b

Please sign in to comment.