From 8c44b0c5256301a989d6bdf5d32e3414cfb0b832 Mon Sep 17 00:00:00 2001 From: Martin Schubert Date: Mon, 10 Jun 2024 13:50:37 -0700 Subject: [PATCH 1/4] Wrapped optax optimizer --- src/invrs_opt/__init__.py | 4 + src/invrs_opt/base.py | 8 + src/invrs_opt/experimental/client.py | 3 +- src/invrs_opt/lbfgsb/lbfgsb.py | 7 +- src/invrs_opt/{lbfgsb => }/transform.py | 0 src/invrs_opt/wrapped_optax/__init__.py | 0 src/invrs_opt/wrapped_optax/wrapped_optax.py | 150 ++++++++ tests/test_algos.py | 3 + tests/{lbfgsb => }/test_transform.py | 4 +- tests/wrapped_optax/test_wrapped_optax.py | 385 +++++++++++++++++++ 10 files changed, 556 insertions(+), 8 deletions(-) rename src/invrs_opt/{lbfgsb => }/transform.py (100%) create mode 100644 src/invrs_opt/wrapped_optax/__init__.py create mode 100644 src/invrs_opt/wrapped_optax/wrapped_optax.py rename tests/{lbfgsb => }/test_transform.py (99%) create mode 100644 tests/wrapped_optax/test_wrapped_optax.py diff --git a/src/invrs_opt/__init__.py b/src/invrs_opt/__init__.py index 5f0cc44..f091cae 100644 --- a/src/invrs_opt/__init__.py +++ b/src/invrs_opt/__init__.py @@ -8,3 +8,7 @@ from invrs_opt.lbfgsb.lbfgsb import density_lbfgsb as density_lbfgsb from invrs_opt.lbfgsb.lbfgsb import lbfgsb as lbfgsb +from invrs_opt.wrapped_optax.wrapped_optax import ( + density_wrapped_optax as density_wrapped_optax, +) +from invrs_opt.wrapped_optax.wrapped_optax import wrapped_optax as wrapped_optax diff --git a/src/invrs_opt/base.py b/src/invrs_opt/base.py index 1c12cb4..1e6c975 100644 --- a/src/invrs_opt/base.py +++ b/src/invrs_opt/base.py @@ -6,6 +6,9 @@ import dataclasses from typing import Any, Protocol +import optax +from totypes import json_utils + PyTree = Any @@ -44,3 +47,8 @@ class Optimizer: init: InitFn params: ParamsFn update: UpdateFn + + +# TODO: consider programatically registering all optax states here. +json_utils.register_custom_type(optax.EmptyState) +json_utils.register_custom_type(optax.ScaleByAdamState) diff --git a/src/invrs_opt/experimental/client.py b/src/invrs_opt/experimental/client.py index a8b7066..18bf790 100644 --- a/src/invrs_opt/experimental/client.py +++ b/src/invrs_opt/experimental/client.py @@ -4,16 +4,15 @@ """ import json -import requests import time from typing import Any, Dict, Optional +import requests from totypes import json_utils from invrs_opt import base from invrs_opt.experimental import labels - PyTree = Any StateToken = str diff --git a/src/invrs_opt/lbfgsb/lbfgsb.py b/src/invrs_opt/lbfgsb/lbfgsb.py index 871ccd1..5ec2f1b 100644 --- a/src/invrs_opt/lbfgsb/lbfgsb.py +++ b/src/invrs_opt/lbfgsb/lbfgsb.py @@ -11,13 +11,12 @@ import jax.numpy as jnp import numpy as onp from jax import flatten_util, tree_util -from scipy.optimize._lbfgsb_py import ( # type: ignore[import-untyped] - _lbfgsb as scipy_lbfgsb, +from scipy.optimize._lbfgsb_py import ( + _lbfgsb as scipy_lbfgsb, # type: ignore[import-untyped] ) from totypes import types -from invrs_opt import base -from invrs_opt.lbfgsb import transform +from invrs_opt import base, transform NDArray = onp.ndarray[Any, Any] PyTree = Any diff --git a/src/invrs_opt/lbfgsb/transform.py b/src/invrs_opt/transform.py similarity index 100% rename from src/invrs_opt/lbfgsb/transform.py rename to src/invrs_opt/transform.py diff --git a/src/invrs_opt/wrapped_optax/__init__.py b/src/invrs_opt/wrapped_optax/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/invrs_opt/wrapped_optax/wrapped_optax.py b/src/invrs_opt/wrapped_optax/wrapped_optax.py new file mode 100644 index 0000000..63ee6a9 --- /dev/null +++ b/src/invrs_opt/wrapped_optax/wrapped_optax.py @@ -0,0 +1,150 @@ +import dataclasses +from typing import Any, Callable, Tuple + +import jax +import jax.numpy as jnp +import optax +from jax import tree_util +from totypes import types + +from invrs_opt import base, transform + +PyTree = Any +WrappedOptaxState = Tuple[PyTree, PyTree, PyTree] + + +def wrapped_optax(opt: optax.GradientTransformation) -> base.Optimizer: + """Return a wrapped optax optimizer.""" + return transformed_wrapped_optax( + opt=opt, + transform_fn=lambda x: x, + initialize_latent_fn=lambda x: x, + ) + + +def density_wrapped_optax( + opt: optax.GradientTransformation, + beta: float, +) -> base.Optimizer: + """Return a wrapped optax optimizer with transforms for density arrays.""" + + def transform_fn(tree: PyTree) -> PyTree: + return tree_util.tree_map( + lambda x: transform_density(x) if _is_density(x) else x, + tree, + is_leaf=_is_density, + ) + + def initialize_latent_fn(tree: PyTree) -> PyTree: + return tree_util.tree_map( + lambda x: initialize_latent_density(x) if _is_density(x) else x, + tree, + is_leaf=_is_density, + ) + + def transform_density(density: types.Density2DArray) -> types.Density2DArray: + transformed = types.symmetrize_density(density) + transformed = transform.density_gaussian_filter_and_tanh(transformed, beta=beta) + # Scale to ensure that the full valid range of the density array is reachable. + mid_value = (density.lower_bound + density.upper_bound) / 2 + transformed = tree_util.tree_map( + lambda array: mid_value + (array - mid_value) / jnp.tanh(beta), transformed + ) + return transform.apply_fixed_pixels(transformed) + + def initialize_latent_density( + density: types.Density2DArray, + ) -> types.Density2DArray: + array = transform.normalized_array_from_density(density) + array = jnp.clip(array, -1, 1) + array *= jnp.tanh(beta) + latent_array = jnp.arctanh(array) / beta + latent_array = transform.rescale_array_for_density(latent_array, density) + return dataclasses.replace(density, array=latent_array) + + return transformed_wrapped_optax( + opt=opt, + transform_fn=transform_fn, + initialize_latent_fn=initialize_latent_fn, + ) + + +def transformed_wrapped_optax( + opt: optax.GradientTransformation, + transform_fn: Callable[[PyTree], PyTree], + initialize_latent_fn: Callable[[PyTree], PyTree], +) -> base.Optimizer: + """Return a wrapped optax optimizer for transformed latent parameters. + + Args: + opt: The optax `GradientTransformation` to be wrapped. + transform_fn: Function which transforms the internal latent parameters to + the parameters returned by the optimizer. + initialize_latent_fn: Function which computes the initial latent parameters + given the initial parameters. + + Returns: + The `base.Optimizer`. + """ + + def init_fn(params: PyTree) -> WrappedOptaxState: + """Initializes the optimization state.""" + latent_params = initialize_latent_fn(_clip(params)) + params = transform_fn(latent_params) + return params, latent_params, opt.init(latent_params) + + def params_fn(state: WrappedOptaxState) -> PyTree: + """Returns the parameters for the given `state`.""" + params, _, _ = state + return params + + def update_fn( + *, + grad: PyTree, + value: float, + params: PyTree, + state: WrappedOptaxState, + ) -> WrappedOptaxState: + """Updates the state.""" + del value + + _, latent_params, opt_state = state + _, vjp_fn = jax.vjp(transform_fn, latent_params) + (latent_grad,) = vjp_fn(grad) + + updates, opt_state = opt.update(latent_grad, opt_state) + latent_params = optax.apply_updates(params=latent_params, updates=updates) + latent_params = _clip(latent_params) + params = transform_fn(latent_params) + return params, latent_params, opt_state + + return base.Optimizer( + init=init_fn, + params=params_fn, + update=update_fn, + ) + + +def _is_density(leaf: Any) -> Any: + """Return `True` if `leaf` is a density array.""" + return isinstance(leaf, types.Density2DArray) + + +def _is_custom_type(leaf: Any) -> bool: + """Return `True` if `leaf` is a recognized custom type.""" + return isinstance(leaf, (types.BoundedArray, types.Density2DArray)) + + +def _clip(pytree: PyTree) -> PyTree: + """Clips leaves on `pytree` to their bounds.""" + + def _clip_fn(leaf: Any) -> Any: + if not _is_custom_type(leaf): + return leaf + if leaf.lower_bound is None and leaf.upper_bound is None: + return leaf + return tree_util.tree_map( + lambda x: jnp.clip(x, leaf.lower_bound, leaf.upper_bound), leaf + ) + + return tree_util.tree_map(_clip_fn, pytree, is_leaf=_is_custom_type) diff --git a/tests/test_algos.py b/tests/test_algos.py index 5fb8c8c..b0c6ff4 100644 --- a/tests/test_algos.py +++ b/tests/test_algos.py @@ -9,6 +9,7 @@ import jax import jax.numpy as jnp import numpy as onp +import optax import parameterized from totypes import json_utils, symmetry, types @@ -21,6 +22,8 @@ OPTIMIZERS = [ invrs_opt.lbfgsb(maxcor=20, line_search_max_steps=100), invrs_opt.density_lbfgsb(maxcor=20, line_search_max_steps=100, beta=2.0), + invrs_opt.wrapped_optax(optax.adam(1e-2)), + invrs_opt.density_wrapped_optax(optax.adam(1e-2), beta=2.0), ] # Various parameter combinations tested in this module. diff --git a/tests/lbfgsb/test_transform.py b/tests/test_transform.py similarity index 99% rename from tests/lbfgsb/test_transform.py rename to tests/test_transform.py index 6ea002c..10b5308 100644 --- a/tests/lbfgsb/test_transform.py +++ b/tests/test_transform.py @@ -1,4 +1,4 @@ -"""Tests for `lbfgsb.transforms`. +"""Tests for `transforms`. Copyright (c) 2023 The INVRS-IO authors. """ @@ -12,7 +12,7 @@ from parameterized import parameterized from totypes import types -from invrs_opt.lbfgsb import transform +from invrs_opt import transform class GaussianFilterTest(unittest.TestCase): diff --git a/tests/wrapped_optax/test_wrapped_optax.py b/tests/wrapped_optax/test_wrapped_optax.py new file mode 100644 index 0000000..829a3d8 --- /dev/null +++ b/tests/wrapped_optax/test_wrapped_optax.py @@ -0,0 +1,385 @@ +"""Defines tests for the `wrapped_optax.wrapped_optax` module. + +Copyright (c) 2023 The INVRS-IO authors. +""" + +import dataclasses +import unittest + +import jax +import jax.numpy as jnp +import numpy as onp +import optax +from jax import flatten_util, tree_util +from parameterized import parameterized +from totypes import types + +from invrs_opt import transform +from invrs_opt.wrapped_optax import wrapped_optax + + +class DensityWrappedOptaxBoundsTest(unittest.TestCase): + @parameterized.expand([[-1, 1, 1], [-1, 1, -1], [0, 1, 1], [0, 1, -1]]) + def test_respects_bounds(self, lower_bound, upper_bound, sign): + def loss_fn(density): + return sign * jnp.sum(density.array) + + params = types.Density2DArray( + array=jnp.ones((5, 5)) * (lower_bound + upper_bound) / 2, + lower_bound=lower_bound, + upper_bound=upper_bound, + ) + opt = wrapped_optax.density_wrapped_optax(opt=optax.adam(0.1), beta=2) + state = opt.init(params) + for _ in range(20): + params = opt.params(state) + value, grad = jax.value_and_grad(loss_fn)(params) + state = opt.update(grad=grad, value=value, params=params, state=state) + + params = opt.params(state) + expected = upper_bound if sign < 0 else lower_bound + onp.testing.assert_allclose(params.array, expected, atol=1e-5) + + +class DensityWrappedOptaxInitializeTest(unittest.TestCase): + @parameterized.expand( + [ + [-1, 1, -0.95], + [-1, 1, -0.50], + [-1, 1, 0.00], + [-1, 1, 0.50], + [-1, 1, 0.95], + [0, 1, 0.05], + [0, 1, 0.25], + [0, 1, 0.00], + [0, 1, 0.25], + [0, 1, 0.95], + ] + ) + def test_initial_params_match_expected(self, lb, ub, value): + density = types.Density2DArray( + array=jnp.full((10, 10), value), + lower_bound=lb, + upper_bound=ub, + ) + opt = wrapped_optax.density_wrapped_optax(optax.adam(0.01), beta=4) + state = opt.init(density) + params = opt.params(state) + onp.testing.assert_allclose(density.array, params.array, atol=1e-2) + + def test_initial_params_out_of_bounds(self): + density = types.Density2DArray( + array=jnp.full((10, 10), 10), + lower_bound=-1, + upper_bound=1, + ) + opt = wrapped_optax.density_wrapped_optax(optax.adam(0.01), beta=4) + state = opt.init(density) + params = opt.params(state) + onp.testing.assert_allclose(params.array, onp.ones_like(params.array)) + + +class WrappedOptaxBoundsTest(unittest.TestCase): + @parameterized.expand([[-1, 1, 1], [-1, 1, -1], [0, 1, 1], [0, 1, -1]]) + def test_respects_bounds(self, lower_bound, upper_bound, sign): + def loss_fn(density): + return sign * jnp.sum(density.array) + + params = types.Density2DArray( + array=jnp.ones((5, 5)) * (lower_bound + upper_bound) / 2, + lower_bound=lower_bound, + upper_bound=upper_bound, + ) + opt = wrapped_optax.wrapped_optax(optax.adam(0.1)) + state = opt.init(params) + for _ in range(10): + params = opt.params(state) + value, grad = jax.value_and_grad(loss_fn)(params) + state = opt.update(grad=grad, value=value, params=params, state=state) + + params = opt.params(state) + expected = upper_bound if sign < 0 else lower_bound + onp.testing.assert_allclose(params.array, expected, atol=1e-5) + + +class WrappedOptaxTest(unittest.TestCase): + def test_trajectory_matches_optax_bounded_array(self): + initial_params = { + "a": jnp.asarray([1.0, 2.0]), + "b": types.BoundedArray( + jnp.asarray([3.0, 4.0]), lower_bound=2.0, upper_bound=None + ), + "c": types.BoundedArray( + jnp.asarray([3.0, 4.0]), lower_bound=2.0, upper_bound=5.0 + ), + "d": types.BoundedArray( + jnp.asarray([3.0, 4.0]), lower_bound=None, upper_bound=5.0 + ), + } + + def loss_fn(params): + x, _ = flatten_util.ravel_pytree(params) + return jnp.sum(x**2) + + # Carry out optimization directly + opt = optax.adam(1e-2) + params = wrapped_optax._clip(initial_params) + state = opt.init(params) + + expected_values = [] + for _ in range(10): + value, grad = jax.value_and_grad(loss_fn)(params) + updates, state = opt.update(grad, state, params=params) + params = wrapped_optax._clip(optax.apply_updates(params, updates=updates)) + expected_values.append(value) + + # Carry out optimization using the wrapped optimizer. + wrapped_opt = wrapped_optax.wrapped_optax(opt) + state = wrapped_opt.init(initial_params) + + values = [] + for _ in range(10): + params = wrapped_opt.params(state) + value, grad = jax.value_and_grad(loss_fn)(params) + state = wrapped_opt.update( + grad=grad, value=value, state=state, params=params + ) + values.append(value) + + onp.testing.assert_array_equal(values, expected_values) + + def test_trajectory_matches_scipy_density_2d(self): + initial_params = { + "a": jnp.asarray([1.0, 2.0]), + "b": types.BoundedArray( + jnp.asarray([3.0, 4.0]), lower_bound=2.0, upper_bound=None + ), + "c": types.BoundedArray( + jnp.asarray([3.0, 4.0]), lower_bound=2.0, upper_bound=5.0 + ), + "d": types.BoundedArray( + jnp.asarray([3.0, 4.0]), lower_bound=None, upper_bound=5.0 + ), + "density": types.Density2DArray(jnp.arange(20, dtype=float).reshape(4, 5)), + } + beta = 2.0 + + def loss_fn(params): + x, _ = flatten_util.ravel_pytree(params) + return jnp.sum(x**2) + + def latent_loss_fn(params): + params["density"] = transform_density(params["density"]) + return loss_fn(params) + + def transform_density(density): + transformed = types.symmetrize_density(density) + transformed = transform.density_gaussian_filter_and_tanh( + transformed, beta=beta + ) + # Scale to ensure that full valid range of the density array is reachable. + mid_value = (density.lower_bound + density.upper_bound) / 2 + transformed = tree_util.tree_map( + lambda array: mid_value + (array - mid_value) / jnp.tanh(beta), + transformed, + ) + return transform.apply_fixed_pixels(transformed) + + def initialize_latent_density(density) -> types.Density2DArray: + array = transform.normalized_array_from_density(density) + array = jnp.clip(array, -1, 1) + array *= jnp.tanh(beta) + latent_array = jnp.arctanh(array) / beta + latent_array = transform.rescale_array_for_density(latent_array, density) + return dataclasses.replace(density, array=latent_array) + + # Carry out optimization directly + opt = optax.adam(1e-2) + params = wrapped_optax._clip(initial_params) + params["density"] = initialize_latent_density(params["density"]) + state = opt.init(params) + + expected_values = [] + for _ in range(10): + value, grad = jax.value_and_grad(latent_loss_fn)(params) + updates, state = opt.update(grad, state, params=params) + params = wrapped_optax._clip(optax.apply_updates(params, updates=updates)) + expected_values.append(value) + + # Carry out optimization using the wrapped optimizer. + wrapped_opt = wrapped_optax.density_wrapped_optax(opt, beta=beta) + state = wrapped_opt.init(initial_params) + + values = [] + for _ in range(10): + params = wrapped_opt.params(state) + value, grad = jax.value_and_grad(loss_fn)(params) + state = wrapped_opt.update( + grad=grad, value=value, state=state, params=params + ) + values.append(value) + + onp.testing.assert_array_equal(values, expected_values) + + @parameterized.expand( + [ + [2.0], + [jnp.ones((3,))], + [ + types.BoundedArray( + array=jnp.ones((3,)), + lower_bound=0.0, + upper_bound=1.0, + ) + ], + [ + types.BoundedArray( + array=jnp.ones((3,)), + lower_bound=None, + upper_bound=1.0, + ) + ], + [ + types.BoundedArray( + array=jnp.ones((3,)), + lower_bound=None, + upper_bound=None, + ) + ], + [ + types.BoundedArray( + array=jnp.ones((3,)), + lower_bound=None, + upper_bound=None, + ) + ], + [ + types.BoundedArray( + array=jnp.ones((3,)), + lower_bound=jnp.zeros((3,)), + upper_bound=jnp.ones((3,)), + ) + ], + [ + types.Density2DArray( + array=jnp.ones((3, 3)), + lower_bound=0.0, + upper_bound=1.0, + fixed_solid=None, + fixed_void=None, + minimum_width=1, + minimum_spacing=1, + ) + ], + [ + { + "a": types.Density2DArray( + array=jnp.ones((3, 3)), + lower_bound=0.0, + upper_bound=1.0, + fixed_solid=None, + fixed_void=None, + minimum_width=1, + minimum_spacing=1, + ), + "b": types.BoundedArray( + array=jnp.ones((3,)), + lower_bound=jnp.zeros((3,)), + upper_bound=jnp.ones((3,)), + ), + } + ], + [ + { + "a": types.Density2DArray( + array=jnp.ones((3, 3)), + lower_bound=0.0, + upper_bound=1.0, + fixed_solid=None, + fixed_void=None, + minimum_width=1, + minimum_spacing=1, + ), + "b": None, + } + ], + ] + ) + def test_initialize(self, params): + opt = wrapped_optax.density_wrapped_optax(optax.adam(0.1), beta=2.0) + state = opt.init(params) + params = opt.params(state) + dummy_grad = jax.tree_util.tree_map(jnp.zeros_like, params) + state = opt.update(value=0.0, params=params, grad=dummy_grad, state=state) + + def test_density_wrapped_optax_reaches_bounds(self): + def loss_fn(density): + return jnp.sum(jnp.abs(density.array - 1) ** 2) + + opt = wrapped_optax.density_wrapped_optax(optax.adam(0.1), beta=2.0) + + density = types.Density2DArray( + array=jnp.zeros((3, 3)), + lower_bound=0.0, + upper_bound=1.0, + fixed_solid=None, + fixed_void=None, + minimum_width=1, + minimum_spacing=1, + ) + state = opt.init(density) + for _ in range(20): + density = opt.params(state) + value, grad = jax.value_and_grad(loss_fn)(density) + state = opt.update(value=value, grad=grad, params=density, state=state) + + onp.testing.assert_allclose(density.array, 1.0) + + def test_optimization_with_vmap(self): + def initial_params_fn(key): + ka, kb = jax.random.split(key) + return { + "a": jax.random.normal(ka, (10,)), + "b": jax.random.normal(kb, (10,)), + "c": types.Density2DArray(array=jnp.ones((3, 3))), + } + + def loss_fn(params): + flat, _ = flatten_util.ravel_pytree(params) + return jnp.sum(jnp.abs(flat**2)) + + keys = jax.random.split(jax.random.PRNGKey(0)) + opt = wrapped_optax.density_wrapped_optax(optax.adam(0.1), beta=2.0) + + # Test batch optimization + params = jax.vmap(initial_params_fn)(keys) + state = jax.vmap(opt.init)(params) + + @jax.jit + @jax.vmap + def step_fn(state): + params = opt.params(state) + value, grad = jax.value_and_grad(loss_fn)(params) + state = opt.update(grad=grad, value=value, params=params, state=state) + return state, value + + batch_values = [] + for i in range(10): + state, value = step_fn(state) + batch_values.append(value) + + # Test one-at-a-time optimization. + no_batch_values = [] + for k in keys: + no_batch_values.append([]) + params = initial_params_fn(k) + state = opt.init(params) + for _ in range(10): + params = opt.params(state) + value, grad = jax.jit(jax.value_and_grad(loss_fn))(params) + state = opt.update(grad=grad, value=value, params=params, state=state) + no_batch_values[-1].append(value) + + onp.testing.assert_allclose( + batch_values, onp.transpose(no_batch_values, (1, 0)), atol=1e-4 + ) From e81da937adcb4a074ef756b1b65612caa3260685 Mon Sep 17 00:00:00 2001 From: Martin Schubert Date: Mon, 10 Jun 2024 13:53:00 -0700 Subject: [PATCH 2/4] Add optax dep --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index bd55b78..45df90f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,6 +20,7 @@ dependencies = [ "jaxlib", "numpy", "requests", + "optax", "scipy", "totypes", "types-requests", From cfee6329c079246e9cde857e9520aaa41fd09fab Mon Sep 17 00:00:00 2001 From: Martin Schubert Date: Mon, 10 Jun 2024 13:56:12 -0700 Subject: [PATCH 3/4] mypy --- src/invrs_opt/base.py | 2 +- src/invrs_opt/lbfgsb/lbfgsb.py | 4 ++-- src/invrs_opt/wrapped_optax/wrapped_optax.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/invrs_opt/base.py b/src/invrs_opt/base.py index 1e6c975..edc408d 100644 --- a/src/invrs_opt/base.py +++ b/src/invrs_opt/base.py @@ -6,7 +6,7 @@ import dataclasses from typing import Any, Protocol -import optax +import optax # type: ignore[import-untyped] from totypes import json_utils PyTree = Any diff --git a/src/invrs_opt/lbfgsb/lbfgsb.py b/src/invrs_opt/lbfgsb/lbfgsb.py index 5ec2f1b..352944e 100644 --- a/src/invrs_opt/lbfgsb/lbfgsb.py +++ b/src/invrs_opt/lbfgsb/lbfgsb.py @@ -11,8 +11,8 @@ import jax.numpy as jnp import numpy as onp from jax import flatten_util, tree_util -from scipy.optimize._lbfgsb_py import ( - _lbfgsb as scipy_lbfgsb, # type: ignore[import-untyped] +from scipy.optimize._lbfgsb_py import ( # type: ignore[import-untyped] + _lbfgsb as scipy_lbfgsb, ) from totypes import types diff --git a/src/invrs_opt/wrapped_optax/wrapped_optax.py b/src/invrs_opt/wrapped_optax/wrapped_optax.py index 63ee6a9..440a942 100644 --- a/src/invrs_opt/wrapped_optax/wrapped_optax.py +++ b/src/invrs_opt/wrapped_optax/wrapped_optax.py @@ -3,7 +3,7 @@ import jax import jax.numpy as jnp -import optax +import optax # type: ignore[import-untyped] from jax import tree_util from totypes import types From 3b94e937c5cd5e3fcb59de9eee111b2647b52560 Mon Sep 17 00:00:00 2001 From: Martin Schubert Date: Mon, 10 Jun 2024 13:58:33 -0700 Subject: [PATCH 4/4] Version updated from v0.5.2 to v0.6.0 --- .bumpversion.toml | 2 +- README.md | 2 +- pyproject.toml | 2 +- src/invrs_opt/__init__.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/.bumpversion.toml b/.bumpversion.toml index 4ded5eb..b54cf8f 100644 --- a/.bumpversion.toml +++ b/.bumpversion.toml @@ -1,5 +1,5 @@ [tool.bumpversion] -current_version = "v0.5.2" +current_version = "v0.6.0" commit = true commit_args = "--no-verify" tag = true diff --git a/README.md b/README.md index 5c0ad80..9b96159 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,5 @@ # invrs-opt - Optimization algorithms for inverse design -`v0.5.2` +`v0.6.0` ## Overview diff --git a/pyproject.toml b/pyproject.toml index 45df90f..d196ed1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,7 @@ [project] name = "invrs_opt" -version = "v0.5.2" +version = "v0.6.0" description = "Algorithms for inverse design" keywords = ["topology", "optimization", "jax", "inverse design"] readme = "README.md" diff --git a/src/invrs_opt/__init__.py b/src/invrs_opt/__init__.py index f091cae..2d835dd 100644 --- a/src/invrs_opt/__init__.py +++ b/src/invrs_opt/__init__.py @@ -3,7 +3,7 @@ Copyright (c) 2023 The INVRS-IO authors. """ -__version__ = "v0.5.2" +__version__ = "v0.6.0" __author__ = "Martin F. Schubert " from invrs_opt.lbfgsb.lbfgsb import density_lbfgsb as density_lbfgsb