Skip to content

Commit

Permalink
FlashAttention-2 release
Browse files Browse the repository at this point in the history
  • Loading branch information
tridao committed Jul 17, 2023
1 parent 6d48e14 commit 4f285b3
Show file tree
Hide file tree
Showing 90 changed files with 6,867 additions and 10,641 deletions.
4 changes: 2 additions & 2 deletions .gitmodules
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
[submodule "csrc/flash_attn/cutlass"]
path = csrc/flash_attn/cutlass
[submodule "csrc/cutlass"]
path = csrc/cutlass
url = https://github.com/NVIDIA/cutlass.git
3 changes: 1 addition & 2 deletions AUTHORS
Original file line number Diff line number Diff line change
@@ -1,2 +1 @@
Tri Dao, [email protected]
Dan Fu, [email protected]
Tri Dao, [email protected]
2 changes: 2 additions & 0 deletions MANIFEST.in
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@ recursive-include csrc *.cu
recursive-include csrc *.h
recursive-include csrc *.cuh
recursive-include csrc *.cpp
recursive-include csrc *.hpp

recursive-include flash_attn *.cu
recursive-include flash_attn *.h
recursive-include flash_attn *.cuh
recursive-include flash_attn *.cpp
recursive-include flash_attn *.hpp
270 changes: 119 additions & 151 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,46 +1,30 @@
# FlashAttention
This repository provides the official implementation of FlashAttention from the
following paper.
This repository provides the official implementation of FlashAttention and
FlashAttention-2 from the
following papers.

**FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness**
Tri Dao, Daniel Y. Fu, Stefano Ermon, Atri Rudra, Christopher Ré
Paper: https://arxiv.org/abs/2205.14135
IEEE Spectrum [article](https://spectrum.ieee.org/mlperf-rankings-2022) about our submission to the MLPerf 2.0 benchmark using FlashAttention.
![FlashAttention](assets/flashattn_banner.jpg)

## Usage
**FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning**
Tri Dao

We've been very happy to see FlashAttention being widely adopted in such a short
time after its release. This [page](https://github.com/HazyResearch/flash-attention/blob/main/usage.md)
contains a partial list of places where FlashAttention is being used.
Paper: https://tridao.me/publications/flash2/flash2.pdf

## Full model code and training script
![FlashAttention-2](assets/flashattention_logo.png)

We have released the full GPT model
[implementation](https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/models/gpt.py).
We also provide optimized implementations of other layers (e.g., MLP, LayerNorm,
cross-entropy loss, rotary embedding). Overall this speeds up training by 3-5x
compared to the baseline implementation from Huggingface, reaching up to 189
TFLOPs/sec per A100, equivalent to 60.6\% model FLOPs utilization (we don't need
any activation checkpointing).

We also include a training
[script](https://github.com/HazyResearch/flash-attention/tree/main/training) to
train GPT2 on Openwebtext and GPT3 on The Pile.

## Triton implementation of FlashAttention

Phil Tillet (OpenAI) has an experimental implementation of FlashAttention in Triton:
https://github.com/openai/triton/blob/master/python/tutorials/06-fused-attention.py

As Triton is a higher-level language than CUDA, it might be easier to understand
and experiment with. The notations in the Triton implementation are also closer
to what's used in our paper.
## Usage

We also have an experimental implementation in Triton that support attention
bias (e.g. ALiBi):
https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/flash_attn_triton.py
We've been very happy to see FlashAttention being widely adopted in such a short
time after its release. This [page](https://github.com/Dao-AILab/flash-attention/blob/main/usage.md)
contains a partial list of places where FlashAttention is being used.

FlashAttention and FlashAttention-2 are free to use and modify (see LICENSE).
Please cite and credit FlashAttention if you use it.

## Installation and features

Expand All @@ -53,125 +37,116 @@ We recommend the
container from Nvidia, which has all the required tools to install FlashAttention.

To install:
1. Make sure that PyTorch is installed.
2. Make sure that `packaging` is installed (`pip install packaging`)
3. Make sure that `ninja` is installed and that it works correctly (e.g. `ninja
--version` then `echo $?` should return exit code 0). If not (sometimes `ninja
--version` then `echo $?` returns a nonzero exit code), uninstall then reinstall
`ninja` (`pip uninstall -y ninja && pip install ninja`). Without `ninja`
compiling can take a very long time (2h) since it does not use multiple CPU
cores. With `ninja` compiling takes 3-5 minutes on a 64-core machine.
4. Then:
```sh
pip install flash-attn
pip install flash-attn --no-build-isolation
```

Alternatively you can compile from source:
```
python setup.py install
```

Interface: `src/flash_attention.py`

To run the benchmark against PyTorch standard attention:
```
PYTHONPATH=$PWD python benchmarks/benchmark_flash_attention.py
```
Interface: `src/flash_attention_interface.py`

FlashAttention currently supports:
1. Turing, Ampere, Ada, or Hopper GPUs (e.g., H100, A100, RTX 3090, T4, RTX 2080).
2. fp16 and bf16 (bf16 requires Ampere, Ada, or Hopper GPUs).
3. Head dimensions that are multiples of 8, up to 128 (e.g., 8, 16, 24, ...,
128). Head dim > 64 backward requires A100 or H100.

Our tentative roadmap:
1. ~~[Jun 2022] Make package pip-installable~~[Done, thanks to lucidrains].
2. ~~[Jun 2022] Support SM86 GPUs (e.g., RTX 3080, 3090)~~[Done].
3. ~~[Jun 2022] Support SM75 GPUs (e.g. T4)~~[Done].
4. ~~[Jun 2022] Support bf16~~[Done].
5. ~~[Jul 2022] Implement cross-attention~~[Done].
6. ~~[Jul 2022] Support head dimension 128~~[Done].
7. ~~[Aug 2022] Fuse rotary embedding~~[Done].
8. ~~[Mar 2023] Support SM90 GPUs (H100)~~[Done].
FlashAttention-2 currently supports:
1. Ampere, Ada, or Hopper GPUs (e.g., A100, RTX 3090, RTX 4090, H100). Support for Turing
GPUs (T4, RTX 2080) is coming soon, please use FlashAttention 1.x for Turing
GPUs for now.
2. Datatype fp16 and bf16 (bf16 requires Ampere, Ada, or Hopper GPUs).
3. All head dimensions up to 256. Head dim > 192 backward requires A100/A800 or H100/H800.


## How to use FlashAttention

Here's a simple example:
```python
import torch
from flash_attn.flash_attention import FlashMHA

# Replace this with your correct GPU device
device = "cuda:0"

# Create attention layer. This is similar to torch.nn.MultiheadAttention,
# and it includes the input and output linear layers
flash_mha = FlashMHA(
embed_dim=128, # total channels (= num_heads * head_dim)
num_heads=8, # number of heads
device=device,
dtype=torch.float16,
)

# Run forward pass with dummy data
x = torch.randn(
(64, 256, 128), # (batch, seqlen, embed_dim)
device=device,
dtype=torch.float16
)

output = flash_mha(x)[0]
The main functions implement scaled dot product attention (softmax(Q @ K^T *
softmax_scale) @ V):
```
from flash_attn import flash_attn_qkvpacked_func, flash_attn_func
```

Alternatively, you can import the inner attention layer only (so that the input
and output linear layers are not included):
```python
from flash_attn.flash_attention import FlashAttention
```
flash_attn_qkvpacked_func(qkv, dropout_p=0.0, softmax_scale=None, causal=False):
"""dropout_p should be set to 0.0 during evaluation
If Q, K, V are already stacked into 1 tensor, this function will be faster than
calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
of the gradients of Q, K, V.
Arguments:
qkv: (batch_size, seqlen, 3, nheads, headdim)
dropout_p: float. Dropout probability.
softmax_scale: float. The scaling of QK^T before applying softmax.
Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
Return:
out: (batch_size, seqlen, nheads, headdim).
```

# Create the nn.Module
flash_attention = FlashAttention()
```
flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=None, causal=False):
"""dropout_p should be set to 0.0 during evaluation
Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
than Q. Note that the number of heads in KV must be divisible by the number of heads in Q.
For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
Arguments:
q: (batch_size, seqlen, nheads, headdim)
k: (batch_size, seqlen, nheads_k, headdim)
v: (batch_size, seqlen, nheads_k, headdim)
dropout_p: float. Dropout probability.
softmax_scale: float. The scaling of QK^T before applying softmax.
Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
Return:
out: (batch_size, seqlen, nheads, headdim).
```

Or, if you need more fine-grained control, you can import one of the lower-level
functions (this is more similar to the `torch.nn.functional` style):
```python
from flash_attn.flash_attn_interface import flash_attn_unpadded_func
To see how these functions are used in a multi-head attention layer (which
includes QKV projection, output projection), see the MHA [implementation](https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/modules/mha.py).

# or
## Upgrading from FlashAttention (1.x) to FlashAttention-2

from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_split_func
These functions have been renamed:
- `flash_attn_unpadded_func` -> `flash_attn_varlen_func`
- `flash_attn_unpadded_qkvpacked_func` -> `flash_attn_varlen_qkvpacked_func`
- `flash_attn_unpadded_kvpacked_func` -> `flash_attn_varlen_kvpacked_func`

# etc.
If the inputs have the same sequence lengths in the same batch, it is simpler
and faster to use these functions:
```
flash_attn_qkvpacked_func(qkv, dropout_p, softmax_scale=None, causal=False)
```

There are also separate Python files with various FlashAttention extensions:
```python
# Import the triton implementation (torch.nn.functional version only)
from flash_attn.flash_attn_triton import flash_attn_func

# Import block sparse attention (nn.Module version)
from flash_attn.flash_blocksparse_attention import FlashBlocksparseMHA, FlashBlocksparseAttention

# Import block sparse attention (torch.nn.functional version)
from flash_attn.flash_blocksparse_attn_interface import flash_blocksparse_attn_func
```
flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=None, causal=False)
```

## Speedup and Memory Savings
## Performance

We present expected speedup (combined forward + backward pass) and memory savings from using FlashAttention against PyTorch standard attention, depending on sequence length, on different GPUs (speedup depends on memory bandwidth - we see more speedup on slower GPU memory).

We currently have benchmarks for these GPUs:
* [A100](#a100)
* [RTX 3090](#rtx-3090)
* [T4](#t4)
* [H100](#h100)
<!-- * [RTX 3090](#rtx-3090) -->
<!-- * [T4](#t4) -->

### A100

We display FlashAttention speedup using these parameters (similar to BERT-base):
* Batch size 8
* Head dimension 64
* 12 attention heads

Our graphs show sequence lengths between 128 and 4096 (when standard attention runs out of memory on an A100), but FlashAttention can scale up to sequence length 64K.
We display FlashAttention speedup using these parameters:
* Head dimension 64 or 128, hidden dimension 2048 (i.e. either 32 or 16 heads).
* Sequence length 512, 1k, 2k, 4k, 8k, 16k.
* Batch size set to 16k / seqlen.

#### Speedup

![FlashAttention speedup](assets/flashattn_speedup.jpg)

We generally see 2-4X speedup at sequence lengths between 128 and 4K, and we see more speedup when using dropout and masking, since we fuse the kernels.
At sequence lengths that are popular with language models like 512 and 1K, we see speedups up to 4X when using dropout and masking.
![FlashAttention speedup on A100 80GB SXM5 with FP16/BF16](assets/flash2_a100_fwd_bwd_benchmark.png)

#### Memory

Expand All @@ -182,38 +157,37 @@ Memory savings are proportional to sequence length -- since standard attention h
We see 10X memory savings at sequence length 2K, and 20X at 4K.
As a result, FlashAttention can scale to much longer sequence lengths.

#### Head Dimension 128

![FlashAttention speedup, head dimension 128](assets/flashattn_speedup_a100_d128.jpg)

We show speedup with head dimension 128.
Here we show batch size 16 with 12 heads.
Speedup is less than with the smaller head sizes, since we have to make the block size smaller in the tiling.
But speedup is still significant, especially with a causal mask.

### RTX 3090
### H100

For the RTX 3090, we use batch size 12 with 12 attention heads.
Memory savings are the same as on an A100, so we'll only show speedup here.
![FlashAttention speedup on H100 SXM5 with FP16/BF16](assets/flash2_h100_fwd_bwd_benchmark.png)

![FlashAttention speedup GTX 3090](assets/flashattn_speedup_3090.jpg)

We see slightly higher speedups (between 2.5-4.5x) on the GTX 3090, since memory bandwidth on the GDDR6X is lower than A100 HBM (~900 GB/s vs. ~1.5 TB/s).
## Full model code and training script

### T4
We have released the full GPT model
[implementation](https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/models/gpt.py).
We also provide optimized implementations of other layers (e.g., MLP, LayerNorm,
cross-entropy loss, rotary embedding). Overall this speeds up training by 3-5x
compared to the baseline implementation from Huggingface, reaching up to 225
TFLOPs/sec per A100, equivalent to 72% model FLOPs utilization (we don't need
any activation checkpointing).

We again use batch size 12 with 12 attention heads.
We also include a training
[script](https://github.com/Dao-AILab/flash-attention/tree/main/training) to
train GPT2 on Openwebtext and GPT3 on The Pile.

![Flashattention speedup T4](assets/flashattn_speedup_t4.jpg)
## Triton implementation of FlashAttention

T4 SRAM is smaller than the newer GPUs (64 KB), so we see less speedup (we need to make the block sizes smaller, so we end up doing more R/W).
This matches the IO complexity analysis from section 3.2 of [our paper](https://arxiv.org/abs/2205.14135).
Phil Tillet (OpenAI) has an experimental implementation of FlashAttention in Triton:
https://github.com/openai/triton/blob/master/python/tutorials/06-fused-attention.py

T4 GPUs are commonly used for inference, so we also measure speedup on the forward pass only (note that these are not directly comparable to the graphs above):
As Triton is a higher-level language than CUDA, it might be easier to understand
and experiment with. The notations in the Triton implementation are also closer
to what's used in our paper.

![FlashAttention speedup T4 fwd](assets/flashattn_speedup_t4_fwd.jpg)
We also have an experimental implementation in Triton that support attention
bias (e.g. ALiBi):
https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/flash_attn_triton.py

We see speedups between 2.5x-4.5x on the forward pass.

## Tests
We test that FlashAttention produces the same output and gradient as a reference
Expand All @@ -228,21 +202,10 @@ pytest -q -s tests/test_flash_attn.py
```
## When you encounter issues

This alpha release of FlashAttention contains code written for a research
project to validate ideas on speeding up attention.
We have tested it on several models (BERT, GPT2, ViT).
However, there might still be bugs in the implementation that we hope to iron
out in the next few months.
This new release of FlashAttention-2 have been tested on several GPT-style
models, mostly on A100 GPUs.

If you encounter any of these bugs, please open a respective GitHub Issue!

## Acknowledgments
Our implementation uses Apex's
[FMHA](https://github.com/NVIDIA/apex/tree/master/apex/contrib/csrc/fmha) code
as a starting point.

We thank [Young-Jun Ko](https://yjk21.github.io/) for the in-depth explanation of his FMHA implementation
and for his thoughtful answers to our questions about CUDA.
If you encounter any of bugs, please open a respective GitHub Issue!

## Citation
If you use this codebase, or otherwise found our work valuable, please cite:
Expand All @@ -253,4 +216,9 @@ If you use this codebase, or otherwise found our work valuable, please cite:
booktitle={Advances in Neural Information Processing Systems},
year={2022}
}
@article{dao2023flashattention2,
title={Flash{A}ttention-2: Faster Attention with Better Parallelism and Work Partitioning,
author={Dao, Tri},
year={2023}
}
```
Binary file added assets/flash2_a100_fwd_bwd_benchmark.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/flash2_h100_fwd_bwd_benchmark.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/flashattention_logo.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading

0 comments on commit 4f285b3

Please sign in to comment.