diff --git a/src/invrs_gym/challenges/sorter/common.py b/src/invrs_gym/challenges/sorter/common.py index 256ee57..10d028c 100644 --- a/src/invrs_gym/challenges/sorter/common.py +++ b/src/invrs_gym/challenges/sorter/common.py @@ -230,7 +230,9 @@ def response( spec = dataclasses.replace( self.spec, thickness_cap=params[THICKNESS_CAP], # type: ignore[arg-type] - thickness_metasurface=params[THICKNESS_METASURFACE], # type: ignore[arg-type] + thickness_metasurface=( + params[THICKNESS_METASURFACE] # type: ignore[arg-type] + ), thickness_spacer=params[THICKNESS_SPACER], # type: ignore[arg-type] ) return simulate_sorter( @@ -375,9 +377,9 @@ def simulate_sorter( layer_thicknesses = [ jnp.zeros(()), # Ambient - spec.thickness_cap.array, - spec.thickness_metasurface.array, - spec.thickness_spacer.array, + jnp.asarray(spec.thickness_cap.array), + jnp.asarray(spec.thickness_metasurface.array), + jnp.asarray(spec.thickness_spacer.array), jnp.asarray(spec.offset_monitor_substrate), # Substrate ]