diff --git a/optimum/exporters/ipex/modeling_utils.py b/optimum/exporters/ipex/modeling_utils.py index 46f3868cc2..b69dfcfe55 100755 --- a/optimum/exporters/ipex/modeling_utils.py +++ b/optimum/exporters/ipex/modeling_utils.py @@ -652,18 +652,19 @@ def attention_interface( is_causal=True, ) self.use_sdpa = True - elif self.has_flash_attn(query) and past_len == 0: - # prefill, remove padding + elif self.has_flash_attn(query): attn_output = torch.empty_like(query) seq_len_tensor = torch.cat((input_lens.new_tensor([0]), input_lens.cumsum(-1).int())) + query_len_tensor = seq_len_tensor if past_len == 0 else torch.arange(seq_len_tensor.shape[0]) + query_max_len = input_lens.max() if past_len == 0 else 1 PagedAttention.flash_attn_varlen_func( attn_output, query.contiguous() if query.device.type == "xpu" else query, key_cache.contiguous() if key_cache.device.type == "xpu" else key_cache, value_cache.contiguous() if value_cache.device.type == "xpu" else value_cache, + query_len_tensor, seq_len_tensor, - seq_len_tensor, - input_lens.max(), + query_max_len, input_lens.max(), 1.0 / math.sqrt(self.head_dim), True,