Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

scaled_dot_product_attention: allow Q x Q mask instead of just Q x K mask #1842

Open
rvermillion opened this issue Feb 7, 2025 · 1 comment

Comments

@rvermillion
Copy link

Currently the mlx.fast.scaled_dot_product_attention function throws an error if the mask cannot be broadcast to the shape of the scores. This requires you to supply a full Q x K mask when Q != 1, which can get large, especially with long contexts. But there are many use cases where it would be beneficial to use a Q x Q mask, and treat the first K-Q keys as causally unmasked. One use case is pre-filling the cache in chunks from a long prompt, where you could re-use a prefill_step_size x prefill_step_size mask. Another is speculative token generation where you want to guess a second token based on the sampled one and send both through the model in the next step (rejecting the guess if its logprob is too low). Right now, this requires calculating a 2 x K causal mask on each pass instead of a 2 x 2.

In python, if you are manually calculating attention and mask.shape[-2:] != scores.shape[-2:], it's easy enough to do:

scores[..., -mask.shape(-2):, -mask.shape(-1):] += mask

To get a Q x Q mask to only affect the lower right corner of the scores and mask them causally.

But it would be nice if the mlx.fast.mlx.fast.scaled_dot_product_attention implementation supported this. I have an diff to mlx/fast.cpp that implements this but it doesn't look like there are any tests for the function yet. I'd be willing to try to make some (to test the original functionality and this new addition) but I wanted to check and see if people thought this was a good addition and workable before doing that and creating a pull request.

Let me know and thanks for an awesome framework.

-rv

@angeloskath
Copy link
Member

Thanks for your good words!

There are tests for scaled_dot_product_attention in python/tests/test_fast_sdpa.py.

I don't think I like the concept of assuming a causal mask for the omitted part of the mask. However, we are looking into introducing the ability to specify that the mask is causal by passing the string "causal" instead of a mask. I assume that would cover your use case even more succinctly.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants