Skip to content

Commit

Permalink
Update README.md
Browse files Browse the repository at this point in the history
  • Loading branch information
ardagoreci authored Sep 16, 2024
1 parent eafeb26 commit 621531e
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ Despite these optimisations, our profiling experiments show that over 60% of the

[@alexzhang13](https://github.com/alexzhang13) implemented a custom MSA pair weighted averaging kernel in Triton that is fast and memory-efficient.

- We had observed that one of the key memory bottlnecks is the MSA pair weighted averaging operation. The AlphaFold3 paper states that this operation replaces the MSARowAttentionWithPairBias operation with a "cheaper" pair weighted averaging, but implementing this naively in PyTorch results in 4x increase in memory usage compared to the Deepspeed4Science MSARowAttentionWithPairBias kernel. We hypothesized that this was due to the memory efficiency gains from the tiling and recomputation tricks in FlashAttention, which is also incorporated into the Deepspeed4Science MSARowAttentionWithPairBias kernel.
- We had observed that one of the key memory bottlenecks is the MSA pair weighted averaging operation. The AlphaFold3 paper states that this operation replaces the MSARowAttentionWithPairBias operation with a "cheaper" pair weighted averaging, but implementing this naively in PyTorch results in 4x increase in memory usage compared to the Deepspeed4Science MSARowAttentionWithPairBias kernel. We hypothesized that this was due to the memory efficiency gains from the tiling and recomputation tricks in FlashAttention, which is also incorporated into the Deepspeed4Science MSARowAttentionWithPairBias kernel.
- A naive implementation of the pair weighted averaging allocates a (*, N_seq, N_res, N_res, heads, c_hidden) intermediate tensor, which is too large to fit in GPU memory for even moderately long sequences.
- Alex's kernel allows scaling the network to thousands of tokens on a single GPU!

Expand Down

0 comments on commit 621531e

Please sign in to comment.