Skip to content

Commit

Permalink
add support for flash decoding on xpu
Browse files Browse the repository at this point in the history
Signed-off-by: Liu, Kaixuan <[email protected]>
  • Loading branch information
kaixuanliu committed Jan 17, 2025
1 parent 2590794 commit f96edd2
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 29 deletions.
66 changes: 46 additions & 20 deletions optimum/exporters/ipex/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from intel_extension_for_pytorch.llm.modules import PagedAttention
from transformers import Cache, PretrainedConfig

from optimum.intel.utils.import_utils import is_ipex_version


class IPEXPagedCache(Cache):
"""
Expand Down Expand Up @@ -43,6 +45,10 @@ def __init__(
) -> None:
super().__init__()
self.batch_size = batch_size
self.device = device
self.flash_decoding = (
is_ipex_version(">", "2.4.99") if device.type == "cpu" else is_ipex_version(">", "2.5.99")
)
# Used in `generate` to keep tally of how many tokens the cache has seen
self._seen_tokens = torch.zeros([batch_size], dtype=torch.int32, device=device)
default_block_size = 16 if device.type == "cpu" else 64
Expand All @@ -69,14 +75,43 @@ def __init__(
key_cache_shape = (self.num_blocks, self.num_kv_heads, self.block_size, head_size)
value_cache_shape = (self.num_blocks, self.num_kv_heads, self.block_size, head_size)
elif device.type == "xpu":
key_cache_shape = (self.num_blocks, self.num_kv_heads, head_size, self.block_size, 1)
value_cache_shape = (self.num_blocks, self.num_kv_heads, head_size, self.block_size)
if self.flash_decoding:
key_cache_shape = (self.num_blocks, self.block_size, self.num_kv_heads, head_size)
value_cache_shape = (self.num_blocks, self.block_size, self.num_kv_heads, head_size)
else:
key_cache_shape = (self.num_blocks, self.num_kv_heads, head_size, self.block_size, 1)
value_cache_shape = (self.num_blocks, self.num_kv_heads, head_size, self.block_size)
for i in range(config.num_hidden_layers):
new_layer_key_cache = torch.zeros(key_cache_shape, dtype=dtype, device=device)
new_layer_value_cache = torch.zeros(value_cache_shape, dtype=dtype, device=device)
self.key_cache.append(new_layer_key_cache)
self.value_cache.append(new_layer_value_cache)

def reshape_and_cache(
self,
key: torch.Tensor,
value: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
slots: torch.Tensor,
):
if self.device.type == "xpu" and self.flash_decoding:
PagedAttention.reshape_and_cache_flash(
key,
value,
key_cache,
value_cache,
slots,
)
else:
PagedAttention.reshape_and_cache(
key,
value,
key_cache,
value_cache,
slots,
)

def update_for_prefill(
self,
key_states: torch.Tensor,
Expand All @@ -94,7 +129,7 @@ def update_for_prefill(
block_table = self.free_blocks.nonzero().view(-1)[0:nb]
self.block_tables[i][0:nb] = block_table
self.free_blocks[block_table] = 0
slots_range = torch.arange(input_lens[i], device=key_states.device)
slots_range = torch.arange(input_lens[i], device=self.device)
block_indices = slots_range // self.block_size
slot_offsets = slots_range % self.block_size
all_block_indices.append(self.block_tables[i][block_indices])
Expand All @@ -104,12 +139,8 @@ def update_for_prefill(
all_slot_offsets = torch.cat(all_slot_offsets)
self.slots = all_block_indices * self.block_size + all_slot_offsets
# Update the cache
PagedAttention.reshape_and_cache(
key_states,
value_states,
self.key_cache[layer_idx],
self.value_cache[layer_idx],
self.slots,
self.reshape_and_cache(
key_states, value_states, self.key_cache[layer_idx], self.value_cache[layer_idx], self.slots
)

# Update the number of seen tokens
Expand All @@ -127,7 +158,7 @@ def update_for_decode(
if layer_idx == 0:
start_block_idx = self._seen_tokens // self.block_size
slot_offset_in_block = (self._seen_tokens) % self.block_size
self.slots = torch.zeros([batch_size], device=key_states.device, dtype=torch.int32)
self.slots = torch.zeros([batch_size], device=self.device, dtype=torch.int32)
for i in range(batch_size):
if slot_offset_in_block[i] == 0:
# need a new block:
Expand All @@ -138,12 +169,8 @@ def update_for_decode(
self.free_blocks[self.block_tables[i][b_idx]] = 0
self.slots[i] = self.block_tables[i][start_block_idx[i]] * self.block_size + slot_offset_in_block[i]
# Update the cache
PagedAttention.reshape_and_cache(
key_states,
value_states,
self.key_cache[layer_idx],
self.value_cache[layer_idx],
self.slots,
self.reshape_and_cache(
key_states, value_states, self.key_cache[layer_idx], self.value_cache[layer_idx], self.slots
)

# Update the number of seen tokens
Expand Down Expand Up @@ -193,16 +220,15 @@ def get_max_length(self) -> Optional[int]:

def reset(self):
"""Resets the cache values while preserving the objects"""
self._seen_tokens = torch.zeros([self.batch_size], dtype=torch.int32, device=self.block_tables.device)
self._seen_tokens = torch.zeros([self.batch_size], dtype=torch.int32, device=self.device)
self.block_tables.fill_(-1)
self.free_blocks = torch.ones([self.num_blocks], dtype=torch.int32, device=self.block_tables.device)
self.free_blocks = torch.ones([self.num_blocks], dtype=torch.int32, device=self.device)
self.max_seq_len = 0

def reorder_cache(self, beam_idx: torch.LongTensor):
"""Reorders the cache for beam search, given the selected beam indices."""
device = self.block_tables.device
origin_table = self.block_tables.clone()
updated_block_tables = self.block_tables.index_select(0, beam_idx.to(device))
updated_block_tables = self.block_tables.index_select(0, beam_idx.to(self.device))
mask = self.block_tables.masked_fill(self.block_tables != -1, 1).masked_fill(self.block_tables == -1, 0)
num_blocks = mask.cumsum(-1)[:, -1]
updated_table = torch.zeros_like(beam_idx)
Expand Down
21 changes: 12 additions & 9 deletions optimum/exporters/ipex/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -628,10 +628,10 @@ def postprocess_attention_output(self, attn_output):
return attn_output

# Maybe removed after torch 2.6 released
def has_flash_attn(self, query):
if query.device.type == "cpu":
def has_flash_attn(self):
if self.module_device.type == "cpu":
return is_torch_version(">", "2.4.99")
elif query.device.type == "xpu":
elif self.module_device.type == "xpu":
return is_torch_version(">", "2.5.99")

def attention_interface(
Expand All @@ -652,20 +652,23 @@ def attention_interface(
is_causal=True,
)
self.use_sdpa = True
elif self.has_flash_attn(query):
elif self.has_flash_attn():
attn_output = torch.empty_like(query)
seq_len_tensor = torch.cat((input_lens.new_tensor([0]), input_lens.cumsum(-1).int()))
query_len_tensor = seq_len_tensor if past_len == 0 else torch.arange(seq_len_tensor.shape[0]).int()
query_max_len = input_lens.max() if past_len == 0 else 1
query_len_tensor = (
seq_len_tensor if past_len == 0 else torch.arange(seq_len_tensor.shape[0], device=query.device).int()
)
max_input_lens = input_lens.max().item()
query_max_len = max_input_lens if past_len == 0 else 1
PagedAttention.flash_attn_varlen_func(
attn_output,
query.contiguous() if query.device.type == "xpu" else query,
key_cache.contiguous() if key_cache.device.type == "xpu" else key_cache,
value_cache.contiguous() if value_cache.device.type == "xpu" else value_cache,
key_cache,
value_cache,
query_len_tensor,
seq_len_tensor,
query_max_len,
input_lens.max(),
max_input_lens,
1.0 / math.sqrt(self.head_dim),
True,
past_key_value.block_tables,
Expand Down

0 comments on commit f96edd2

Please sign in to comment.