Replies: 1 comment
-
I hope the idea of FFPA can help extend FlashAttention to support larger head dimensions (headdim), thus making it a more versatile algorithm. |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
Hello, @tridao. I noticed that you've recently been working on extending the head dimension (headdim) beyond 256 on the Hopper architecture. Coincidentally, I'm making a similar attempt, but not on the Hopper architecture. I've extended the headdim supported by FlashAttention to 1024, and it can even be larger. My main improvement to FlashAttention is to lower the tiling level from the Attention level to the MMA level, that is, to achieve a fully GEMM tiling style, and can achieve O(1)⚡️SRAM complexity. For the sake of concise expression, I've named this new Attention tiling technology FFPA: Yet antother Faster Flash Prefill Attention with O(1)⚡️SRAM complexity for headdim > 256, 1.8x~3x↑🎉faster vs SDPA EA.
As we know, in GEMM, we don't run out of shared memory (smem) due to an overly large K dimension. However, in FlashAttention, an overly large HeadDim can cause a shortage of smem, which is unreasonable. I've implemented some experimental kernels using MMA PTX, which have been open-sourced in ffpa-attn-mma, and only the forward pass is supported.
Extending the support of FlashAttention for large HeadDim is meaningful in the context of DeepSeek MLA. For example, when the headdim supported by FlashAttention exceeds 512, we can achieve fully Fused MLA in the prefill stage. Of course, this involves a large number of modifications to FlashAttention, and I can't complete all the modifications in my spare time. Therefore, I've opened this discussion to share my ideas.
As an experiment, I used MMA PTX to implement this fine - grained tiling and achieved good performance. I think similar things can also be done with CuTLASS. In my experiment, for FFPA when the headdim = 512 and the MMA Acc is F32, it can reach 105 TFLOPS (105/119 = 88%) on the NVIDIA L20, and 149 TFLOPS on the NVIDIA 4090.
*
=MMA Acc F32,^
=MMA Acc F16,T
=TFLOPS, ~1.8x↑🎉)*
=MMA Acc: QK F32 + PV F16,^
=MMA Acc F16,T
=TFLOPS, ~1.9x↑🎉)I hope the idea of FFPA can help extend FlashAttention to support larger head dimensions (headdim), thus making it a more versatile algorithm. The following is an introduction and benchmark of FFPA. For more details, please refer to the ffpa-attn-mma repository.
🤖FFPA: Yet antother Faster Flash Prefill Attention with O(1)⚡️GPU SRAM complexity for large headdim🐑
📚FFPA L1~L3 Design | 📈L20 ~1.9x↑🎉 | 📈A30 ~1.8x↑🎉 | 📈3080 ~2.9x↑🎉 | 📈4090 ~2.1x↑🎉🤖FFPA: 1.8x~3x🎉faster vs SDPA EA with or without MMA Acc F32
🤖[WIP] FFPA: Yet antother Faster Flash Prefill Attention with O(1) SRAM complexity & O(d/4) or O(1) register complexity for large headdim (D > 256), almost 1.8x~3x 🎉 faster than SDPA EA with or without MMA Acc F32 on many devices: 📈L20 ~1.9x↑🎉, 📈A30 ~1.8x↑🎉, 📈3080 ~2.9x↑🎉, 📈4090 ~2.1x↑🎉. FFPA Attention Algo: Fine-grained tiling for large headim, FA-2 Attention Algo: Coarse-grained tiling for small headidm.
💡NOTE: This project is still in its early dev stages and now provides some kernels and benchmarks for reference. More features will be added in the future.
📖 Contents
📖 FFPA L1~L3: FlashAttention + QKV Fine-grained Tiling at MMA level💡
We have extended FlashAttention for large headdim (D > 256) by implementing Fine-grained Tiling at the MMA level (GEMM style) for the Q@K^T and P@V matmul. This approach results in a constant SRAM usage of Br * 16 or Bc * 16 (Br = Bc) for Q, K, and V, leading to an overall SRAM complexity of O(2 * Br * 16) ≈ O(1) and a register complexity of O(d/4) or O(1). Consequently, this method allows us to extend headdim beyond 256 and achieve faster performance compared to SDPA with or without MMA Accumulation F32 (1.8x~3x 🎉 faster than SDPA EA).
We have named this new attention tiling technique FFPA: Faster Flash Prefill Attention. We have designed three
(L1~L3)
levels of FFPA based on SRAM and register complexity considerations. All levels will not introduce any additional VRAM requirements, ensuring that the HBM memory complexity remains same as FlashAttention. 👇By leveraging this approach, we can achieve better performance than SDPA EA for large headdim (D > 256). Approximate SRAM and register complexity analysis for L1~L3 is as follows: (
d
=headdim,C,Br,Bc
=Constant,Br=Bc
) 👇📚👇Core Features🎉🎉: I have implemented FFPA L1~L3 using pure MMA PTX instructions, which supports many features such as Split-Q, SMEM Swizzle/Padding, QKV Multi-Stages(1~4), Tile MMAs/Warps, Mixed MMA F32/F16 Acc (Q@K^T MMA Acc F32 + P@V MMA Acc F16), Fully Shared QKV SMEM, Prefetch QKV g2s, Persist Q s2r/g2s, Fully QKV Fine-grained Tiling(GEMM style), Collective Store, etc.
L1
kernel template signature: ffpa_attn_templates_L1.cuh📖 Prerequisites
📖 Installation
The FFPA implemented in this repo can be install as a python library, namely,
ffpa-attn
library (optional).📖 FFPA L1 (Level 1): Benchmark 🎉🎉
L1: level 1, O(2xBrx16)≈O(1) SRAM complexity, O(d/4) register complexity, the same GPU HBM memory complexity as FlashAttention. B=1, H=48, N=8192, D=320-1024(FA2 not supported 👀). (Notes,
*
=MMA Acc F32,^
=MMA Acc F16, Softmax Acc dtype is always be F32, T=TFLOPS, 👇Benchmark)*
=MMA Acc F32,^
=MMA Acc F16,T
=TFLOPS, ~1.8x↑🎉)*
=MMA Acc: QK F32 + PV F16,^
=MMA Acc F16,T
=TFLOPS, ~1.9x↑🎉)*
=MMA Acc F32,^
=MMA Acc F16,T
=TFLOPS, ~1.8x↑🎉)*
=MMA Acc: QK F32 + PV F16,^
=MMA Acc F16,T
=TFLOPS, ~1.9x↑🎉)*
=MMA Acc F32,^
=MMA Acc F16,T
=TFLOPS, ~2.5x↑🎉)*
=MMA Acc: QK F32 + PV F16,^
=MMA Acc F16,T
=TFLOPS, ~2.9x↑🎉)*
=MMA Acc F32,^
=MMA Acc F16,T
=TFLOPS, ~1.8x↑🎉)*
=MMA Acc: QK F32 + PV F16,^
=MMA Acc F16,T
=TFLOPS, ~2.1x↑🎉)📖 Python Testing
👇You can test many custom FFPA kernels via Python and figure out the difference in their performance. The
--gen-bench
and--plot
options help you generate a benchmark table in Markdown style and speedup bar plots on your device. Contributions of your benchmark tables and plots are welcome via a PR 🎉🎉.FA2 not supported
)📖 References
Beta Was this translation helpful? Give feedback.
All reactions