Skip to content

Commit

Permalink
feat: clear jax cache on register
Browse files Browse the repository at this point in the history
Signed-off-by: nstarman <[email protected]>
  • Loading branch information
nstarman committed Jun 7, 2024
1 parent e2f8961 commit b0f0260
Showing 1 changed file with 8 additions and 0 deletions.
8 changes: 8 additions & 0 deletions plum/dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,14 @@ def _add_method(
f = self._get_function(method)
for signature in signatures:
f.register(method, signature, precedence)

# Hooks to clear the JIT caches of various libraries.
# TODO: this needs to be systematized. This should probably work
# by entry points to allow for any JIT library to be added.
# JAX:
if "jax" in type(method).__module__ and hasattr(method, "clear_cache"):
method.clear_cache()

return f

def clear_cache(self):
Expand Down

0 comments on commit b0f0260

Please sign in to comment.