Skip to content

Commit

Permalink
vLLM-Ext: Full enabling of ALiBi
Browse files Browse the repository at this point in the history
Changes:
- Added back alibi biases to decode stage.
- Optimized ALiBI memory usage.
  - Added environment variable "VLLM_PROMPT_ALIBI_MAX_SEQ_LEN" to allow
    large models to run with restricted prompt lengths.
  - Prompt biases instantiated once in __init__ rather than each
    forward.
  - Prompt and decode biases are shared across encoder/decoder layers.
- Added environment variable "VLLM_ALIBI_USE_FLOAT32_BIASES" to resolve
  accuracy issue on long sequences.
- Updated jais, mpt, falcon, baichuan, and bloom to work with ALiBI.
  - Due to bloom's 176B parameter size I was unable to test this model.
    Its changes are the simplest though.
- Works in lazy and eager mode.
- ALiBI is restricted to "VLLM_PROMPT_USE_FUSEDSDPA=false", and
  "VLLM_CONTIGUOUS_PA=true".
- Add position offsets to improve quality on BS > 1 with sequences of
  varying length.
- BS > 1 may have accuracy issues if on FW < 1.19.0. This is due to
  limitation in softmax. Resolved on FW >= 1.19.0.
- NTT patch for GQA

Co-authored-by: Tanner Voas <[email protected]>
Co-authored-by: Haihao Xiang <[email protected]>
Signed-off-by: Tanner Voas <[email protected]>
  • Loading branch information
tannervoas742 and xhaihao committed Dec 6, 2024
1 parent 41ff369 commit d97b251
Showing 1 changed file with 67 additions and 11 deletions.
78 changes: 67 additions & 11 deletions vllm_hpu_extension/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,10 +105,24 @@ def pa(attn, value, block_groups, block_mapping, block_scales, batch_size,
pa_impl = pipelined_pa if pipelined_pa_enabled else pa


def flat_pa(query, key_cache, value_cache, block_list, block_mapping,
block_bias, block_scales, block_groups, scale, matmul_qk_op,
matmul_av_op, batch2block_matmul_op, block2batch_matmul_op,
keys_fetch_func, values_fetch_func):
def flat_pa(
query,
key_cache,
value_cache,
block_list,
block_mapping,
block_bias,
block_scales,
block_groups,
scale,
position_bias,
matmul_qk_op,
matmul_av_op,
batch2block_matmul_op,
block2batch_matmul_op,
keys_fetch_func,
values_fetch_func,
):
batch_size = query.size(0)
q_heads = query.size(1)
kv_heads = key_cache.size(2)
Expand All @@ -118,20 +132,39 @@ def flat_pa(query, key_cache, value_cache, block_list, block_mapping,
value = values_fetch_func(value_cache, block_list).transpose(1, 2)
block_bias = block_bias.view(key.size(0), 1, 1, -1)
if kv_heads != q_heads:
block_bias = block_bias.unsqueeze(1)
query = query.unflatten(1, (kv_heads, -1))
key = key.unflatten(1, (kv_heads, 1))
value = value.unflatten(1, (kv_heads, 1))
key = key.transpose(3, 4)
else:
key = key.transpose(2, 3)

attn = matmul_qk_op(query, key) + block_bias
if position_bias is not None:
position_bias = position_bias.unflatten(1, (kv_heads, -1))
if block_bias is not None:
block_bias = block_bias.unsqueeze(2)
key = key.transpose(-2, -1)

attn = matmul_qk_op(query, key)
if position_bias is not None:
if attn.dtype != position_bias.dtype:
attn = attn.to(dtype=position_bias.dtype)
attn.add_(position_bias.unsqueeze(-2))
if block_bias is not None:
if attn.dtype != block_bias.dtype:
block_bias = block_bias.to(dtype=attn.dtype)
attn.add_(block_bias)

if attn.dtype != block_mapping.dtype:
block_mapping = block_mapping.to(dtype=attn.dtype)
if attn.dtype != block_scales.dtype:
block_scales = block_scales.to(dtype=attn.dtype)
if attn.dtype != value.dtype:
value = value.to(dtype=attn.dtype)
attn = pa_impl(attn, value, block_groups, block_mapping, block_scales=block_scales,
batch_size=batch_size, matmul_av_op=matmul_av_op,
batch2block_matmul_op=batch2block_matmul_op, block2batch_matmul_op=block2batch_matmul_op)
attn = block2batch(attn, block_mapping, block2batch_matmul_op)
if attn.dtype != query.dtype:
attn = attn.to(dtype=query.dtype)
attn = attn.squeeze(-2)

if kv_heads != q_heads:
attn = attn.flatten(1, 2)
return attn
Expand Down Expand Up @@ -163,6 +196,8 @@ def prompt_attention(
key: torch.Tensor,
value: torch.Tensor,
attn_bias: Optional[torch.Tensor] = None,
position_bias: Optional[torch.Tensor] = None,
position_bias_offset: Optional[torch.Tensor] = None,
p: float = 0.0,
scale: Optional[float] = None,
matmul_qk_op=torch.matmul,
Expand All @@ -181,13 +216,33 @@ def prompt_attention(
query = query.unflatten(1, (kv_heads, -1))
key = key.unflatten(1, (kv_heads, 1))
value = value.unflatten(1, (kv_heads, 1))
if position_bias is not None:
position_bias = position_bias.unflatten(1, (kv_heads, -1))
if position_bias_offset is not None:
position_bias_offset = position_bias_offset.unflatten(1, (kv_heads, -1))
if attn_bias is not None:
attn_bias = attn_bias.unsqueeze(2)
attn_weights = matmul_qk_op(query * scale, key.transpose(-1, -2))
key = key.transpose(-2, -1)
attn_weights = matmul_qk_op(query * scale, key)

if position_bias is not None:
if attn_weights.dtype != position_bias.dtype:
attn_weights = attn_weights.to(dtype=position_bias.dtype)
attn_weights.add_(position_bias)
if position_bias_offset is not None:
attn_weights.add_(position_bias_offset.unsqueeze(-1).unsqueeze(-1))
if attn_bias is not None:
if attn_weights.dtype != attn_bias.dtype:
attn_bias = attn_bias.to(dtype=attn_weights.dtype)
attn_weights.add_(attn_bias)

attn_weights = softmax_op(attn_weights, dim=-1)
if attn_weights.dtype != value.dtype:
value = value.to(dtype=attn_weights.dtype)
attn_weights = matmul_av_op(attn_weights, value)
if attn_weights.dtype != query.dtype:
attn_weights = attn_weights.to(dtype=query.dtype)

if query_heads != kv_heads:
attn_weights = attn_weights.flatten(1, 2)
else:
Expand All @@ -206,6 +261,7 @@ def prompt_attention(
attn_weights = fsdpa_op(query, key, value, None, 0.0, True,
scale, softmax_mode, recompute_mode,
valid_seq_lengths, 'right')

attn_weights = attn_weights.transpose(1, 2)
return attn_weights

Expand Down

0 comments on commit d97b251

Please sign in to comment.