From b640d427dc9d9eb0f2e87ec72e46a135d79cdc97 Mon Sep 17 00:00:00 2001 From: Martin Schubert Date: Sat, 10 Feb 2024 20:22:06 -0800 Subject: [PATCH] formatting --- src/invrs_utils/experiment/work_unit.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/invrs_utils/experiment/work_unit.py b/src/invrs_utils/experiment/work_unit.py index 0a05106..9b90e86 100644 --- a/src/invrs_utils/experiment/work_unit.py +++ b/src/invrs_utils/experiment/work_unit.py @@ -73,13 +73,15 @@ def run_work_unit( def step_fn(state: Any): def loss_fn( params: Any, - ) -> Tuple[jnp.ndarray, Tuple[Any, jnp.ndarray, Dict[str, Any], Dict[str, Any]]]: + ) -> Tuple[ + jnp.ndarray, Tuple[Any, jnp.ndarray, Dict[str, Any], Dict[str, Any]] + ]: response, aux = challenge.component.response(params) loss = challenge.loss(response) distance = challenge.distance_to_target(response) metrics = challenge.metrics(response, params, aux) return loss, (response, distance, metrics, aux) - + params = optimizer.params(state) (value, (response, distance, metrics, aux)), grad = jax.value_and_grad( loss_fn, has_aux=True