Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Flex attention with dropout #77

Open
zbh2047 opened this issue Nov 13, 2024 · 3 comments
Open

Flex attention with dropout #77

zbh2047 opened this issue Nov 13, 2024 · 3 comments

Comments

@zbh2047
Copy link

zbh2047 commented Nov 13, 2024

Hi,
I found the flex attention package really useful and flexible. However, it seems that flex attention does not support dropout, which is quite widely adopted. I would like to know if this would be supported in future?

Besides, I also considered implementing dropout in the mask, although it is not equivalent to applying dropout after softmax. However, even in this setting, I am not sure how to make the implementation correct, as the dropout mask cannot be generated on the fly (it must be the same in both forward and backward propagation).

Can anyone elaborate on this? Thank you so much!

@drisspg
Copy link
Contributor

drisspg commented Nov 13, 2024

You are correct, we dont currently have post-softmax dropout implemented. We have this is a feature but we have seen decreasing adoption of this throughout the industry and don't have it high pri.

@zbh2047
Copy link
Author

zbh2047 commented Nov 14, 2024

Thank you for the reply. In this case, I just would like to know if it is possible to implement a pre-softmax dropout under the current framework. The main question here is whether I can use rand function within mask_mod or score_mod? Will the forward and backward process compute the same mask? Another question is, can I avoid the need to call the create block mask for different forward pass?
Look forward to your thought. Thank you!

@drisspg
Copy link
Contributor

drisspg commented Nov 16, 2024

So the naive way to implement this is

import torch

from torch.nn.attention.flex_attention import flex_attention, create_block_mask
from functools import partial

B, H, S, D = 1, 4, 256, 64

dropout_prob = 0.1
full_dropout = bool_mask = (torch.rand((B, H, S, D), device="cuda") > dropout_prob)

def dropout(score, b, h, q_idz, kv_idx):
    return torch.where(full_dropout[b, h, q_idz, kv_idx], -float("inf"), score)


if __name__ == "__main__":
    make_tensor = partial(torch.randn, (B, H, S, D), device="cuda", dtype=torch.float16, requires_grad=True)

    query, key, value = make_tensor(), make_tensor(), make_tensor()
    compiled_flex = torch.compile(flex_attention, fullgraph=True)
    out = compiled_flex(query, key, value, score_mod=dropout)
    print(out)

There is probs some of other fun things you can do to try and reduce the extra memory to store the mask but this is the most straightforward

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants