From f96edd27b3a12919be88b12835a3c290f6671047 Mon Sep 17 00:00:00 2001 From: "Liu, Kaixuan" Date: Fri, 17 Jan 2025 17:54:49 -0500 Subject: [PATCH] add support for flash decoding on xpu Signed-off-by: Liu, Kaixuan --- optimum/exporters/ipex/cache_utils.py | 66 +++++++++++++++++------- optimum/exporters/ipex/modeling_utils.py | 21 ++++---- 2 files changed, 58 insertions(+), 29 deletions(-) diff --git a/optimum/exporters/ipex/cache_utils.py b/optimum/exporters/ipex/cache_utils.py index f9df2cf69..d3555e73d 100755 --- a/optimum/exporters/ipex/cache_utils.py +++ b/optimum/exporters/ipex/cache_utils.py @@ -5,6 +5,8 @@ from intel_extension_for_pytorch.llm.modules import PagedAttention from transformers import Cache, PretrainedConfig +from optimum.intel.utils.import_utils import is_ipex_version + class IPEXPagedCache(Cache): """ @@ -43,6 +45,10 @@ def __init__( ) -> None: super().__init__() self.batch_size = batch_size + self.device = device + self.flash_decoding = ( + is_ipex_version(">", "2.4.99") if device.type == "cpu" else is_ipex_version(">", "2.5.99") + ) # Used in `generate` to keep tally of how many tokens the cache has seen self._seen_tokens = torch.zeros([batch_size], dtype=torch.int32, device=device) default_block_size = 16 if device.type == "cpu" else 64 @@ -69,14 +75,43 @@ def __init__( key_cache_shape = (self.num_blocks, self.num_kv_heads, self.block_size, head_size) value_cache_shape = (self.num_blocks, self.num_kv_heads, self.block_size, head_size) elif device.type == "xpu": - key_cache_shape = (self.num_blocks, self.num_kv_heads, head_size, self.block_size, 1) - value_cache_shape = (self.num_blocks, self.num_kv_heads, head_size, self.block_size) + if self.flash_decoding: + key_cache_shape = (self.num_blocks, self.block_size, self.num_kv_heads, head_size) + value_cache_shape = (self.num_blocks, self.block_size, self.num_kv_heads, head_size) + else: + key_cache_shape = (self.num_blocks, self.num_kv_heads, head_size, self.block_size, 1) + value_cache_shape = (self.num_blocks, self.num_kv_heads, head_size, self.block_size) for i in range(config.num_hidden_layers): new_layer_key_cache = torch.zeros(key_cache_shape, dtype=dtype, device=device) new_layer_value_cache = torch.zeros(value_cache_shape, dtype=dtype, device=device) self.key_cache.append(new_layer_key_cache) self.value_cache.append(new_layer_value_cache) + def reshape_and_cache( + self, + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slots: torch.Tensor, + ): + if self.device.type == "xpu" and self.flash_decoding: + PagedAttention.reshape_and_cache_flash( + key, + value, + key_cache, + value_cache, + slots, + ) + else: + PagedAttention.reshape_and_cache( + key, + value, + key_cache, + value_cache, + slots, + ) + def update_for_prefill( self, key_states: torch.Tensor, @@ -94,7 +129,7 @@ def update_for_prefill( block_table = self.free_blocks.nonzero().view(-1)[0:nb] self.block_tables[i][0:nb] = block_table self.free_blocks[block_table] = 0 - slots_range = torch.arange(input_lens[i], device=key_states.device) + slots_range = torch.arange(input_lens[i], device=self.device) block_indices = slots_range // self.block_size slot_offsets = slots_range % self.block_size all_block_indices.append(self.block_tables[i][block_indices]) @@ -104,12 +139,8 @@ def update_for_prefill( all_slot_offsets = torch.cat(all_slot_offsets) self.slots = all_block_indices * self.block_size + all_slot_offsets # Update the cache - PagedAttention.reshape_and_cache( - key_states, - value_states, - self.key_cache[layer_idx], - self.value_cache[layer_idx], - self.slots, + self.reshape_and_cache( + key_states, value_states, self.key_cache[layer_idx], self.value_cache[layer_idx], self.slots ) # Update the number of seen tokens @@ -127,7 +158,7 @@ def update_for_decode( if layer_idx == 0: start_block_idx = self._seen_tokens // self.block_size slot_offset_in_block = (self._seen_tokens) % self.block_size - self.slots = torch.zeros([batch_size], device=key_states.device, dtype=torch.int32) + self.slots = torch.zeros([batch_size], device=self.device, dtype=torch.int32) for i in range(batch_size): if slot_offset_in_block[i] == 0: # need a new block: @@ -138,12 +169,8 @@ def update_for_decode( self.free_blocks[self.block_tables[i][b_idx]] = 0 self.slots[i] = self.block_tables[i][start_block_idx[i]] * self.block_size + slot_offset_in_block[i] # Update the cache - PagedAttention.reshape_and_cache( - key_states, - value_states, - self.key_cache[layer_idx], - self.value_cache[layer_idx], - self.slots, + self.reshape_and_cache( + key_states, value_states, self.key_cache[layer_idx], self.value_cache[layer_idx], self.slots ) # Update the number of seen tokens @@ -193,16 +220,15 @@ def get_max_length(self) -> Optional[int]: def reset(self): """Resets the cache values while preserving the objects""" - self._seen_tokens = torch.zeros([self.batch_size], dtype=torch.int32, device=self.block_tables.device) + self._seen_tokens = torch.zeros([self.batch_size], dtype=torch.int32, device=self.device) self.block_tables.fill_(-1) - self.free_blocks = torch.ones([self.num_blocks], dtype=torch.int32, device=self.block_tables.device) + self.free_blocks = torch.ones([self.num_blocks], dtype=torch.int32, device=self.device) self.max_seq_len = 0 def reorder_cache(self, beam_idx: torch.LongTensor): """Reorders the cache for beam search, given the selected beam indices.""" - device = self.block_tables.device origin_table = self.block_tables.clone() - updated_block_tables = self.block_tables.index_select(0, beam_idx.to(device)) + updated_block_tables = self.block_tables.index_select(0, beam_idx.to(self.device)) mask = self.block_tables.masked_fill(self.block_tables != -1, 1).masked_fill(self.block_tables == -1, 0) num_blocks = mask.cumsum(-1)[:, -1] updated_table = torch.zeros_like(beam_idx) diff --git a/optimum/exporters/ipex/modeling_utils.py b/optimum/exporters/ipex/modeling_utils.py index 41dd5693d..cc57de14e 100755 --- a/optimum/exporters/ipex/modeling_utils.py +++ b/optimum/exporters/ipex/modeling_utils.py @@ -628,10 +628,10 @@ def postprocess_attention_output(self, attn_output): return attn_output # Maybe removed after torch 2.6 released - def has_flash_attn(self, query): - if query.device.type == "cpu": + def has_flash_attn(self): + if self.module_device.type == "cpu": return is_torch_version(">", "2.4.99") - elif query.device.type == "xpu": + elif self.module_device.type == "xpu": return is_torch_version(">", "2.5.99") def attention_interface( @@ -652,20 +652,23 @@ def attention_interface( is_causal=True, ) self.use_sdpa = True - elif self.has_flash_attn(query): + elif self.has_flash_attn(): attn_output = torch.empty_like(query) seq_len_tensor = torch.cat((input_lens.new_tensor([0]), input_lens.cumsum(-1).int())) - query_len_tensor = seq_len_tensor if past_len == 0 else torch.arange(seq_len_tensor.shape[0]).int() - query_max_len = input_lens.max() if past_len == 0 else 1 + query_len_tensor = ( + seq_len_tensor if past_len == 0 else torch.arange(seq_len_tensor.shape[0], device=query.device).int() + ) + max_input_lens = input_lens.max().item() + query_max_len = max_input_lens if past_len == 0 else 1 PagedAttention.flash_attn_varlen_func( attn_output, query.contiguous() if query.device.type == "xpu" else query, - key_cache.contiguous() if key_cache.device.type == "xpu" else key_cache, - value_cache.contiguous() if value_cache.device.type == "xpu" else value_cache, + key_cache, + value_cache, query_len_tensor, seq_len_tensor, query_max_len, - input_lens.max(), + max_input_lens, 1.0 / math.sqrt(self.head_dim), True, past_key_value.block_tables,