From ea3bcab68b47a9b92ce759e08acbc77abfb5a10a Mon Sep 17 00:00:00 2001 From: Enrique Piqueras Date: Thu, 25 Apr 2024 12:03:49 -0700 Subject: [PATCH] Fix jax.tree_util.register_dataclass in older JAX versions. PiperOrigin-RevId: 628149376 --- flax/struct.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/flax/struct.py b/flax/struct.py index 1b7dc6d4..314ce3c6 100644 --- a/flax/struct.py +++ b/flax/struct.py @@ -124,11 +124,14 @@ def replace(self, **updates): data_clz.replace = replace # Remove this guard once minimux JAX version is >0.4.26. - if hasattr(jax.tree_util, 'register_dataclass'): - jax.tree_util.register_dataclass( - data_clz, data_fields, meta_fields - ) - else: + try: + if hasattr(jax.tree_util, 'register_dataclass'): + jax.tree_util.register_dataclass( + data_clz, data_fields, meta_fields + ) + else: + raise NotImplementedError + except NotImplementedError: def iterate_clz(x): meta = tuple(getattr(x, name) for name in meta_fields)