Skip to content

Commit

Permalink
update readme with fav3 and no fa usage (#109)
Browse files Browse the repository at this point in the history
  • Loading branch information
feifeibear authored Dec 11, 2024
1 parent 26d3abd commit 6b532f3
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 25 deletions.
66 changes: 41 additions & 25 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,44 @@ Furthermore, Ring-Attention utilizes asynchronous peer-to-peer communication, wh
<img src="./media/usp.png">
</p>

### 1. Usage

### 1. Installation

FlashAttention is the most important external dependency and is often the cause of errors when installing and using yunchang. Yunchang supports flash_attn 2.6.x and 2.7.x, both v3 and v2 versions. Additionally, yunchang supports using torch's SPDA for sequence parallelism without installing flash_attn.

As shown in the figure below, there are three usage methods based on the flash_attn situation:

1. For H100, B100, hardware that supports FA v3, ring_flash_attn uses FA v3.

2. For A100, L40, hardware that supports FA v2, ring_flash_attn uses FA v2.

3. For hardware such as NPUs that does not support FA, use torch's SPDA. In this case, there is no need to install `flash_attn`, and you should apply `UlyssesAttention(sp_pg, attn_type=FlashAttentionImpl.TORCH)`.

<p align="center">
<img src="./media/usp_fa.png">
</p>

Option 1: pip install

`pip install flash-attn`

`pip install yunchang`

#### Apply FlashAttention V3: Since FA V3 is beta-released, you need to install FlashAttention V3 from source code.

Follow the [FlashAttention beta-release](https://github.com/Dao-AILab/flash-attention?tab=readme-ov-file#flashattention-3-beta-release) to install V3 for NVIDIA Hopper GPUs.

We applied the Nov 10 2024 commit `b443207c1fc4c98e4532aad4e88cfee1d590d996`.


Option 2: build from local.

`pip install .`

Install for AMD GPU: [install_amd.md](./docs/install_amd.md)


### 2. Usage

Please refer to [test/test_hybrid_qkvpacked_attn.py](./test/test_hybrid_qkvpacked_attn.py) and [test/test_hybrid_attn.py](./test/test_hybrid_attn.py) for usage.

Expand Down Expand Up @@ -62,13 +99,15 @@ set_seq_parallel_pg(sp_ulysses_degree, sp_ring_degree, rank, world_size)
# attn_type could be FA, FA3, TORCH.
longctx_attn = LongContextAttention(ring_impl_type="zigzag", attn_type=FlashAttentionImpl.FA)

# if you use Ulysses, where no flash_attn is supported, you can use the following code.
# UlyssesAttention(sp_pg, attn_type=FlashAttentionImpl.TORCH)

# extract a local shard for the global Q, K, V.
local_q = EXTRACT_FUNC_DICT["zigzag"](
Q, rank, world_size=world_size, rd=sp_ring_degree, ud=sp_ulysses_degree
).detach().clone()
...


local_out = usp_attn(
local_q,
local_k,
Expand All @@ -84,29 +123,6 @@ local_out = usp_attn(

```


### 2. Installation

Option 1: pip install from pypi.

`pip install yunchang` (flash_attn >= 2.6.0)

`pip install yunchang==0.2` (flash_attn < 2.6.0)

#### Apply FlashAttention V3: Since FA V3 is beta-released, you need to install FlashAttention V3 from source code.

Follow the [FlashAttention beta-release](https://github.com/Dao-AILab/flash-attention?tab=readme-ov-file#flashattention-3-beta-release) to install V3 for NVIDIA Hopper GPUs.

We applied the Nov 10 2024 commit `b443207c1fc4c98e4532aad4e88cfee1d590d996`.


Option 2: build from local.

`pip install .`

Install for AMD GPU: [install_amd.md](./docs/install_amd.md)


### 3.Test

```bash
Expand Down
1 change: 1 addition & 0 deletions yunchang/hybrid/attn_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@ def __init__(
self.use_sync = use_sync
self.ring_attn_fn = RING_IMPL_QKVPACKED_DICT[ring_impl_type]
self.attn_type = attn_type

def forward(
self,
qkv,
Expand Down

0 comments on commit 6b532f3

Please sign in to comment.