Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sliding window support for GPU flash attention #962

Merged
merged 8 commits into from
Feb 1, 2025

Conversation

kelvin-zou
Copy link
Contributor

This PR aims to address the initial SRAM OOM issue in the current GPU flash attention due to large bias matrix block loaded (along with segment id loaded), and through investigation, we also apply two optimizations bundled in the change, following splash attention optimization in TPU:

  1. support random mask_fn in GPU flash attention kernel,
  2. remove causal since it is a special case, and this change made it consistent with the TPU pallas flash attention kernel.
  3. optimize sliding window in flash attention via adding a prebuilt index_offsets, so that we can skip the empty blocks.

Through benchmark, we see a much superior performance for seq_len > 16k with a sliding window=4k. jax-cudnn performs poorly due to a naive implementation through bias_mask, while pallas doesn't support sliding window and no naive masking support.
There is very minor performance penalty (<5%) for regular causal case, since it requires a one more memory lookup to retrieve the dynamic index. However, this consistent change gives a better code readability

Benchmark result:

setup jax axlearn jax-cudnn jax-pallas
num_heads=4,seq_len=8192,sw_sz=1024 2.555328 0.307232 1.613344 1.087104
num_heads=4,seq_len=8192,sw_sz=4096 2.554784 0.937472 1.610048 1.092768
num_heads=4,seq_len=16384,sw_sz=1024 14.380320 0.591552 4.430400 3.901536
num_heads=4,seq_len=16384,sw_sz=4096 14.342656 1.857568 4.413408 3.879168
num_heads=4,seq_len=32768,sw_sz=1024 44.018017 1.144128 16.732927 14.926560
num_heads=4,seq_len=32768,sw_sz=4096 44.005024 3.651648 16.696735 14.951200

@kelvin-zou kelvin-zou marked this pull request as ready for review February 1, 2025 01:20
@kelvin-zou kelvin-zou requested review from ruomingp, markblee and a team as code owners February 1, 2025 01:20
@kelvin-zou kelvin-zou added this pull request to the merge queue Feb 1, 2025
Merged via the queue into apple:main with commit c1c6e29 Feb 1, 2025
6 checks passed
@kelvin-zou kelvin-zou deleted the sliding_window branch February 1, 2025 06:57
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants