-
Notifications
You must be signed in to change notification settings - Fork 15
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feature request: Add Dropout #7
Comments
Yes. It is on our plan. Actually WIP. Since triton provides pseudo-random generator now we can implement an memory efficient flash attention with dropout without having to save the dropout mask(since it requires O(n^2) memory). The essence is to re-generate the same dropout masking as is used in the forward pass. |
I have implemented a prototype which seems to work here but it's hard to test correctness without separately implementing the dropout layer and checking, as uses a different random seed than torch. |
Yes, testing for randomness is tricky. There is no proper
This method can be applied to a separate dropout operator. It can also be applied to the dropout part of a more-complex-operator, but the overall testing for correctness is more complicated then operators without randomness involved. |
finally done in #23 |
The pytorch base implementation of
scaled_dot_product_attention
provides dropout as an arg. Fusing it into the triton kernel would replicate that functionality, as dropout is applied to the attention scores, not the output.In the CUDA version, it is supported here
There have been attempts at integrating into triton before
The text was updated successfully, but these errors were encountered: