-
Notifications
You must be signed in to change notification settings - Fork 24
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: clear jax cache on register #158
base: master
Are you sure you want to change the base?
Conversation
Pull Request Test Coverage Report for Build 9419781118Details
💛 - Coveralls |
Pull Request Test Coverage Report for Build 9419813345Details
💛 - Coveralls |
Signed-off-by: nstarman <[email protected]>
Pull Request Test Coverage Report for Build 9419998190Details
💛 - Coveralls |
@wesselb @PhilipVinc what do you think of this simple solution? There's definitely room for improvement, as indicated in the comment, but I think since this is private API it's good enough to get in as is, then iterate. WDYT? |
I would be a bit against this. The thing you proposed works when directly jitting a dispatch function, like @dispatch
@jit
def myfun(a: jax.Array, b:int):
return a * b
@dispatch
@jit
def myfun(a: jax.Array, b:float):
return a + b however it does not detect usages like @dispatch
def myfun(a: jax.Array, b:int):
return a * b
@dispatch
def myfun(a: jax.Array, b:float):
return a + b
@jit
def my_algorithm(a, b):
return 12 * myfun(a, b)
# this causes jax to jit,
my_algorithm(a, b)
@dispatch
def myfun(a: jax.Array, b:float):
return a + b + 10
# jax has no idea that the function changed, so he will return same result as before
my_algorithm(a, b) This also applies to your function if inside a jitted function you call other dispatch functions. This is the use-case that should be addressed in my opinion, otherwise it's too brittle. |
@PhilipVinc some really good points! Just for completeness, this also helps address the case @jit
@dispatch
def myfun(a: jax.Array, b:int):
return a * b
@jit
@dispatch
def myfun(a: jax.Array, b:float):
return a + b (The decorator order is important jit-wise) But your are correct that this doesn't work for "higher-order" functions that call other functions which themselves are multiply-dispactched, as you show later in your comment. |
@wesselb, maybe is this related to your suggestion of a |
How does your PR work with |
See the opening comment of #154 (comment). I haven't tested your "higher-order" case, but I hope (🤞) adding
If the But overall I agree this is still brittle and thus not an optimal solution! |
But isn't this # Hooks to clear the JIT caches of various libraries.
# TODO: this needs to be systematized. This should probably work
# by entry points to allow for any JIT library to be added.
# JAX:
if "jax" in type(method).__module__ and hasattr(method, "clear_cache"):
method.clear_cache() only going to detect if you import jax
def dispatch(fun):
print("Fun:", type(fun).__module__)
return fun
@jax.jit
@dispatch
def test(x, y):
return x+y
print("fun now:", type(test).__module__) outputs
so in this case it would not be triggered |
Ah. Good catch. I didn't write tests yet for this, which would have caught this mistake. This is the difference between the example I gave in #154 and here. |
The only way that I can think of, brutal and not ideal, is to clear all jax caches with |
It's definitely not ideal, but so long as a user imports in all modules (as in registers everything) then it's not so bad. It's only when new dispatches are registered that this becomes annoying. |
Maybe @patrick-kidger can advise on how to best make |
AFAIK it's not really possible.
Tbh I think supporting this is an antipattern anyway. Everything prior to That aside I really don't like the idea that |
@patrick-kidger. Thanks for the info! |
That sounds much more reasonable to me! Having the extent of a cross-library interaction be a "hey, did you really mean this?" warning sounds like a good compromise I think. |
This also sounds like a good approach to me. :)
What I had in mind was slightly different. The cache would remember all dispatch decisions. The next time the function is called, instead of running the resolver on the arguments, the resolver would not look at the arguments, but immediately return the previously resolved method from the cache. This way you could "compile away" the overhead of dispatch. |
The idea is to "structurally" identify jax. jax.jit produces as jaxlib object with attribute clear_cache. That's sufficient to identify JAX jitted functions.
I've put notes for longer-term plans.
Needs tests.