Skip to content

Commit

Permalink
use flash attn for decode
Browse files Browse the repository at this point in the history
Signed-off-by: jiqing-feng <[email protected]>
  • Loading branch information
jiqing-feng committed Jan 14, 2025
1 parent 95b7043 commit 71aa6b0
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions optimum/exporters/ipex/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -652,18 +652,19 @@ def attention_interface(
is_causal=True,
)
self.use_sdpa = True
elif self.has_flash_attn(query) and past_len == 0:
# prefill, remove padding
elif self.has_flash_attn(query):
attn_output = torch.empty_like(query)
seq_len_tensor = torch.cat((input_lens.new_tensor([0]), input_lens.cumsum(-1).int()))
query_len_tensor = seq_len_tensor if past_len == 0 else torch.arange(seq_len_tensor.shape[0])
query_max_len = input_lens.max() if past_len == 0 else 1
PagedAttention.flash_attn_varlen_func(
attn_output,
query.contiguous() if query.device.type == "xpu" else query,
key_cache.contiguous() if key_cache.device.type == "xpu" else key_cache,
value_cache.contiguous() if value_cache.device.type == "xpu" else value_cache,
query_len_tensor,
seq_len_tensor,
seq_len_tensor,
input_lens.max(),
query_max_len,
input_lens.max(),
1.0 / math.sqrt(self.head_dim),
True,
Expand Down

0 comments on commit 71aa6b0

Please sign in to comment.