Skip to content

Commit

Permalink
Merge pull request #23 from invrs-io/optax
Browse files Browse the repository at this point in the history
Wrapped optax optimizer
  • Loading branch information
mfschubert authored Jun 10, 2024
2 parents b1301c4 + 3b94e93 commit 00b4f87
Show file tree
Hide file tree
Showing 13 changed files with 559 additions and 10 deletions.
2 changes: 1 addition & 1 deletion .bumpversion.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[tool.bumpversion]
current_version = "v0.5.2"
current_version = "v0.6.0"
commit = true
commit_args = "--no-verify"
tag = true
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# invrs-opt - Optimization algorithms for inverse design
`v0.5.2`
`v0.6.0`

## Overview

Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[project]

name = "invrs_opt"
version = "v0.5.2"
version = "v0.6.0"
description = "Algorithms for inverse design"
keywords = ["topology", "optimization", "jax", "inverse design"]
readme = "README.md"
Expand All @@ -20,6 +20,7 @@ dependencies = [
"jaxlib",
"numpy",
"requests",
"optax",
"scipy",
"totypes",
"types-requests",
Expand Down
6 changes: 5 additions & 1 deletion src/invrs_opt/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,12 @@
Copyright (c) 2023 The INVRS-IO authors.
"""

__version__ = "v0.5.2"
__version__ = "v0.6.0"
__author__ = "Martin F. Schubert <[email protected]>"

from invrs_opt.lbfgsb.lbfgsb import density_lbfgsb as density_lbfgsb
from invrs_opt.lbfgsb.lbfgsb import lbfgsb as lbfgsb
from invrs_opt.wrapped_optax.wrapped_optax import (
density_wrapped_optax as density_wrapped_optax,
)
from invrs_opt.wrapped_optax.wrapped_optax import wrapped_optax as wrapped_optax
8 changes: 8 additions & 0 deletions src/invrs_opt/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
import dataclasses
from typing import Any, Protocol

import optax # type: ignore[import-untyped]
from totypes import json_utils

PyTree = Any


Expand Down Expand Up @@ -44,3 +47,8 @@ class Optimizer:
init: InitFn
params: ParamsFn
update: UpdateFn


# TODO: consider programatically registering all optax states here.
json_utils.register_custom_type(optax.EmptyState)
json_utils.register_custom_type(optax.ScaleByAdamState)
3 changes: 1 addition & 2 deletions src/invrs_opt/experimental/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,15 @@
"""

import json
import requests
import time
from typing import Any, Dict, Optional

import requests
from totypes import json_utils

from invrs_opt import base
from invrs_opt.experimental import labels


PyTree = Any
StateToken = str

Expand Down
3 changes: 1 addition & 2 deletions src/invrs_opt/lbfgsb/lbfgsb.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,7 @@
)
from totypes import types

from invrs_opt import base
from invrs_opt.lbfgsb import transform
from invrs_opt import base, transform

NDArray = onp.ndarray[Any, Any]
PyTree = Any
Expand Down
File renamed without changes.
Empty file.
150 changes: 150 additions & 0 deletions src/invrs_opt/wrapped_optax/wrapped_optax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
import dataclasses
from typing import Any, Callable, Tuple

import jax
import jax.numpy as jnp
import optax # type: ignore[import-untyped]
from jax import tree_util
from totypes import types

from invrs_opt import base, transform

PyTree = Any
WrappedOptaxState = Tuple[PyTree, PyTree, PyTree]


def wrapped_optax(opt: optax.GradientTransformation) -> base.Optimizer:
"""Return a wrapped optax optimizer."""
return transformed_wrapped_optax(
opt=opt,
transform_fn=lambda x: x,
initialize_latent_fn=lambda x: x,
)


def density_wrapped_optax(
opt: optax.GradientTransformation,
beta: float,
) -> base.Optimizer:
"""Return a wrapped optax optimizer with transforms for density arrays."""

def transform_fn(tree: PyTree) -> PyTree:
return tree_util.tree_map(
lambda x: transform_density(x) if _is_density(x) else x,
tree,
is_leaf=_is_density,
)

def initialize_latent_fn(tree: PyTree) -> PyTree:
return tree_util.tree_map(
lambda x: initialize_latent_density(x) if _is_density(x) else x,
tree,
is_leaf=_is_density,
)

def transform_density(density: types.Density2DArray) -> types.Density2DArray:
transformed = types.symmetrize_density(density)
transformed = transform.density_gaussian_filter_and_tanh(transformed, beta=beta)
# Scale to ensure that the full valid range of the density array is reachable.
mid_value = (density.lower_bound + density.upper_bound) / 2
transformed = tree_util.tree_map(
lambda array: mid_value + (array - mid_value) / jnp.tanh(beta), transformed
)
return transform.apply_fixed_pixels(transformed)

def initialize_latent_density(
density: types.Density2DArray,
) -> types.Density2DArray:
array = transform.normalized_array_from_density(density)
array = jnp.clip(array, -1, 1)
array *= jnp.tanh(beta)
latent_array = jnp.arctanh(array) / beta
latent_array = transform.rescale_array_for_density(latent_array, density)
return dataclasses.replace(density, array=latent_array)

return transformed_wrapped_optax(
opt=opt,
transform_fn=transform_fn,
initialize_latent_fn=initialize_latent_fn,
)


def transformed_wrapped_optax(
opt: optax.GradientTransformation,
transform_fn: Callable[[PyTree], PyTree],
initialize_latent_fn: Callable[[PyTree], PyTree],
) -> base.Optimizer:
"""Return a wrapped optax optimizer for transformed latent parameters.
Args:
opt: The optax `GradientTransformation` to be wrapped.
transform_fn: Function which transforms the internal latent parameters to
the parameters returned by the optimizer.
initialize_latent_fn: Function which computes the initial latent parameters
given the initial parameters.
Returns:
The `base.Optimizer`.
"""

def init_fn(params: PyTree) -> WrappedOptaxState:
"""Initializes the optimization state."""
latent_params = initialize_latent_fn(_clip(params))
params = transform_fn(latent_params)
return params, latent_params, opt.init(latent_params)

def params_fn(state: WrappedOptaxState) -> PyTree:
"""Returns the parameters for the given `state`."""
params, _, _ = state
return params

def update_fn(
*,
grad: PyTree,
value: float,
params: PyTree,
state: WrappedOptaxState,
) -> WrappedOptaxState:
"""Updates the state."""
del value

_, latent_params, opt_state = state
_, vjp_fn = jax.vjp(transform_fn, latent_params)
(latent_grad,) = vjp_fn(grad)

updates, opt_state = opt.update(latent_grad, opt_state)
latent_params = optax.apply_updates(params=latent_params, updates=updates)
latent_params = _clip(latent_params)
params = transform_fn(latent_params)
return params, latent_params, opt_state

return base.Optimizer(
init=init_fn,
params=params_fn,
update=update_fn,
)


def _is_density(leaf: Any) -> Any:
"""Return `True` if `leaf` is a density array."""
return isinstance(leaf, types.Density2DArray)


def _is_custom_type(leaf: Any) -> bool:
"""Return `True` if `leaf` is a recognized custom type."""
return isinstance(leaf, (types.BoundedArray, types.Density2DArray))


def _clip(pytree: PyTree) -> PyTree:
"""Clips leaves on `pytree` to their bounds."""

def _clip_fn(leaf: Any) -> Any:
if not _is_custom_type(leaf):
return leaf
if leaf.lower_bound is None and leaf.upper_bound is None:
return leaf
return tree_util.tree_map(
lambda x: jnp.clip(x, leaf.lower_bound, leaf.upper_bound), leaf
)

return tree_util.tree_map(_clip_fn, pytree, is_leaf=_is_custom_type)
3 changes: 3 additions & 0 deletions tests/test_algos.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import jax
import jax.numpy as jnp
import numpy as onp
import optax
import parameterized
from totypes import json_utils, symmetry, types

Expand All @@ -21,6 +22,8 @@
OPTIMIZERS = [
invrs_opt.lbfgsb(maxcor=20, line_search_max_steps=100),
invrs_opt.density_lbfgsb(maxcor=20, line_search_max_steps=100, beta=2.0),
invrs_opt.wrapped_optax(optax.adam(1e-2)),
invrs_opt.density_wrapped_optax(optax.adam(1e-2), beta=2.0),
]

# Various parameter combinations tested in this module.
Expand Down
4 changes: 2 additions & 2 deletions tests/lbfgsb/test_transform.py → tests/test_transform.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Tests for `lbfgsb.transforms`.
"""Tests for `transforms`.
Copyright (c) 2023 The INVRS-IO authors.
"""
Expand All @@ -12,7 +12,7 @@
from parameterized import parameterized
from totypes import types

from invrs_opt.lbfgsb import transform
from invrs_opt import transform


class GaussianFilterTest(unittest.TestCase):
Expand Down
Loading

0 comments on commit 00b4f87

Please sign in to comment.