diff --git a/src/invrs_opt/optimizers/lbfgsb.py b/src/invrs_opt/optimizers/lbfgsb.py index 156f9f5..e52e071 100644 --- a/src/invrs_opt/optimizers/lbfgsb.py +++ b/src/invrs_opt/optimizers/lbfgsb.py @@ -19,7 +19,7 @@ from invrs_opt.optimizers import base from invrs_opt.parameterization import ( - base as parameterization_base, + base as param_base, filter_project, gaussian_levelset, pixel, @@ -252,7 +252,7 @@ def levelset_lbfgsb( def parameterized_lbfgsb( - density_parameterization: Optional[parameterization_base.Density2DParameterization], + density_parameterization: Optional[param_base.Density2DParameterization], penalty: float, maxcor: int = DEFAULT_MAXCOR, line_search_max_steps: int = DEFAULT_LINE_SEARCH_MAX_STEPS, @@ -296,59 +296,6 @@ def parameterized_lbfgsb( if density_parameterization is None: density_parameterization = pixel.pixel() - def _init_latents(params: PyTree) -> PyTree: - def _leaf_init_latents(leaf: Any) -> Any: - leaf = _clip(leaf) - if not _is_density(leaf) or density_parameterization is None: - return leaf - return density_parameterization.from_density(leaf) - - return tree_util.tree_map(_leaf_init_latents, params, is_leaf=_is_custom_type) - - def _params_from_latents(latent_params: PyTree) -> PyTree: - def _leaf_params_from_latents(leaf: Any) -> Any: - if not _is_parameterized_density(leaf) or density_parameterization is None: - return leaf - return density_parameterization.to_density(leaf) - - return tree_util.tree_map( - _leaf_params_from_latents, - latent_params, - is_leaf=_is_parameterized_density, - ) - - def _constraint_loss(latent_params: PyTree) -> jnp.ndarray: - def _constraint_loss_leaf( - params: parameterization_base.ParameterizedDensity2DArrayBase, - ) -> jnp.ndarray: - constraints = density_parameterization.constraints(params) - constraints = tree_util.tree_map( - lambda x: jnp.sum(jnp.maximum(x, 0.0) ** 2), - constraints, - ) - return jnp.sum(jnp.asarray(constraints)) - - losses = [0.0] + [ - _constraint_loss_leaf(p) - for p in tree_util.tree_leaves( - latent_params, is_leaf=_is_parameterized_density - ) - if _is_parameterized_density(p) - ] - return penalty * jnp.sum(jnp.asarray(losses)) - - def _update_parameterized_densities(latent_params: PyTree, step: int) -> PyTree: - def _update_leaf(leaf: Any) -> Any: - if not _is_parameterized_density(leaf): - return leaf - return density_parameterization.update(leaf, step) - - return tree_util.tree_map( - _update_leaf, - latent_params, - is_leaf=_is_parameterized_density, - ) - def init_fn(params: PyTree) -> LbfgsbState: """Initializes the optimization state.""" @@ -368,10 +315,17 @@ def _init_state_pure(latent_params: PyTree) -> Tuple[PyTree, JaxLbfgsbDict]: return latent_params, scipy_lbfgsb_state.to_jax() latent_params = _init_latents(params) - latent_params, jax_lbfgsb_state = jax.pure_callback( - _init_state_pure, _example_state(latent_params, maxcor), latent_params + metadata, latents = param_base.partition_density_metadata(latent_params) + latents, jax_lbfgsb_state = jax.pure_callback( + _init_state_pure, _example_state(latents, maxcor), latents + ) + latent_params = param_base.combine_density_metadata(metadata, latents) + return ( + 0, # step + _params_from_latent_params(latent_params), # params + latent_params, # latent params + jax_lbfgsb_state, # opt state ) - return 0, _params_from_latents(latent_params), latent_params, jax_lbfgsb_state def params_fn(state: LbfgsbState) -> PyTree: """Returns the parameters for the given `state`.""" @@ -403,46 +357,118 @@ def _update_pure( return flat_latent_params, scipy_lbfgsb_state.to_jax() step, _, latent_params, jax_lbfgsb_state = state - _, vjp_fn = jax.vjp(_params_from_latents, latent_params) - (latent_grad,) = vjp_fn(grad) + metadata, latents = param_base.partition_density_metadata(latent_params) + + def _params_from_latents(latents: PyTree) -> PyTree: + latent_params = param_base.combine_density_metadata(metadata, latents) + return _params_from_latent_params(latent_params) + + def _constraint_loss_latents(latents: PyTree) -> jnp.ndarray: + latent_params = param_base.combine_density_metadata(metadata, latents) + return _constraint_loss(latent_params) + + _, vjp_fn = jax.vjp(_params_from_latents, latents) + (latents_grad,) = vjp_fn(grad) if not ( - tree_util.tree_structure(latent_grad) - == tree_util.tree_structure(latent_params) # type: ignore[operator] + tree_util.tree_structure(latents_grad) + == tree_util.tree_structure(latents) # type: ignore[operator] ): raise ValueError( - f"Tree structure of `latent_grad` was different than expected, got \n" - f"{tree_util.tree_structure(latent_grad)} but expected \n" - f"{tree_util.tree_structure(latent_params)}." + f"Tree structure of `latents_grad` was different than expected, got \n" + f"{tree_util.tree_structure(latents_grad)} but expected \n" + f"{tree_util.tree_structure(latents)}." ) ( constraint_loss_value, constraint_loss_grad, ) = jax.value_and_grad( - _constraint_loss - )(latent_params) + _constraint_loss_latents + )(latents) value += constraint_loss_value - latent_grad = tree_util.tree_map( - lambda a, b: a + b, latent_grad, constraint_loss_grad + latents_grad = tree_util.tree_map( + lambda a, b: a + b, latents_grad, constraint_loss_grad ) - flat_latent_grad, unflatten_fn = flatten_util.ravel_pytree( - latent_grad + flat_latents_grad, unflatten_fn = flatten_util.ravel_pytree( + latents_grad ) # type: ignore[no-untyped-call] - flat_latent_params, jax_lbfgsb_state = jax.pure_callback( + flat_latents, jax_lbfgsb_state = jax.pure_callback( _update_pure, - (flat_latent_grad, jax_lbfgsb_state), - flat_latent_grad, + (flat_latents_grad, jax_lbfgsb_state), + flat_latents_grad, value, jax_lbfgsb_state, ) - latent_params = unflatten_fn(flat_latent_params) + latents = unflatten_fn(flat_latents) + latent_params = param_base.combine_density_metadata(metadata, latents) latent_params = _update_parameterized_densities(latent_params, step) - params = _params_from_latents(latent_params) + params = _params_from_latent_params(latent_params) return step + 1, params, latent_params, jax_lbfgsb_state + # ------------------------------------------------------------------------- + # Functions related to the density parameterization. + # ------------------------------------------------------------------------- + + def _init_latents(params: PyTree) -> PyTree: + def _leaf_init_latents(leaf: Any) -> Any: + leaf = _clip(leaf) + if not _is_density(leaf) or density_parameterization is None: + return leaf + return density_parameterization.from_density(leaf) + + return tree_util.tree_map(_leaf_init_latents, params, is_leaf=_is_custom_type) + + def _params_from_latent_params(latent_params: PyTree) -> PyTree: + def _leaf_params_from_latents(leaf: Any) -> Any: + if not _is_parameterized_density(leaf) or density_parameterization is None: + return leaf + return density_parameterization.to_density(leaf) + + return tree_util.tree_map( + _leaf_params_from_latents, + latent_params, + is_leaf=_is_parameterized_density, + ) + + def _update_parameterized_densities(latent_params: PyTree, step: int) -> PyTree: + def _update_leaf(leaf: Any) -> Any: + if not _is_parameterized_density(leaf): + return leaf + return density_parameterization.update(leaf, step) + + return tree_util.tree_map( + _update_leaf, + latent_params, + is_leaf=_is_parameterized_density, + ) + + # ------------------------------------------------------------------------- + # Functions related to the constraints to be minimized. + # ------------------------------------------------------------------------- + + def _constraint_loss(latent_params: PyTree) -> jnp.ndarray: + def _constraint_loss_leaf( + leaf: param_base.ParameterizedDensity2DArray, + ) -> jnp.ndarray: + constraints = density_parameterization.constraints(leaf) + constraints = tree_util.tree_map( + lambda x: jnp.sum(jnp.maximum(x, 0.0) ** 2), + constraints, + ) + return jnp.sum(jnp.asarray(constraints)) + + losses = [0.0] + [ + _constraint_loss_leaf(p) + for p in tree_util.tree_leaves( + latent_params, is_leaf=_is_parameterized_density + ) + if _is_parameterized_density(p) + ] + return penalty * jnp.sum(jnp.asarray(losses)) + return base.Optimizer( init=init_fn, params=params_fn, @@ -467,7 +493,7 @@ def _is_density(leaf: Any) -> Any: def _is_parameterized_density(leaf: Any) -> Any: """Return `True` if `leaf` is a parameterized density array.""" - return isinstance(leaf, parameterization_base.ParameterizedDensity2DArrayBase) + return isinstance(leaf, param_base.ParameterizedDensity2DArray) def _is_custom_type(leaf: Any) -> bool: diff --git a/src/invrs_opt/optimizers/wrapped_optax.py b/src/invrs_opt/optimizers/wrapped_optax.py index 0c8fc9e..48606e4 100644 --- a/src/invrs_opt/optimizers/wrapped_optax.py +++ b/src/invrs_opt/optimizers/wrapped_optax.py @@ -13,7 +13,7 @@ from invrs_opt.optimizers import base from invrs_opt.parameterization import ( - base as parameterization_base, + base as param_base, filter_project, gaussian_levelset, pixel, @@ -158,7 +158,7 @@ def levelset_wrapped_optax( def parameterized_wrapped_optax( opt: optax.GradientTransformation, - density_parameterization: Optional[parameterization_base.Density2DParameterization], + density_parameterization: Optional[param_base.Density2DParameterization], penalty: float, ) -> base.Optimizer: """Wrapped optax optimizer with specified density parameterization. @@ -181,11 +181,12 @@ def parameterized_wrapped_optax( def init_fn(params: PyTree) -> WrappedOptaxState: """Initializes the optimization state.""" latent_params = _init_latents(params) + _, latents = param_base.partition_density_metadata(latent_params) return ( 0, # step - _params_from_latents(latent_params), # params + _params_from_latent_params(latent_params), # params latent_params, # latent params - opt.init(tree_util.tree_leaves(latent_params)), # opt state + opt.init(latents), # opt state ) def params_fn(state: WrappedOptaxState) -> PyTree: @@ -204,42 +205,41 @@ def update_fn( del value, params step, params, latent_params, opt_state = state + metadata, latents = param_base.partition_density_metadata(latent_params) - _, vjp_fn = jax.vjp(_params_from_latents, latent_params) - (latent_grad,) = vjp_fn(grad) + def _params_from_latents(latents: PyTree) -> PyTree: + latent_params = param_base.combine_density_metadata(metadata, latents) + return _params_from_latent_params(latent_params) + + def _constraint_loss_latents(latents: PyTree) -> jnp.ndarray: + latent_params = param_base.combine_density_metadata(metadata, latents) + return _constraint_loss(latent_params) + + _, vjp_fn = jax.vjp(_params_from_latents, latents) + (latents_grad,) = vjp_fn(grad) if not ( - tree_util.tree_structure(latent_grad) - == tree_util.tree_structure(latent_params) # type: ignore[operator] + tree_util.tree_structure(latents_grad) + == tree_util.tree_structure(latents) # type: ignore[operator] ): raise ValueError( - f"Tree structure of `latent_grad` was different than expected, got \n" - f"{tree_util.tree_structure(latent_grad)} but expected \n" - f"{tree_util.tree_structure(latent_params)}." + f"Tree structure of `latents_grad` was different than expected, got \n" + f"{tree_util.tree_structure(latents_grad)} but expected \n" + f"{tree_util.tree_structure(latents)}." ) - constraint_loss_grad = jax.grad(_constraint_loss)(latent_params) - latent_grad = tree_util.tree_map( - lambda a, b: a + b, latent_grad, constraint_loss_grad + constraint_loss_grad = jax.grad(_constraint_loss_latents)(latents) + latents_grad = tree_util.tree_map( + lambda a, b: a + b, latents_grad, constraint_loss_grad ) - updates_leaves, opt_state = opt.update( - updates=tree_util.tree_leaves(latent_grad), - state=opt_state, - params=tree_util.tree_leaves(latent_params), - ) - latent_params_leaves = optax.apply_updates( - params=tree_util.tree_leaves(latent_params), - updates=updates_leaves, - ) - latent_params = tree_util.tree_unflatten( - treedef=tree_util.tree_structure(latent_params), - leaves=latent_params_leaves, - ) + updates, opt_state = opt.update(latents_grad, state=opt_state, params=latents) + latents = optax.apply_updates(params=latents, updates=updates) + latent_params = param_base.combine_density_metadata(metadata, latents) latent_params = _clip(latent_params) latent_params = _update_parameterized_densities(latent_params, step + 1) - params = _params_from_latents(latent_params) + params = _params_from_latent_params(latent_params) return (step + 1, params, latent_params, opt_state) # ------------------------------------------------------------------------- @@ -255,7 +255,7 @@ def _leaf_init_latents(leaf: Any) -> Any: return tree_util.tree_map(_leaf_init_latents, params, is_leaf=_is_custom_type) - def _params_from_latents(params: PyTree) -> PyTree: + def _params_from_latent_params(params: PyTree) -> PyTree: def _leaf_params_from_latents(leaf: Any) -> Any: if not _is_parameterized_density(leaf): return leaf @@ -285,9 +285,9 @@ def _update_leaf(leaf: Any) -> Any: def _constraint_loss(latent_params: PyTree) -> jnp.ndarray: def _constraint_loss_leaf( - params: parameterization_base.ParameterizedDensity2DArrayBase, + leaf: param_base.ParameterizedDensity2DArray, ) -> jnp.ndarray: - constraints = density_parameterization.constraints(params) + constraints = density_parameterization.constraints(leaf) constraints = tree_util.tree_map( lambda x: jnp.sum(jnp.maximum(x, 0.0) ** 2), constraints, @@ -313,7 +313,7 @@ def _is_density(leaf: Any) -> Any: def _is_parameterized_density(leaf: Any) -> Any: """Return `True` if `leaf` is a parameterized density array.""" - return isinstance(leaf, parameterization_base.ParameterizedDensity2DArrayBase) + return isinstance(leaf, param_base.ParameterizedDensity2DArray) def _is_custom_type(leaf: Any) -> bool: diff --git a/src/invrs_opt/parameterization/base.py b/src/invrs_opt/parameterization/base.py index 6d89571..28c2993 100644 --- a/src/invrs_opt/parameterization/base.py +++ b/src/invrs_opt/parameterization/base.py @@ -9,30 +9,30 @@ import jax.numpy as jnp import numpy as onp from jax import tree_util -from totypes import json_utils, types +from totypes import json_utils, partition_utils, types Array = jnp.ndarray | onp.ndarray[Any, Any] PyTree = Any -class Density2DLatentsBase: - """Base class for latents of a parameterized density array.""" +@dataclasses.dataclass +class ParameterizedDensity2DArray: + """Stores latents and metadata for a parameterized density array.""" - pass + latents: "LatentsBase" + metadata: Optional["MetadataBase"] -class Density2DMetadataBase: - """Base class for metadata of a parameterized density array.""" +class LatentsBase: + """Base class for latents of a parameterized density array.""" pass -@dataclasses.dataclass -class ParameterizedDensity2DArray: - """Stores latents and metadata for a parameterized density array.""" +class MetadataBase: + """Base class for metadata of a parameterized density array.""" - latents: Density2DLatentsBase - metadata: Optional[Density2DMetadataBase] + pass tree_util.register_dataclass( @@ -40,14 +40,27 @@ class ParameterizedDensity2DArray: data_fields=["latents", "metadata"], meta_fields=[], ) - json_utils.register_custom_type(ParameterizedDensity2DArray) -class ParameterizedDensity2DArrayBase: - """Base class for parameterized density arrays.""" +def partition_density_metadata(tree: PyTree) -> Tuple[PyTree, PyTree]: + """Splits a pytree with parameterized densities into metadata from latents.""" + metadata, latents = partition_utils.partition( + tree, + select_fn=lambda x: isinstance(x, MetadataBase), + is_leaf=_is_metadata_or_none, + ) + return metadata, latents - pass + +def combine_density_metadata(metadata: PyTree, latents: PyTree) -> PyTree: + """Combines pytrees containing metadata and latents.""" + return partition_utils.combine(metadata, latents, is_leaf=_is_metadata_or_none) + + +def _is_metadata_or_none(leaf: Any) -> bool: + """Return `True` if `leaf` is `None` or density metadata.""" + return leaf is None or isinstance(leaf, MetadataBase) @dataclasses.dataclass @@ -63,9 +76,7 @@ class Density2DParameterization: class FromDensityFn(Protocol): """Generate the latent representation of a density array.""" - def __call__( - self, density: types.Density2DArray - ) -> ParameterizedDensity2DArrayBase: + def __call__(self, density: types.Density2DArray) -> ParameterizedDensity2DArray: ... @@ -107,6 +118,12 @@ def __post_init__(self) -> None: self.periodic = tuple(self.periodic) self.symmetries = tuple(self.symmetries) + @classmethod + def from_density(self, density: types.Density2DArray) -> "Density2DMetadata": + density_metadata_dict = dataclasses.asdict(density) + del density_metadata_dict["array"] + return Density2DMetadata(**density_metadata_dict) + def _flatten_density_2d_metadata( metadata: Density2DMetadata, diff --git a/src/invrs_opt/parameterization/filter_project.py b/src/invrs_opt/parameterization/filter_project.py index 91ceb8a..58d7a9d 100644 --- a/src/invrs_opt/parameterization/filter_project.py +++ b/src/invrs_opt/parameterization/filter_project.py @@ -13,25 +13,53 @@ @dataclasses.dataclass -class FilterAndProjectParams(base.ParameterizedDensity2DArrayBase): - """Stores the latent parameters of the pixel parameterization. +class FilterProjectParams(base.ParameterizedDensity2DArray): + """Stores parameters for the filter-project parameterization.""" - Attributes: + latents: "FilterProjectLatents" + metadata: "FilterProjectMetadata" + + +@dataclasses.dataclass +class FilterProjectLatents(base.LatentsBase): + """Stores latent parameters for the filter-project parameterization. + + Attributes:s latent_density: The latent variable from which the density is obtained. - beta: Determines the sharpness of the thresholding operation. """ latent_density: types.Density2DArray + + +@dataclasses.dataclass +class FilterProjectMetadata(base.MetadataBase): + """Stores metadata for the filter-project parameterization. + + Attributes: + beta: Determines the sharpness of the thresholding operation. + """ + beta: float tree_util.register_dataclass( - FilterAndProjectParams, + FilterProjectParams, + data_fields=["latents", "metadata"], + meta_fields=[], +) +tree_util.register_dataclass( + FilterProjectLatents, data_fields=["latent_density"], + meta_fields=[], +) +tree_util.register_dataclass( + FilterProjectMetadata, + data_fields=[], meta_fields=["beta"], ) - -json_utils.register_custom_type(FilterAndProjectParams) +json_utils.register_custom_type(FilterProjectParams) +json_utils.register_custom_type(FilterProjectLatents) +json_utils.register_custom_type(FilterProjectMetadata) def filter_project(beta: float) -> base.Density2DParameterization: @@ -55,24 +83,26 @@ def filter_project(beta: float) -> base.Density2DParameterization: The `Density2DParameterization`. """ - def from_density_fn(density: types.Density2DArray) -> FilterAndProjectParams: + def from_density_fn(density: types.Density2DArray) -> FilterProjectParams: """Return latent parameters for the given `density`.""" array = transforms.normalized_array_from_density(density) array = jnp.clip(array, -1, 1) array *= jnp.tanh(beta) latent_array = jnp.arctanh(array) / beta latent_array = transforms.rescale_array_for_density(latent_array, density) - return FilterAndProjectParams( - latent_density=dataclasses.replace(density, array=latent_array), - beta=beta, + latent_density = density = dataclasses.replace(density, array=latent_array) + return FilterProjectParams( + latents=FilterProjectLatents(latent_density=latent_density), + metadata=FilterProjectMetadata(beta=beta), ) - def to_density_fn(params: FilterAndProjectParams) -> types.Density2DArray: + def to_density_fn(params: FilterProjectParams) -> types.Density2DArray: """Return a density from the latent parameters.""" - transformed = types.symmetrize_density(params.latent_density) - transformed = transforms.density_gaussian_filter_and_tanh( - transformed, beta=params.beta - ) + latent_density = params.latents.latent_density + beta = params.metadata.beta + + transformed = types.symmetrize_density(latent_density) + transformed = transforms.density_gaussian_filter_and_tanh(transformed, beta) # Scale to ensure that the full valid range of the density array is reachable. mid_value = (transformed.lower_bound + transformed.upper_bound) / 2 transformed = tree_util.tree_map( @@ -80,12 +110,12 @@ def to_density_fn(params: FilterAndProjectParams) -> types.Density2DArray: ) return transforms.apply_fixed_pixels(transformed) - def constraints_fn(params: FilterAndProjectParams) -> jnp.ndarray: + def constraints_fn(params: FilterProjectParams) -> jnp.ndarray: """Computes constraints associated with the params.""" del params return jnp.asarray(0.0) - def update_fn(params: FilterAndProjectParams, step: int) -> FilterAndProjectParams: + def update_fn(params: FilterProjectParams, step: int) -> FilterProjectParams: """Perform updates to `params` required for the given `step`.""" del step return params diff --git a/src/invrs_opt/parameterization/gaussian_levelset.py b/src/invrs_opt/parameterization/gaussian_levelset.py index 406b761..ac05979 100644 --- a/src/invrs_opt/parameterization/gaussian_levelset.py +++ b/src/invrs_opt/parameterization/gaussian_levelset.py @@ -29,12 +29,30 @@ @dataclasses.dataclass -class GaussianLevelsetParams(base.ParameterizedDensity2DArrayBase): - """Parameters of a density represented by a Gaussian levelset. +class GaussianLevelsetParams(base.ParameterizedDensity2DArray): + """Stores parameters for the Gaussian levelset parameterization.""" + + latents: "GaussianLevelsetLatents" + metadata: "GaussianLevelsetMetadata" + + +@dataclasses.dataclass +class GaussianLevelsetLatents(base.LatentsBase): + """Stores latent parameters for the Gaussian levelset parameterization. Attributes: amplitude: Array giving the amplitude of the Gaussian basis function at levelset control points. + """ + + amplitude: jnp.ndarray + + +@dataclasses.dataclass +class GaussianLevelsetMetadata(base.MetadataBase): + """Stores metadata for the Gaussian levelset parameterization. + + Attributes: length_scale_spacing_factor: The number of levelset control points per unit of minimum length scale (mean of density minimum width and minimum spacing). length_scale_fwhm_factor: The ratio of Gaussian full-width at half-maximum to @@ -45,7 +63,6 @@ class GaussianLevelsetParams(base.ParameterizedDensity2DArrayBase): density_metadata: Metadata for the density array obtained from the parameters. """ - amplitude: jnp.ndarray length_scale_spacing_factor: float length_scale_fwhm_factor: float smoothing_factor: int @@ -55,68 +72,31 @@ class GaussianLevelsetParams(base.ParameterizedDensity2DArrayBase): def __post_init__(self) -> None: self.density_shape = tuple(self.density_shape) - def example_density(self) -> types.Density2DArray: - """Returns an example density with appropriate shape and metadata.""" - with jax.ensure_compile_time_eval(): - return types.Density2DArray( - array=jnp.zeros(self.density_shape), - **dataclasses.asdict(self.density_metadata), - ) - - -_GaussianLevelsetParamsAux = Tuple[ - float, float, int, Tuple[int, ...], tree_util.PyTreeDef -] - -def _flatten_gaussian_levelset_params( - params: GaussianLevelsetParams, -) -> Tuple[Tuple[jnp.ndarray], _GaussianLevelsetParamsAux]: - _, flat_metadata = tree_util.tree_flatten(params.density_metadata) - return ( - (params.amplitude,), - ( - params.length_scale_spacing_factor, - params.length_scale_fwhm_factor, - params.smoothing_factor, - params.density_shape, - flat_metadata, - ), - ) - - -def _unflatten_gaussian_levelset_params( - aux: _GaussianLevelsetParamsAux, - children: Tuple[jnp.ndarray], -) -> GaussianLevelsetParams: - (amplitude,) = children - ( - length_scale_spacing_factor, - length_scale_fwhm_factor, - smoothing_factor, - density_shape, - flat_metadata, - ) = aux - - density_metadata = tree_util.tree_unflatten(flat_metadata, ()) - return GaussianLevelsetParams( - amplitude=amplitude, - length_scale_spacing_factor=length_scale_spacing_factor, - length_scale_fwhm_factor=length_scale_fwhm_factor, - smoothing_factor=smoothing_factor, - density_shape=tuple(density_shape), - density_metadata=density_metadata, - ) - - -tree_util.register_pytree_node( +tree_util.register_dataclass( GaussianLevelsetParams, - flatten_func=_flatten_gaussian_levelset_params, - unflatten_func=_unflatten_gaussian_levelset_params, + data_fields=["latents", "metadata"], + meta_fields=[], +) +tree_util.register_dataclass( + GaussianLevelsetLatents, + data_fields=["amplitude"], + meta_fields=[], +) +tree_util.register_dataclass( + GaussianLevelsetMetadata, + data_fields=[ + "length_scale_spacing_factor", + "length_scale_fwhm_factor", + "smoothing_factor", + "density_shape", + "density_metadata", + ], + meta_fields=[], ) - - json_utils.register_custom_type(GaussianLevelsetParams) +json_utils.register_custom_type(GaussianLevelsetLatents) +json_utils.register_custom_type(GaussianLevelsetMetadata) def gaussian_levelset( @@ -187,23 +167,21 @@ def from_density_fn(density: types.Density2DArray) -> GaussianLevelsetParams: pad_width += ((0, 0),) if density.periodic[1] else ((1, 1),) amplitude = jnp.pad(amplitude, pad_width, mode="edge") - density_metadata_dict = dataclasses.asdict(density) - del density_metadata_dict["array"] - density_metadata = base.Density2DMetadata(**density_metadata_dict) - params = GaussianLevelsetParams( - amplitude=amplitude, + latents = GaussianLevelsetLatents(amplitude=amplitude) + metadata = GaussianLevelsetMetadata( length_scale_spacing_factor=length_scale_spacing_factor, length_scale_fwhm_factor=length_scale_fwhm_factor, smoothing_factor=smoothing_factor, density_shape=density.shape, - density_metadata=density_metadata, + density_metadata=base.Density2DMetadata.from_density(density), ) def step_fn( _: int, params_and_state: Tuple[PyTree, PyTree], ) -> Tuple[PyTree, PyTree]: - def loss_fn(params: GaussianLevelsetParams) -> jnp.ndarray: + def loss_fn(latents: GaussianLevelsetLatents) -> jnp.ndarray: + params = GaussianLevelsetParams(latents, metadata=metadata) density_from_params = to_density_fn(params, mask_gradient=False) return jnp.mean((density_from_params.array - target_array) ** 2) @@ -213,13 +191,14 @@ def loss_fn(params: GaussianLevelsetParams) -> jnp.ndarray: params = optax.apply_updates(params, updates) return params, state - state = init_optimizer.init(params) - params, _ = jax.lax.fori_loop( - 0, init_steps, body_fun=step_fn, init_val=(params, state) + state = init_optimizer.init(latents) + latents, _ = jax.lax.fori_loop( + 0, init_steps, body_fun=step_fn, init_val=(latents, state) ) - maxval = jnp.amax(jnp.abs(params.amplitude), axis=(-2, -1), keepdims=True) - return dataclasses.replace(params, amplitude=params.amplitude / maxval) + maxval = jnp.amax(jnp.abs(latents.amplitude), axis=(-2, -1), keepdims=True) + latents = dataclasses.replace(latents, amplitude=latents.amplitude / maxval) + return GaussianLevelsetParams(latents=latents, metadata=metadata) def to_density_fn( params: GaussianLevelsetParams, @@ -228,7 +207,7 @@ def to_density_fn( """Return a density from the latent parameters.""" array = _to_array(params, mask_gradient=mask_gradient, pad_pixels=0) - example_density = params.example_density() + example_density = _example_density(params) lb = example_density.lower_bound ub = example_density.upper_bound array = lb + array * (ub - lb) @@ -263,7 +242,7 @@ def constraints_fn( ) # Normalize constraints to make them (somewhat) resolution-independent. - example_density = params.example_density() + example_density = _example_density(params) length_scale = 0.5 * ( example_density.minimum_spacing + example_density.minimum_width ) @@ -287,6 +266,15 @@ def update_fn(params: GaussianLevelsetParams, step: int) -> GaussianLevelsetPara # ----------------------------------------------------------------------------- +def _example_density(params: GaussianLevelsetParams) -> types.Density2DArray: + """Returns an example density with appropriate shape and metadata.""" + with jax.ensure_compile_time_eval(): + return types.Density2DArray( + array=jnp.zeros(params.metadata.density_shape), + **dataclasses.asdict(params.metadata.density_metadata), + ) + + def _to_array( params: GaussianLevelsetParams, mask_gradient: bool, @@ -308,7 +296,7 @@ def _to_array( Returns: The array. """ - example_density = params.example_density() + example_density = _example_density(params) periodic: Tuple[bool, bool] = example_density.periodic phi = _phi_from_params( params=params, @@ -319,7 +307,7 @@ def _to_array( periodic=periodic, mask_gradient=mask_gradient, ) - return _downsample_spatial_dims(array, params.smoothing_factor) + return _downsample_spatial_dims(array, params.metadata.smoothing_factor) def _phi_from_params( @@ -337,32 +325,35 @@ def _phi_from_params( The levelset array `phi`. """ with jax.ensure_compile_time_eval(): - example_density = params.example_density() + example_density = _example_density(params) length_scale = 0.5 * ( example_density.minimum_width + example_density.minimum_spacing ) - fwhm = length_scale * params.length_scale_fwhm_factor + fwhm = length_scale * params.metadata.length_scale_fwhm_factor sigma = fwhm / (2 * jnp.sqrt(2 * jnp.log(2))) + s_factor = params.metadata.smoothing_factor highres_i = ( 0.5 + jnp.arange( - params.smoothing_factor * (-pad_pixels), - params.smoothing_factor * (pad_pixels + example_density.shape[-2]), + s_factor * (-pad_pixels), + s_factor * (pad_pixels + example_density.shape[-2]), ) - ) / params.smoothing_factor + ) / s_factor highres_j = ( 0.5 + jnp.arange( - params.smoothing_factor * (-pad_pixels), - params.smoothing_factor * (pad_pixels + example_density.shape[-1]), + s_factor * (-pad_pixels), + s_factor * (pad_pixels + example_density.shape[-1]), ) - ) / params.smoothing_factor + ) / s_factor # Coordinates for the control points of the Gaussian radial basis functions. levelset_i, levelset_j = _control_point_coords( - density_shape=params.density_shape[-2:], # type: ignore[arg-type] - levelset_shape=params.amplitude.shape[-2:], # type: ignore[arg-type] + density_shape=params.metadata.density_shape[-2:], # type: ignore[arg-type] + levelset_shape=( + params.latents.amplitude.shape[-2:] # type: ignore[arg-type] + ), periodic=example_density.periodic, ) @@ -391,7 +382,7 @@ def _phi_from_params( levelset_i = levelset_i.flatten() levelset_j = levelset_j.flatten() - amplitude = params.amplitude + amplitude = params.latents.amplitude if example_density.periodic[0]: amplitude = jnp.concat([amplitude] * 3, axis=-2) if example_density.periodic[1]: @@ -410,8 +401,8 @@ def scan_fn(_: Tuple[()], i: jnp.ndarray) -> Tuple[Tuple[()], jnp.ndarray]: _, array = jax.lax.scan(scan_fn, (), xs=highres_i) array = jnp.moveaxis(array, 0, -2) - assert array.shape[-2] % params.smoothing_factor == 0 - assert array.shape[-1] % params.smoothing_factor == 0 + assert array.shape[-2] % s_factor == 0 + assert array.shape[-1] % s_factor == 0 array = symmetry.symmetrize(array, tuple(example_density.symmetries)) return array @@ -443,7 +434,7 @@ def _fixed_pixel_constraint( """ array = _to_array(params, mask_gradient=mask_gradient, pad_pixels=pad_pixels) - example_density = params.example_density() + example_density = _example_density(params) fixed_solid = jnp.zeros(example_density.shape[-2:], dtype=bool) fixed_void = jnp.zeros(example_density.shape[-2:], dtype=bool) if example_density.fixed_solid is not None: @@ -491,9 +482,9 @@ def _levelset_constraints( beyond the boundaries of the parameterized density. Returns: - The minimum length scale and minimum curvature constraint arrays. + The minimum length scale and minimum curvature constraint arrays.s """ - example_density = params.example_density() + example_density = _example_density(params) minimum_length_scale = 0.5 * ( example_density.minimum_width + example_density.minimum_spacing ) @@ -512,9 +503,10 @@ def _levelset_constraints( ) # Downsample so that constraints shape matches the density shape. + factor = params.metadata.smoothing_factor return ( - _downsample_spatial_dims(length_scale_constraint, params.smoothing_factor), - _downsample_spatial_dims(curvature_constraint, params.smoothing_factor), + _downsample_spatial_dims(length_scale_constraint, factor), + _downsample_spatial_dims(curvature_constraint, factor), ) @@ -529,7 +521,7 @@ def _phi_derivatives_and_inverse_radius( pad_pixels=pad_pixels, ) - d = 1 / params.smoothing_factor + d = 1 / params.metadata.smoothing_factor phi_x, phi_y = jnp.gradient(phi, d, axis=(-2, -1)) phi_xx, phi_yx = jnp.gradient(phi_x, d, axis=(-2, -1)) phi_xy, phi_yy = jnp.gradient(phi_y, d, axis=(-2, -1)) diff --git a/src/invrs_opt/parameterization/pixel.py b/src/invrs_opt/parameterization/pixel.py index 2db2890..21faaf9 100644 --- a/src/invrs_opt/parameterization/pixel.py +++ b/src/invrs_opt/parameterization/pixel.py @@ -13,26 +13,40 @@ @dataclasses.dataclass -class PixelParams(base.ParameterizedDensity2DArrayBase): - """Stores latent parameters of the direct pixel parameterization.""" +class PixelParams(base.ParameterizedDensity2DArray): + latents: "PixelLatents" + metadata: None = None - density: types.Density2DArray +@dataclasses.dataclass +class PixelLatents(base.LatentsBase): + """Stores latent parameters for the direct pixel parameterization.""" -tree_util.register_dataclass(PixelParams, data_fields=["density"], meta_fields=[]) + density: types.Density2DArray +tree_util.register_dataclass( + PixelParams, + data_fields=["latents"], + meta_fields=[], +) +tree_util.register_dataclass( + PixelLatents, + data_fields=["density"], + meta_fields=[], +) json_utils.register_custom_type(PixelParams) +json_utils.register_custom_type(PixelLatents) def pixel() -> base.Density2DParameterization: """Return the direct pixel parameterization.""" def from_density_fn(density: types.Density2DArray) -> PixelParams: - return PixelParams(density=density) + return PixelParams(latents=PixelLatents(density=density)) def to_density_fn(params: PixelParams) -> types.Density2DArray: - return params.density + return params.latents.density def constraints_fn(params: PixelParams) -> jnp.ndarray: del params diff --git a/tests/optimizers/test_lbfgsb.py b/tests/optimizers/test_lbfgsb.py index d5bd4cd..614cf12 100644 --- a/tests/optimizers/test_lbfgsb.py +++ b/tests/optimizers/test_lbfgsb.py @@ -3,7 +3,6 @@ Copyright (c) 2023 The INVRS-IO authors. """ -import dataclasses import unittest import jax @@ -691,7 +690,13 @@ def test_variable_parameterization(self): # Create a custom parameterization whose update method increments `beta` by 1 # at each step. p = filter_project.filter_project(beta=1) - p.update = lambda x, step: dataclasses.replace(x, beta=x.beta + 1) + + def update_fn(params, step): + del step + params.metadata.beta += 1 + return params + + p.update = update_fn opt = lbfgsb.parameterized_lbfgsb( density_parameterization=p, @@ -723,4 +728,4 @@ def loss_fn(density): state = step_fn(state) # Check that beta has actually been incremented. - self.assertEqual(state[2].beta, 11) + self.assertEqual(state[2].metadata.beta, 11) diff --git a/tests/optimizers/test_wrapped_optax.py b/tests/optimizers/test_wrapped_optax.py index 99cadbf..8dcc75c 100644 --- a/tests/optimizers/test_wrapped_optax.py +++ b/tests/optimizers/test_wrapped_optax.py @@ -390,7 +390,13 @@ def test_variable_parameterization(self): # Create a custom parameterization whose update method increments `beta` by 1 # at each step. p = filter_project.filter_project(beta=1) - p.update = lambda x, step: dataclasses.replace(x, beta=x.beta + 1) + + def update_fn(params, step): + del step + params.metadata.beta += 1 + return params + + p.update = update_fn opt = wrapped_optax.parameterized_wrapped_optax( opt=optax.adam(0.01), @@ -423,4 +429,4 @@ def loss_fn(density): state = step_fn(state) # Check that beta has actually been incremented. - self.assertEqual(state[2].beta, 11) + self.assertEqual(state[2].metadata.beta, 11)