-
Notifications
You must be signed in to change notification settings - Fork 68
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
JaxTyping introduces reference cycles that would otherwise not exist #258
Comments
Thanks for the report! I agree, being able to avoid reference cycles would be nice. I suspect the first one could be broken by using a weak reference. We really do need to be able to get (Well, two caveats here:
Meanwhile I suspect the second one could be manually broken via something like If you're willing to submit a PR + test then I'd be very happy to take that :) |
Regarding the second one:
causes a lot of tests to fail. I was able to "fix" it by adding a
That said, I don't fully understand what this code is doing, so I don't know whether I'm breaking something else :). Does the |
Yup I think that |
Second change: #260 |
Okay, I think this is done! :) Are there any other examples you're bumping into or shall we close this? |
There aren't any other examples that I'm bumping into. If I find any, I'll send some more pull requests, but no need to keep this issue open just in case. Thanks for the prompt reviews and helping to get these changes landed! |
Particularly when building large/complicated systems, it can end up being highly preferable to delete Python objects by decrementing their reference counts to 0, as opposed to relying on the Python Garbage Collector (which is needed to delete unreachable objects that are either in or are referenced by a reference cycle).
For this reason, it would be really nice if decorating code with JaxTyping could avoid introducing any reference cycles where reference cycles do not already exist in the code being decorated. This is not currently the case. There are at least two places I know of where JaxTyping introduces reference cycles even if the user code does not contain any:
wrapped_fn
refering back to itself)fn
<->scope
reference cycle here)Do you think avoiding reference cycles is a nice property to aim for, and do you think it should be achievable?
I think it should be possible to write a test helper that would let you do something like:
and will tell you what the reference cycles are if it fails. Let me know if that would be helpful, and I can look into sending a pull request.
The text was updated successfully, but these errors were encountered: