Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: group-query-attention implementation #74

Merged
merged 24 commits into from
Mar 20, 2024
Merged

feat: group-query-attention implementation #74

merged 24 commits into from
Mar 20, 2024

Conversation

flxst
Copy link
Member

@flxst flxst commented Mar 12, 2024

Re-opened version of #41.

Potential solution for handling the combination of GQA and FlashAttention: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html

@flxst flxst added the enhancement New feature or request label Mar 12, 2024
fromm-m
fromm-m previously approved these changes Mar 12, 2024
@fromm-m fromm-m dismissed their stale review March 12, 2024 14:01

wrong PR

@le1nux le1nux force-pushed the main branch 3 times, most recently from cb6e816 to 179052b Compare March 13, 2024 22:14
@lhahn-iis
Copy link
Contributor

Short update after our changes from today:

After our decision today to try out https://github.com/Dao-AILab/flash-attention/tree/main instead of our own (Group Query) Attention implementation, we provided a first draft.

Now two things remain to be done:

  • Benchmark the new implementation against the previous. For this, @fromm-m agreed to launch some test setups on Leonardo and check the throughput
  • Add a remark about the installation of flash-attn. The issues is currently, that flash-attn requires some prerequesites to get installed. It needs atleast CUDA 11.6 installed. On top, one needs to install some build dependencies. The authors mentioned here to use a non-isolated build environment (probably to access the installed CUDA version, to compile stuff). Unfortunately, this is not trivial to represent without our own pyproject.toml. My suggestion here would be either to check, if this is somehow achievable with a clever trick (e.g. using this, but I have some real doubts if this works). Alternatively we could add a respective remark in the README.md.

I won't be able to have a look at this until thursday this week.

@fromm-m
Copy link
Member

fromm-m commented Mar 19, 2024

Short update after our changes from today:

After our decision today to try out https://github.com/Dao-AILab/flash-attention/tree/main instead of our own (Group Query) Attention implementation, we provided a first draft.

Now two things remain to be done:

  • Benchmark the new implementation against the previous. For this, @fromm-m agreed to launch some test setups on Leonardo and check the throughput
  • Add a remark about the installation of flash-attn. The issues is currently, that flash-attn requires some prerequesites to get installed. It needs atleast CUDA 11.6 installed. On top, one needs to install some build dependencies. The authors mentioned here to use a non-isolated build environment (probably to access the installed CUDA version, to compile stuff). Unfortunately, this is not trivial to represent without our own pyproject.toml. My suggestion here would be either to check, if this is somehow achievable with a clever trick (e.g. using this, but I have some real doubts if this works). Alternatively we could add a respective remark in the README.md.

I won't be able to have a look at this until thursday this week.

Regarding your first point:

I did a benchmark and it works even faster than the previous pytorch flash attention implementation on a 3B paramter scale.
Bildschirmfoto 2024-03-19 um 14 32 22

Regarding your second point:
We opened a new Issue #86 to refactor the readme, where we will also describe the installation of FlashAttention.

@le1nux le1nux self-requested a review March 20, 2024 13:03
@fromm-m fromm-m requested a review from mali-git March 20, 2024 13:15
@mali-git mali-git merged commit adb1f28 into main Mar 20, 2024
@fromm-m fromm-m deleted the GQA_2 branch June 17, 2024 12:22
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants