You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
In my application, I have to call flex attention's backward in backward function of my autograd function.
But in my autograd function, I did a lot of things such as communication on query, key, value across multi GPUs. So I was wondering is there any change to call backward function of flex attention dircetly instead of using pytorch autograd mechanism.
FlashAttention exposed backward function call. It seems easy if we expect all saved tensor do not change shape and the mask is fixed.
The text was updated successfully, but these errors were encountered:
This is something I've seen before - I think usually you can wrap it and manually call torch.autograd.grad with the forwards and backwards. But it might make sense to provide a more explicit way to do it - will think about it.
Thank you for your response. While torch.autograd.grad can be useful in some scenarios, it doesn't allow modifications to the original autograd graph, which is necessary for specific optimizations.
For instance, in the case of RingAttention, I need to call flex_attention forward four times within my autograd function's forward pass and perform four backward passes in the backward pass. This requires saving four outputs, each with its own ctx that contains numerous tensors, which cannot be modified directly.
I'm not very familiar with AOTAutograd, but it seems that torch.compile might not be suitable for this particular case. Is there an alternative approach that could make this work? Any guidance would be greatly appreciated.
In my application, I have to call flex attention's backward in
backward
function of my autograd function.But in my autograd function, I did a lot of things such as communication on query, key, value across multi GPUs. So I was wondering is there any change to call backward function of flex attention dircetly instead of using pytorch autograd mechanism.
FlashAttention exposed backward function call. It seems easy if we expect all saved tensor do not change shape and the mask is fixed.
The text was updated successfully, but these errors were encountered: