Skip to content

Commit

Permalink
vLLM-Ext: Resolved ALiBI bias regression
Browse files Browse the repository at this point in the history
- Works in lazy and eager mode

Co-authored-by: Tanner Voas <[email protected]>
Co-authored-by: Haihao Xiang <[email protected]>
Signed-off-by: Tanner Voas <[email protected]>
  • Loading branch information
tannervoas742 and xhaihao committed Nov 18, 2024
1 parent 250622e commit 1a652d3
Showing 1 changed file with 25 additions and 7 deletions.
32 changes: 25 additions & 7 deletions vllm_hpu_extension/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,10 +140,24 @@ def block_softmax(batch_size, attn, block_mapping, block_scales, block_groups):
return attn


def flat_pa(query, key_cache, value_cache, block_list, block_mapping,
block_bias, block_scales, block_groups, scale, matmul_qk_op,
matmul_av_op, batch2block_matmul_op, block2batch_matmul_op,
keys_fetch_func, values_fetch_func):
def flat_pa(
query,
key_cache,
value_cache,
block_list,
block_mapping,
block_bias,
block_scales,
block_groups,
scale,
alibi_slopes,
matmul_qk_op,
matmul_av_op,
batch2block_matmul_op,
block2batch_matmul_op,
keys_fetch_func,
values_fetch_func,
):
batch_size = query.size(0)
q_heads = query.size(1)
kv_heads = key_cache.size(2)
Expand All @@ -161,7 +175,11 @@ def flat_pa(query, key_cache, value_cache, block_list, block_mapping,
else:
key = key.transpose(2, 3)

attn = matmul_qk_op(query, key) + block_bias
attn = matmul_qk_op(query, key)
if alibi_slopes is not None:
attn[:alibi_slopes.size(0)].add_(alibi_slopes.unsqueeze(-2))
if block_bias is not None:
attn = attn + block_bias
attn = block_softmax(batch_size, attn, block_mapping, block_scales, block_groups)
attn = matmul_av_op(attn, value)
attn = block2batch(attn, block_mapping, block2batch_matmul_op)
Expand Down Expand Up @@ -249,7 +267,7 @@ def prompt_attention_with_context(
matmul_av_op,
softmax_op,
keys_fetch_func,
values_fetch_func,
values_fetch_func,
) -> torch.Tensor:
htorch.core.mark_step()
query.mul_(scale)
Expand Down Expand Up @@ -514,5 +532,5 @@ def scaled_fp8_quant(
False,
False,
dtype=torch.float8_e4m3fn)[0]

return output, scale

0 comments on commit 1a652d3

Please sign in to comment.