Sliding window support for GPU flash attention #962
Merged
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
This PR aims to address the initial SRAM OOM issue in the current GPU flash attention due to large bias matrix block loaded (along with segment id loaded), and through investigation, we also apply two optimizations bundled in the change, following splash attention optimization in TPU:
Through benchmark, we see a much superior performance for seq_len > 16k with a sliding window=4k. jax-cudnn performs poorly due to a naive implementation through bias_mask, while pallas doesn't support sliding window and no naive masking support.
There is very minor performance penalty (<5%) for regular causal case, since it requires a one more memory lookup to retrieve the dynamic index. However, this consistent change gives a better code readability
Benchmark result: