You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Currently the mlx.fast.scaled_dot_product_attention function throws an error if the mask cannot be broadcast to the shape of the scores. This requires you to supply a full Q x K mask when Q != 1, which can get large, especially with long contexts. But there are many use cases where it would be beneficial to use a Q x Q mask, and treat the first K-Q keys as causally unmasked. One use case is pre-filling the cache in chunks from a long prompt, where you could re-use a prefill_step_size x prefill_step_size mask. Another is speculative token generation where you want to guess a second token based on the sampled one and send both through the model in the next step (rejecting the guess if its logprob is too low). Right now, this requires calculating a 2 x K causal mask on each pass instead of a 2 x 2.
In python, if you are manually calculating attention and mask.shape[-2:] != scores.shape[-2:], it's easy enough to do:
To get a Q x Q mask to only affect the lower right corner of the scores and mask them causally.
But it would be nice if the mlx.fast.mlx.fast.scaled_dot_product_attention implementation supported this. I have an diff to mlx/fast.cpp that implements this but it doesn't look like there are any tests for the function yet. I'd be willing to try to make some (to test the original functionality and this new addition) but I wanted to check and see if people thought this was a good addition and workable before doing that and creating a pull request.
Let me know and thanks for an awesome framework.
-rv
The text was updated successfully, but these errors were encountered:
There are tests for scaled_dot_product_attention in python/tests/test_fast_sdpa.py.
I don't think I like the concept of assuming a causal mask for the omitted part of the mask. However, we are looking into introducing the ability to specify that the mask is causal by passing the string "causal" instead of a mask. I assume that would cover your use case even more succinctly.
Currently the
mlx.fast.scaled_dot_product_attention
function throws an error if the mask cannot be broadcast to the shape of the scores. This requires you to supply a full Q x K mask when Q != 1, which can get large, especially with long contexts. But there are many use cases where it would be beneficial to use a Q x Q mask, and treat the first K-Q keys as causally unmasked. One use case is pre-filling the cache in chunks from a long prompt, where you could re-use aprefill_step_size
xprefill_step_size
mask. Another is speculative token generation where you want to guess a second token based on the sampled one and send both through the model in the next step (rejecting the guess if its logprob is too low). Right now, this requires calculating a 2 x K causal mask on each pass instead of a 2 x 2.In python, if you are manually calculating attention and
mask.shape[-2:] != scores.shape[-2:]
, it's easy enough to do:scores[..., -mask.shape(-2):, -mask.shape(-1):] += mask
To get a Q x Q mask to only affect the lower right corner of the scores and mask them causally.
But it would be nice if the
mlx.fast.mlx.fast.scaled_dot_product_attention
implementation supported this. I have an diff to mlx/fast.cpp that implements this but it doesn't look like there are any tests for the function yet. I'd be willing to try to make some (to test the original functionality and this new addition) but I wanted to check and see if people thought this was a good addition and workable before doing that and creating a pull request.Let me know and thanks for an awesome framework.
-rv
The text was updated successfully, but these errors were encountered: