diff --git a/src/invrs_gym/utils/transforms.py b/src/invrs_gym/utils/transforms.py index c037200..55e83d1 100644 --- a/src/invrs_gym/utils/transforms.py +++ b/src/invrs_gym/utils/transforms.py @@ -16,9 +16,9 @@ def rescaled_density_array( upper_bound: float, ) -> jnp.ndarray: """Return a density array for specified lower and upper bounds.""" - array = density.array - density.lower_bound - array /= density.upper_bound - density.lower_bound - array *= upper_bound - lower_bound + array = jnp.asarray(density.array - density.lower_bound) + array /= jnp.asarray(density.upper_bound - density.lower_bound) + array *= jnp.asarray(upper_bound - lower_bound) return jnp.asarray(array + lower_bound)