Skip to content

Commit

Permalink
Fix jax.tree_util.register_dataclass in older JAX versions.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 628149376
  • Loading branch information
epiqueras authored and Flax Authors committed Apr 25, 2024
1 parent b8ccb15 commit ea3bcab
Showing 1 changed file with 8 additions and 5 deletions.
13 changes: 8 additions & 5 deletions flax/struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,11 +124,14 @@ def replace(self, **updates):
data_clz.replace = replace

# Remove this guard once minimux JAX version is >0.4.26.
if hasattr(jax.tree_util, 'register_dataclass'):
jax.tree_util.register_dataclass(
data_clz, data_fields, meta_fields
)
else:
try:
if hasattr(jax.tree_util, 'register_dataclass'):
jax.tree_util.register_dataclass(
data_clz, data_fields, meta_fields
)
else:
raise NotImplementedError
except NotImplementedError:

def iterate_clz(x):
meta = tuple(getattr(x, name) for name in meta_fields)
Expand Down

0 comments on commit ea3bcab

Please sign in to comment.