From b0f0260f82f541eb3033b298e85a4ec15c897b0f Mon Sep 17 00:00:00 2001 From: nstarman Date: Fri, 7 Jun 2024 11:52:44 -0400 Subject: [PATCH] feat: clear jax cache on register Signed-off-by: nstarman --- plum/dispatcher.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/plum/dispatcher.py b/plum/dispatcher.py index 0f2b5adc..80a1ef1b 100644 --- a/plum/dispatcher.py +++ b/plum/dispatcher.py @@ -110,6 +110,14 @@ def _add_method( f = self._get_function(method) for signature in signatures: f.register(method, signature, precedence) + + # Hooks to clear the JIT caches of various libraries. + # TODO: this needs to be systematized. This should probably work + # by entry points to allow for any JIT library to be added. + # JAX: + if "jax" in type(method).__module__ and hasattr(method, "clear_cache"): + method.clear_cache() + return f def clear_cache(self):