Skip to content

Commit

Permalink
[ROCm][Hardware][AMD] Enable group query attention for triton FA (#4406)
Browse files Browse the repository at this point in the history
  • Loading branch information
hongxiayang authored Apr 27, 2024
1 parent 87f545b commit 18d23f6
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 41 deletions.
53 changes: 25 additions & 28 deletions vllm/attention/backends/rocm_flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,36 +253,31 @@ def forward(
# triton attention
# When block_tables are not filled, it means q and k are the
# prompt, and they have the same length.
if self.use_triton_flash_attn or self.use_naive_attn:
if self.use_triton_flash_attn:
out, _ = self.attn_func(
query,
key,
value,
None,
prefill_meta.seq_start_loc,
prefill_meta.seq_start_loc,
prefill_meta.max_prompt_len,
prefill_meta.max_prompt_len,
True,
self.scale,
)
elif self.use_naive_attn:
if self.num_kv_heads != self.num_heads:
# Interleave for MQA workaround.
key = self.repeat_kv(key, self.num_queries_per_kv)
value = self.repeat_kv(value, self.num_queries_per_kv)
if self.use_naive_attn:
out = self.attn_func(
query,
key,
value,
prefill_meta.prompt_lens,
self.scale,
)
assert output[:num_prefill_tokens].shape == out.shape
output[:num_prefill_tokens] = out
else:
out, _ = self.attn_func(
query,
key,
value,
None,
prefill_meta.seq_start_loc,
prefill_meta.seq_start_loc,
prefill_meta.max_prompt_len,
prefill_meta.max_prompt_len,
True,
self.scale,
)
assert output[:num_prefill_tokens].shape == out.shape
output[:num_prefill_tokens] = out
out = self.attn_func(
query,
key,
value,
prefill_meta.prompt_lens,
self.scale,
)
else:
out = self.attn_func(
q=query,
Expand All @@ -295,8 +290,10 @@ def forward(
softmax_scale=self.scale,
causal=True,
)
assert output[:num_prefill_tokens].shape == out.shape
output[:num_prefill_tokens] = out

# common code for prefill
assert output[:num_prefill_tokens].shape == out.shape
output[:num_prefill_tokens] = out
else:
# prefix-enabled attention
output[:num_prefill_tokens] = PagedAttention.forward_prefix(
Expand Down
24 changes: 11 additions & 13 deletions vllm/attention/ops/triton_flash_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ def _attn_fwd_inner(
num_warps=4,
),
],
key=["hq", "hk", "IS_CAUSAL", "dropout_p", "BLOCK_DMODEL"],
key=['IS_CAUSAL', 'dropout_p', 'BLOCK_DMODEL'],
)
@triton.jit
def attn_fwd(
Expand Down Expand Up @@ -330,8 +330,8 @@ def attn_fwd(
philox_seed,
philox_offset_base,
encoded_softmax,
hq,
hk,
HQ: tl.constexpr,
HK: tl.constexpr,
ACTUAL_BLOCK_DMODEL: tl.constexpr,
MAX_SEQLENS_Q: tl.constexpr,
MAX_SEQLENS_K: tl.constexpr,
Expand Down Expand Up @@ -403,7 +403,7 @@ def attn_fwd(
# We still need to write 0s to the result
# tl.store(O_block_ptr,
# acc.to(Out.type.element_ty), boundary_check=(0,1))
# l_ptrs = L + off_z * hq * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q
# l_ptrs = L + off_z * HQ * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q
# + offs_m
# We store inf to LSE, not -inf because in the bwd pass,
# we subtract this
Expand All @@ -414,11 +414,9 @@ def attn_fwd(
# TODO: Should dropout and return encoded softmax be handled here?
return

is_mqa = hq != hk
if is_mqa: # noqa: SIM108
off_h_k = off_h_q % hk
else:
off_h_k = off_h_q
# If MQA / GQA, set the K and V head offsets appropriately.
GROUP_SIZE: tl.constexpr = HQ // HK
off_h_k = off_h_q // GROUP_SIZE if GROUP_SIZE != 1 else off_h_q

n_extra_tokens = 0
if seqlen_k < BLOCK_N:
Expand Down Expand Up @@ -471,7 +469,7 @@ def attn_fwd(
bias_ptr = None
if ENABLE_DROPOUT:
batch_philox_offset = philox_offset_base \
+ (off_z * hq + off_h_q) \
+ (off_z * HQ + off_h_q) \
* seqlen_q * seqlen_k
else:
batch_philox_offset = 0
Expand Down Expand Up @@ -624,7 +622,7 @@ def attn_fwd(
z = 0.0
acc = tl.where(out_ptrs_mask, acc, z.to(acc.type.element_ty))
# write back LSE
# l_ptrs = L + off_z * hq * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q + offs_m
# l_ptrs = L + off_z * HQ * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q + offs_m
# If seqlen_q not multiple of BLOCK_M, we need to mask out the last
# few rows. This is only true for the last M block. For others,
# overflow_size will be -ve
Expand Down Expand Up @@ -784,8 +782,8 @@ def forward(
philox_seed=philox_seed,
philox_offset_base=philox_offset,
encoded_softmax=encoded_softmax,
hq=nheads_q,
hk=nheads_k,
HQ=nheads_q,
HK=nheads_k,
ACTUAL_BLOCK_DMODEL=head_size,
MAX_SEQLENS_Q=max_seqlens_q,
MAX_SEQLENS_K=max_seqlens_k,
Expand Down

0 comments on commit 18d23f6

Please sign in to comment.