-
Notifications
You must be signed in to change notification settings - Fork 113
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Drive enzyme ops removal using pattern rewriter #2229
Conversation
|
||
LogicalResult matchAndRewrite(enzyme::EnzymeOpsRemoverOpInterface iface, | ||
PatternRewriter &rewriter) const final { | ||
return iface.removeEnzymeOps(rewriter); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
High level comment then we can merge: It's critical that we make sure that whatever region op we simplify doesn't have its interface called if there are interior ops for push/pop/set/get. Where is that check happening?
For example we could do something here that says
if (iface.walk(op)[]{ return isa enzyme push/pop op; })) return failure();
before calling iface.removeEnzymeOps
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So currently, implementers of the interface should make sure of that but that is not ideal.
I think in this case we want to process ops in post-order (like the walk did before) and newly created ops should be processed only if they are after in the post-order.
I will try to implement something like this.
@Pangoraw this good to go? |
Yes, this is ok to merge from my POV. I updated the corresponding Enzyme-JAX PR as well. |
This modifies the interface to notify of IR mutations via the pattern rewriter to prevent error during iteration.
Also fixes bugs in the scf.for implementation of the interface which resulted in non-deterministic test failures.