From 873fc194ad0cb2cafe6268d137a9dca43935cc9c Mon Sep 17 00:00:00 2001 From: Tanner Voas Date: Thu, 14 Nov 2024 08:22:10 +0000 Subject: [PATCH] vLLM-Ext: Resolved ALiBI bias regression - Works in lazy and eager mode Co-authored-by: Tanner Voas Co-authored-by: Haihao Xiang Signed-off-by: Tanner Voas --- vllm_hpu_extension/ops.py | 28 +++++++++++++++++++++++----- 1 file changed, 23 insertions(+), 5 deletions(-) diff --git a/vllm_hpu_extension/ops.py b/vllm_hpu_extension/ops.py index eeef4c4e..fd22e8a7 100644 --- a/vllm_hpu_extension/ops.py +++ b/vllm_hpu_extension/ops.py @@ -140,10 +140,24 @@ def block_softmax(batch_size, attn, block_mapping, block_scales, block_groups): return attn -def flat_pa(query, key_cache, value_cache, block_list, block_mapping, - block_bias, block_scales, block_groups, scale, matmul_qk_op, - matmul_av_op, batch2block_matmul_op, block2batch_matmul_op, - keys_fetch_func, values_fetch_func): +def flat_pa( + query, + key_cache, + value_cache, + block_list, + block_mapping, + block_bias, + block_scales, + block_groups, + scale, + alibi_slopes, + matmul_qk_op, + matmul_av_op, + batch2block_matmul_op, + block2batch_matmul_op, + keys_fetch_func, + values_fetch_func, +): batch_size = query.size(0) q_heads = query.size(1) kv_heads = key_cache.size(2) @@ -161,7 +175,11 @@ def flat_pa(query, key_cache, value_cache, block_list, block_mapping, else: key = key.transpose(2, 3) - attn = matmul_qk_op(query, key) + block_bias + attn = matmul_qk_op(query, key) + if alibi_slopes is not None: + attn.add_(alibi_slopes.unsqueeze(-2)) + if block_bias is not None: + attn = attn + block_bias attn = block_softmax(batch_size, attn, block_mapping, block_scales, block_groups) attn = matmul_av_op(attn, value) attn = block2batch(attn, block_mapping, block2batch_matmul_op)