Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: clear jax cache on register #158

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from

Conversation

nstarman
Copy link
Contributor

@nstarman nstarman commented Jun 7, 2024

The idea is to "structurally" identify jax. jax.jit produces as jaxlib object with attribute clear_cache. That's sufficient to identify JAX jitted functions.
I've put notes for longer-term plans.

Needs tests.

@coveralls
Copy link

coveralls commented Jun 7, 2024

Pull Request Test Coverage Report for Build 9419781118

Details

  • 1 of 2 (50.0%) changed or added relevant lines in 1 file are covered.
  • No unchanged relevant lines lost coverage.
  • Overall coverage decreased (-0.08%) to 99.839%

Changes Missing Coverage Covered Lines Changed/Added Lines %
plum/dispatcher.py 1 2 50.0%
Totals Coverage Status
Change from base Build 9340121872: -0.08%
Covered Lines: 1237
Relevant Lines: 1239

💛 - Coveralls

@coveralls
Copy link

coveralls commented Jun 7, 2024

Pull Request Test Coverage Report for Build 9419813345

Details

  • 1 of 2 (50.0%) changed or added relevant lines in 1 file are covered.
  • No unchanged relevant lines lost coverage.
  • Overall coverage decreased (-0.08%) to 99.839%

Changes Missing Coverage Covered Lines Changed/Added Lines %
plum/dispatcher.py 1 2 50.0%
Totals Coverage Status
Change from base Build 9340121872: -0.08%
Covered Lines: 1237
Relevant Lines: 1239

💛 - Coveralls

@coveralls
Copy link

coveralls commented Jun 7, 2024

Pull Request Test Coverage Report for Build 9419998190

Details

  • 1 of 2 (50.0%) changed or added relevant lines in 1 file are covered.
  • No unchanged relevant lines lost coverage.
  • Overall coverage decreased (-0.08%) to 99.839%

Changes Missing Coverage Covered Lines Changed/Added Lines %
plum/dispatcher.py 1 2 50.0%
Totals Coverage Status
Change from base Build 9340121872: -0.08%
Covered Lines: 1237
Relevant Lines: 1239

💛 - Coveralls

@nstarman
Copy link
Contributor Author

nstarman commented Jun 9, 2024

@wesselb @PhilipVinc what do you think of this simple solution? There's definitely room for improvement, as indicated in the comment, but I think since this is private API it's good enough to get in as is, then iterate. WDYT?

@PhilipVinc
Copy link
Collaborator

I would be a bit against this.

The thing you proposed works when directly jitting a dispatch function, like

@dispatch
@jit
def myfun(a: jax.Array, b:int):
   return a * b

@dispatch
@jit
def myfun(a:  jax.Array, b:float):
   return a + b

however it does not detect usages like

@dispatch
def myfun(a: jax.Array, b:int):
   return a * b

@dispatch
def myfun(a:  jax.Array, b:float):
   return a + b

@jit
def my_algorithm(a, b):
    return 12 * myfun(a, b)

# this causes jax to jit, 
my_algorithm(a, b)

@dispatch
def myfun(a:  jax.Array, b:float):
   return a + b + 10

# jax has no idea that the function changed, so he will return same result as before
my_algorithm(a, b)

This also applies to your function if inside a jitted function you call other dispatch functions.

This is the use-case that should be addressed in my opinion, otherwise it's too brittle.

@nstarman
Copy link
Contributor Author

nstarman commented Jun 9, 2024

@PhilipVinc some really good points!

Just for completeness, this also helps address the case

@jit
@dispatch
def myfun(a: jax.Array, b:int):
   return a * b

@jit
@dispatch
def myfun(a:  jax.Array, b:float):
   return a + b

(The decorator order is important jit-wise)

But your are correct that this doesn't work for "higher-order" functions that call other functions which themselves are multiply-dispactched, as you show later in your comment.

@nstarman
Copy link
Contributor Author

nstarman commented Jun 9, 2024

@wesselb, maybe is this related to your suggestion of a cache ? Or is this a separate complication?

@PhilipVinc
Copy link
Collaborator

How does your PR work with jit(dispatch(...)) ?
the dispatch logic is executed before the hitting happens so I do not understand how it would work correctly in this case?

@nstarman
Copy link
Contributor Author

nstarman commented Jun 9, 2024

See the opening comment of #154 (comment).
jit compiles away the plum dispatching, which is great for speeding up the code. Prior to this PR, when a new dispatch was added JAX did not update the plum dispatch it was going to use, keeping its internally compiled dispatch. This is exactly what you point out in my_algorithm, but applies here to myfun. With this PR when a new dispatch is added to myfun it wipes the cache and forces the next usage to JIT again, thus keeping JAX in sync with plum.

I haven't tested your "higher-order" case, but I hope (🤞) adding jit to myfun would actually work with this PR because
my_algorithm wouldn't compile myfunc away, but keep it as a distinct function.

@jit
@dispatch
def myfun(a: jax.Array, b:int):
   return a * b

@jit
@dispatch
def myfun(a:  jax.Array, b:float):
   return a + b

@jit
def my_algorithm(a, b):
    return 12 * myfun(a, b)  # myfun is not compiled away, depending on the JAX version (see discussion in https://github.com/google/jax/issues/9298)

If the jit on myfun had inline=True, then IDK what would happen.

But overall I agree this is still brittle and thus not an optimal solution!

@PhilipVinc
Copy link
Collaborator

But isn't this

        # Hooks to clear the JIT caches of various libraries.
         # TODO: this needs to be systematized. This should probably work
         #       by entry points to allow for any JIT library to be added.
         # JAX:
         if "jax" in type(method).__module__ and hasattr(method, "clear_cache"):
             method.clear_cache()

only going to detect if you @dispatch@jit?

import jax

def dispatch(fun):
	print("Fun:", type(fun).__module__)
	return fun


@jax.jit
@dispatch
def test(x, y):
	return x+y

print("fun now:", type(test).__module__)

outputs

Fun: builtins
fun now: jaxlib.xla_extension

so in this case it would not be triggered

@nstarman
Copy link
Contributor Author

Ah. Good catch. I didn't write tests yet for this, which would have caught this mistake. This is the difference between the example I gave in #154 and here.
Yes, I need to refactor this to search through all the registrants, not the current method.
We definitely do want to detect both jit(dispatch(... and dispatch(jit(....
But neither approach will work for jitted functions that call a separate dispatched function.
Is there any way around this?!

@PhilipVinc
Copy link
Collaborator

The only way that I can think of, brutal and not ideal, is to clear all jax caches with jax.clear_caches() every time we register a new function.

@nstarman
Copy link
Contributor Author

It's definitely not ideal, but so long as a user imports in all modules (as in registers everything) then it's not so bad. It's only when new dispatches are registered that this becomes annoying.

@nstarman
Copy link
Contributor Author

Maybe @patrick-kidger can advise on how to best make plum and jax play well together?

@patrick-kidger
Copy link

patrick-kidger commented Jun 11, 2024

AFAIK it's not really possible.

  • Adding a new dispatch rule is mutating global state (the methods table);
  • I do not know of any way to have JAX automatically re-JIT when changing global state.

Tbh I think supporting this is an antipattern anyway. Everything prior to jax.jit is basically the 'source code', which on your first run you then compile. Looking at other languages: trying to automatically recompile on detecting your source code is nonstandard in C++/Rust-type languages, and in Julia caused a lot of headache: the fact that they were handling this meant they were silently footgunning their compiletimes with lots of cache invalidations.

That aside I really don't like the idea that plum should try to do something based on specific third-party libraries. If everyone did this it would be incredibly hard to reason about what my code does. I would much rather they each just do their thing without trying to interfere with each other.

@nstarman
Copy link
Contributor Author

@patrick-kidger. Thanks for the info!
Are you suggesting that instead we just add something to the docs showing how jit + dispatch can work together, but also showing how this can be dangerous when adding new dispatches, and suggesting jax.clear_caches for that case?
Instead of clearing the cache, this PR could check if there is a non-empty cache and raise a warning.

@patrick-kidger
Copy link

patrick-kidger commented Jun 11, 2024

That sounds much more reasonable to me!

Having the extent of a cross-library interaction be a "hey, did you really mean this?" warning sounds like a good compromise I think.

@wesselb
Copy link
Member

wesselb commented Jun 13, 2024

Are you suggesting that instead we just add something to the docs showing how jit + dispatch can work together, but also showing how this can be dangerous when adding new dispatches, and suggesting jax.clear_caches for that case?
Instead of clearing the cache, this PR could check if there is a non-empty cache and raise a warning.

This also sounds like a good approach to me. :)

@wesselb, maybe is this related to your suggestion of a cache ? Or is this a separate complication?

What I had in mind was slightly different. The cache would remember all dispatch decisions. The next time the function is called, instead of running the resolver on the arguments, the resolver would not look at the arguments, but immediately return the previously resolved method from the cache. This way you could "compile away" the overhead of dispatch.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants