Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Is there any chance to call backward function dircetly instead of using pytorch autograd mechanism? #73

Open
MayDomine opened this issue Nov 7, 2024 · 3 comments

Comments

@MayDomine
Copy link

MayDomine commented Nov 7, 2024

In my application, I have to call flex attention's backward in backward function of my autograd function.
But in my autograd function, I did a lot of things such as communication on query, key, value across multi GPUs. So I was wondering is there any change to call backward function of flex attention dircetly instead of using pytorch autograd mechanism.
FlashAttention exposed backward function call. It seems easy if we expect all saved tensor do not change shape and the mask is fixed.

@Chillee
Copy link
Contributor

Chillee commented Nov 11, 2024

This is something I've seen before - I think usually you can wrap it and manually call torch.autograd.grad with the forwards and backwards. But it might make sense to provide a more explicit way to do it - will think about it.

@MayDomine
Copy link
Author

Thank you for your response. While torch.autograd.grad can be useful in some scenarios, it doesn't allow modifications to the original autograd graph, which is necessary for specific optimizations.

For instance, in the case of RingAttention, I need to call flex_attention forward four times within my autograd function's forward pass and perform four backward passes in the backward pass. This requires saving four outputs, each with its own ctx that contains numerous tensors, which cannot be modified directly.

I'm not very familiar with AOTAutograd, but it seems that torch.compile might not be suitable for this particular case. Is there an alternative approach that could make this work? Any guidance would be greatly appreciated.

@insujang
Copy link

I also have the same problem in implementing RingAttention with FlexAttention. Can FlexAttention provide backward function to be directly called?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants