From 6d210750cf332230fa6df07d082fdf1f9e436d53 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Wed, 11 Dec 2024 16:14:34 +0000 Subject: [PATCH 01/31] use real varlen attn Signed-off-by: jiqing-feng --- optimum/exporters/ipex/modeling_utils.py | 28 ++++++++++++++++++------ 1 file changed, 21 insertions(+), 7 deletions(-) diff --git a/optimum/exporters/ipex/modeling_utils.py b/optimum/exporters/ipex/modeling_utils.py index e741575ed..7ece93a9d 100755 --- a/optimum/exporters/ipex/modeling_utils.py +++ b/optimum/exporters/ipex/modeling_utils.py @@ -614,20 +614,34 @@ def forward( if past_len == 0: # prefill, remove padding seq_len_tensor = torch.cat((input_lens.new_tensor([0]), input_lens.cumsum(-1).int())) - varlen_attention( - query.contiguous() if query.device.type == "xpu" else query, - key.contiguous() if key.device.type == "xpu" else key, - value.contiguous() if value.device.type == "xpu" else value, + # varlen_attention( + # query.contiguous() if query.device.type == "xpu" else query, + # key.contiguous() if key.device.type == "xpu" else key, + # value.contiguous() if value.device.type == "xpu" else value, + # attn_output, + # seq_len_tensor, + # seq_len_tensor, + # input_lens.max(), + # input_lens.max(), + # 0.0, + # 1.0 / math.sqrt(self.head_dim), + # False, + # True, + # False, + # None, + # ) + PagedAttention.flash_attn_varlen_func( attn_output, + query, + key_cache, + value_cache, seq_len_tensor, seq_len_tensor, input_lens.max(), input_lens.max(), - 0.0, 1.0 / math.sqrt(self.head_dim), - False, True, - False, + past_key_value.block_tables, None, ) else: From b792875be1e2ae97275afc0fe53f28f6b202190d Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Thu, 12 Dec 2024 10:18:29 +0000 Subject: [PATCH 02/31] optimize gpt2 by using linear instead of conv1D Signed-off-by: jiqing-feng --- optimum/exporters/ipex/cache_utils.py | 2 +- optimum/exporters/ipex/modeling_utils.py | 26 ++++++++---------------- 2 files changed, 9 insertions(+), 19 deletions(-) diff --git a/optimum/exporters/ipex/cache_utils.py b/optimum/exporters/ipex/cache_utils.py index dec1e8189..e1f6aa19b 100755 --- a/optimum/exporters/ipex/cache_utils.py +++ b/optimum/exporters/ipex/cache_utils.py @@ -44,7 +44,7 @@ def __init__( self.batch_size = batch_size # Used in `generate` to keep tally of how many tokens the cache has seen self._seen_tokens = torch.zeros([batch_size], dtype=torch.int32, device=device) - self.block_size = 16 + self.block_size = 64 self.num_blocks = (max_cache_len // self.block_size + (max_cache_len % self.block_size != 0)) * batch_size self.block_tables = -1 * torch.ones([self.num_blocks], dtype=torch.int32, device=device).reshape( batch_size, -1 diff --git a/optimum/exporters/ipex/modeling_utils.py b/optimum/exporters/ipex/modeling_utils.py index 7ece93a9d..28e6720ee 100755 --- a/optimum/exporters/ipex/modeling_utils.py +++ b/optimum/exporters/ipex/modeling_utils.py @@ -614,22 +614,6 @@ def forward( if past_len == 0: # prefill, remove padding seq_len_tensor = torch.cat((input_lens.new_tensor([0]), input_lens.cumsum(-1).int())) - # varlen_attention( - # query.contiguous() if query.device.type == "xpu" else query, - # key.contiguous() if key.device.type == "xpu" else key, - # value.contiguous() if value.device.type == "xpu" else value, - # attn_output, - # seq_len_tensor, - # seq_len_tensor, - # input_lens.max(), - # input_lens.max(), - # 0.0, - # 1.0 / math.sqrt(self.head_dim), - # False, - # True, - # False, - # None, - # ) PagedAttention.flash_attn_varlen_func( attn_output, query, @@ -734,9 +718,16 @@ class _IPEXGPT2Attention(_IPEXAttention): def __init__(self, module, config) -> None: self.num_key_value_heads = config.num_key_value_heads super().__init__(module, config) + _setattr_from_module(self, module) + self.c_attn_linear = nn.Linear(self.c_attn.weight.shape[0], self.c_attn.weight.shape[1]) + self.c_attn_linear.weight = nn.Parameter(self.c_attn.weight.t()) + self.c_attn_linear.bias = self.c_attn.bias + self.c_proj_linear = nn.Linear(self.c_proj.weight.shape[0], self.c_proj.weight.shape[1]) + self.c_proj_linear.weight = nn.Parameter(self.c_proj.weight.t()) + self.c_proj_linear.bias = self.c_proj.bias def qkv_gemm(self, hidden_states): - query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=-1) + query, key, value = self.c_attn_linear(hidden_states).split(self.split_size, dim=-1) query = query.view(-1, self.num_heads, self.head_dim) key = key.view(-1, self.num_heads, self.head_dim) value = value.view(-1, self.num_heads, self.head_dim) @@ -748,7 +739,6 @@ def rope(self, query, key, *args, **kwargs): def postprocess_attention_output(self, attn_output): attn_output = attn_output.reshape(-1, attn_output.shape[-2] * attn_output.shape[-1]) attn_output = self.c_proj(attn_output) - attn_output = self.resid_dropout(attn_output) return attn_output From 36884cb77e5afae8bd0f04863f994d2122d03e53 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Thu, 12 Dec 2024 15:48:33 +0000 Subject: [PATCH 03/31] fix usage without pkv Signed-off-by: jiqing-feng --- optimum/exporters/ipex/modeling_utils.py | 61 +++++++++++++++++------- 1 file changed, 44 insertions(+), 17 deletions(-) diff --git a/optimum/exporters/ipex/modeling_utils.py b/optimum/exporters/ipex/modeling_utils.py index fdc7ea86b..039d5201b 100755 --- a/optimum/exporters/ipex/modeling_utils.py +++ b/optimum/exporters/ipex/modeling_utils.py @@ -30,6 +30,7 @@ logger = logging.getLogger(__name__) _IPEX_MINIMUM_VERSION_FOR_PATCHING = "2.4.0" +_IPEX_MINIMUM_VERSION_FOR_FLASH_VARLEN_ATTN = "2.5.0" if is_ipex_version("<", _IPEX_MINIMUM_VERSION_FOR_PATCHING): @@ -588,6 +589,44 @@ def postprocess_attention_output(self, attn_output): attn_output = attn_output.reshape(-1, attn_output.shape[-2] * attn_output.shape[-1]) return attn_output + def varlen_attn(self, query, key, value, past_key_value, input_lens): + # prefill, remove padding + attn_output = torch.empty_like(query) + seq_len_tensor = torch.cat((input_lens.new_tensor([0]), input_lens.cumsum(-1).int())) + if past_key_value and is_ipex_version(">=", _IPEX_MINIMUM_VERSION_FOR_FLASH_VARLEN_ATTN): + PagedAttention.flash_attn_varlen_func( + attn_output, + query, + key, + value, + seq_len_tensor, + seq_len_tensor, + input_lens.max(), + input_lens.max(), + 1.0 / math.sqrt(self.head_dim), + True, + past_key_value.block_tables, + None, + ) + else: + varlen_attention( + query.contiguous() if query.device.type == "xpu" else query, + key.contiguous() if key.device.type == "xpu" else key, + value.contiguous() if value.device.type == "xpu" else value, + attn_output, + seq_len_tensor, + seq_len_tensor, + input_lens.max(), + input_lens.max(), + 0.0, + 1.0 / math.sqrt(self.head_dim), + False, + True, + False, + None, + ) + return attn_output + def forward( self, hidden_states: torch.Tensor, @@ -609,27 +648,15 @@ def forward( if past_key_value is not None: key_cache, value_cache = past_key_value.update(key, value, self.layer_idx, attention_mask, input_lens) + else: + key_cache, value_cache = key, value - attn_output = torch.empty_like(query) if past_len == 0: - # prefill, remove padding - seq_len_tensor = torch.cat((input_lens.new_tensor([0]), input_lens.cumsum(-1).int())) - PagedAttention.flash_attn_varlen_func( - attn_output, - query, - key_cache, - value_cache, - seq_len_tensor, - seq_len_tensor, - input_lens.max(), - input_lens.max(), - 1.0 / math.sqrt(self.head_dim), - True, - past_key_value.block_tables, - None, - ) + # prefill + attn_output = self.varlen_attn(query, key_cache, value_cache, past_key_value, input_lens) else: # decode + attn_output = torch.empty_like(query) PagedAttention.single_query_cached_kv_attention( attn_output, query, From d061e6940bd69336320fca46fe3fb449506f337e Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Thu, 12 Dec 2024 16:13:08 +0000 Subject: [PATCH 04/31] use sdpa for no cache forward Signed-off-by: jiqing-feng --- optimum/exporters/ipex/modeling_utils.py | 25 ++++++++---------------- 1 file changed, 8 insertions(+), 17 deletions(-) diff --git a/optimum/exporters/ipex/modeling_utils.py b/optimum/exporters/ipex/modeling_utils.py index 039d5201b..91f406d3c 100755 --- a/optimum/exporters/ipex/modeling_utils.py +++ b/optimum/exporters/ipex/modeling_utils.py @@ -591,9 +591,9 @@ def postprocess_attention_output(self, attn_output): def varlen_attn(self, query, key, value, past_key_value, input_lens): # prefill, remove padding - attn_output = torch.empty_like(query) - seq_len_tensor = torch.cat((input_lens.new_tensor([0]), input_lens.cumsum(-1).int())) if past_key_value and is_ipex_version(">=", _IPEX_MINIMUM_VERSION_FOR_FLASH_VARLEN_ATTN): + attn_output = torch.empty_like(query) + seq_len_tensor = torch.cat((input_lens.new_tensor([0]), input_lens.cumsum(-1).int())) PagedAttention.flash_attn_varlen_func( attn_output, query, @@ -609,21 +609,12 @@ def varlen_attn(self, query, key, value, past_key_value, input_lens): None, ) else: - varlen_attention( - query.contiguous() if query.device.type == "xpu" else query, - key.contiguous() if key.device.type == "xpu" else key, - value.contiguous() if value.device.type == "xpu" else value, - attn_output, - seq_len_tensor, - seq_len_tensor, - input_lens.max(), - input_lens.max(), - 0.0, - 1.0 / math.sqrt(self.head_dim), - False, - True, - False, - None, + attn_output = torch.nn.functional.scaled_dot_product_attention( + query, + key, + value, + attn_mask=None, + is_causal=True, ) return attn_output From 31c635a32f454c6db057731932a58c803cf76a2d Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Thu, 12 Dec 2024 16:15:08 +0000 Subject: [PATCH 05/31] fix format Signed-off-by: jiqing-feng --- optimum/exporters/ipex/modeling_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/optimum/exporters/ipex/modeling_utils.py b/optimum/exporters/ipex/modeling_utils.py index 91f406d3c..1c69844c2 100755 --- a/optimum/exporters/ipex/modeling_utils.py +++ b/optimum/exporters/ipex/modeling_utils.py @@ -38,7 +38,7 @@ f"Please upgrade the IPEX version to at least {_IPEX_MINIMUM_VERSION_FOR_PATCHING} if you want to patch the model." ) else: - from intel_extension_for_pytorch.llm.functional import rms_norm, rotary_embedding, varlen_attention + from intel_extension_for_pytorch.llm.functional import rms_norm, rotary_embedding from intel_extension_for_pytorch.llm.modules import ( Linear2SiluMul, LinearAdd, From 73a5ef7ed06518ed078b14810a84415124c23c7e Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Thu, 12 Dec 2024 16:21:18 +0000 Subject: [PATCH 06/31] fix sdpa Signed-off-by: jiqing-feng --- optimum/exporters/ipex/modeling_utils.py | 55 ++++++++++++------------ 1 file changed, 27 insertions(+), 28 deletions(-) diff --git a/optimum/exporters/ipex/modeling_utils.py b/optimum/exporters/ipex/modeling_utils.py index 1c69844c2..f84c1cd54 100755 --- a/optimum/exporters/ipex/modeling_utils.py +++ b/optimum/exporters/ipex/modeling_utils.py @@ -591,31 +591,23 @@ def postprocess_attention_output(self, attn_output): def varlen_attn(self, query, key, value, past_key_value, input_lens): # prefill, remove padding - if past_key_value and is_ipex_version(">=", _IPEX_MINIMUM_VERSION_FOR_FLASH_VARLEN_ATTN): - attn_output = torch.empty_like(query) - seq_len_tensor = torch.cat((input_lens.new_tensor([0]), input_lens.cumsum(-1).int())) - PagedAttention.flash_attn_varlen_func( - attn_output, - query, - key, - value, - seq_len_tensor, - seq_len_tensor, - input_lens.max(), - input_lens.max(), - 1.0 / math.sqrt(self.head_dim), - True, - past_key_value.block_tables, - None, - ) - else: - attn_output = torch.nn.functional.scaled_dot_product_attention( - query, - key, - value, - attn_mask=None, - is_causal=True, - ) + attn_output = torch.empty_like(query) + seq_len_tensor = torch.cat((input_lens.new_tensor([0]), input_lens.cumsum(-1).int())) + PagedAttention.flash_attn_varlen_func( + attn_output, + query, + key, + value, + seq_len_tensor, + seq_len_tensor, + input_lens.max(), + input_lens.max(), + 1.0 / math.sqrt(self.head_dim), + True, + past_key_value.block_tables, + None, + ) + return attn_output def forward( @@ -639,12 +631,19 @@ def forward( if past_key_value is not None: key_cache, value_cache = past_key_value.update(key, value, self.layer_idx, attention_mask, input_lens) - else: - key_cache, value_cache = key, value if past_len == 0: # prefill - attn_output = self.varlen_attn(query, key_cache, value_cache, past_key_value, input_lens) + if past_key_value is None or is_ipex_version("<", _IPEX_MINIMUM_VERSION_FOR_FLASH_VARLEN_ATTN): + attn_output = torch.nn.functional.scaled_dot_product_attention( + query, + key, + value, + attn_mask=None, + is_causal=True, + ) + else: + attn_output = self.varlen_attn(query, key_cache, value_cache, past_key_value, input_lens) else: # decode attn_output = torch.empty_like(query) From f9c021b4c1130dac4eb069ea3161aaf182449c2c Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Thu, 12 Dec 2024 16:51:14 +0000 Subject: [PATCH 07/31] revert shape for sdpa Signed-off-by: jiqing-feng --- optimum/exporters/ipex/modeling_utils.py | 18 +++++++++--------- optimum/intel/ipex/modeling_base.py | 2 ++ 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/optimum/exporters/ipex/modeling_utils.py b/optimum/exporters/ipex/modeling_utils.py index f84c1cd54..336b3871b 100755 --- a/optimum/exporters/ipex/modeling_utils.py +++ b/optimum/exporters/ipex/modeling_utils.py @@ -181,7 +181,7 @@ def _llama_model_forward( position_ids = torch.arange( past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device ) - position_ids = position_ids.unsqueeze(0) + position_ids = position_ids.unsqueeze(0).repeat_interleave(input_ids.shape[0], 0) if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) @@ -195,7 +195,7 @@ def _llama_model_forward( next_decoder_cache = () if use_cache else None position_embeddings = self.rotary_emb(hidden_states, position_ids) - if past_key_values_length == 0: + if past_key_values_length == 0 and past_key_values is not None: # first token, remove the padding from hidden_states, varlen do not accept attention mask hidden_states_copy = hidden_states index = attention_mask.view(-1) != 0 @@ -298,7 +298,7 @@ def _falcon_model_forward( ) if position_ids is None: - position_ids = cache_position.unsqueeze(0) + position_ids = cache_position.unsqueeze(0).repeat_interleave(input_ids.shape[0], 0) # Prepare head mask if needed # 1.0 in head_mask indicate we keep the head @@ -310,7 +310,7 @@ def _falcon_model_forward( # create position embeddings to be shared across the decoder layers position_embeddings = self.rotary_emb(hidden_states, position_ids) - if past_key_values_length == 0: + if past_key_values_length == 0 and past_key_values is not None: # first token, remove the padding from hidden_states, varlen do not accept attention mask hidden_states_copy = hidden_states index = attention_mask.view(-1) != 0 @@ -420,7 +420,7 @@ def _gpt2_model_forward( past_length = past_key_values.get_seq_length() if past_key_values is not None else 0 if position_ids is None: position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device) - position_ids = position_ids.unsqueeze(0) + position_ids = position_ids.unsqueeze(0).repeat_interleave(input_ids.shape[0], 0) if inputs_embeds is None: inputs_embeds = self.wte(input_ids) @@ -437,7 +437,7 @@ def _gpt2_model_forward( hidden_states = self.drop(hidden_states) - if past_length == 0: + if past_length == 0 and past_key_values is not None: # first token, remove the padding from hidden_states, varlen do not accept attention mask hidden_states_copy = hidden_states index = attention_mask.view(-1) != 0 @@ -636,9 +636,9 @@ def forward( # prefill if past_key_value is None or is_ipex_version("<", _IPEX_MINIMUM_VERSION_FOR_FLASH_VARLEN_ATTN): attn_output = torch.nn.functional.scaled_dot_product_attention( - query, - key, - value, + query.reshape(input_lens.shape[0], -1, query.shape[-1]), + key.reshape(input_lens.shape[0], -1, key.shape[-1]), + value.reshape(input_lens.shape[0], -1, value.shape[-1]), attn_mask=None, is_causal=True, ) diff --git a/optimum/intel/ipex/modeling_base.py b/optimum/intel/ipex/modeling_base.py index 8611bddd2..d8f830e51 100644 --- a/optimum/intel/ipex/modeling_base.py +++ b/optimum/intel/ipex/modeling_base.py @@ -276,6 +276,8 @@ def forward( attention_mask: Optional[torch.FloatTensor] = None, **kwargs, ) -> CausalLMOutputWithPast: + if self.add_patch and input_ids is not None and attention_mask is None: + attention_mask = torch.ones_like(input_ids) return self.model(input_ids=input_ids, attention_mask=attention_mask, **kwargs) def _prepare_generation_config( From d0694070531181dcbbe879af9717d0a84af56a94 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Thu, 12 Dec 2024 17:11:44 +0000 Subject: [PATCH 08/31] fix sdpa precision, still have error Signed-off-by: jiqing-feng --- optimum/exporters/ipex/modeling_utils.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/optimum/exporters/ipex/modeling_utils.py b/optimum/exporters/ipex/modeling_utils.py index 336b3871b..54e184582 100755 --- a/optimum/exporters/ipex/modeling_utils.py +++ b/optimum/exporters/ipex/modeling_utils.py @@ -578,6 +578,7 @@ def __init__(self, module, config) -> None: self.kv_head_mapping = torch.arange( 0, self.num_key_value_heads, dtype=torch.int32, device=self.module_device ).repeat_interleave(self.num_groups) + self.use_sdpa = False def qkv_gemm(self, hidden_states): raise NotImplementedError("Need to implement in specific model class") @@ -586,7 +587,10 @@ def rope(self, *args, **kwargs): raise NotImplementedError("Need to implement in specific model class") def postprocess_attention_output(self, attn_output): - attn_output = attn_output.reshape(-1, attn_output.shape[-2] * attn_output.shape[-1]) + if self.use_sdpa: + attn_output = attn_output.reshape(-1, self.embed_dim) + else: + attn_output = attn_output.reshape(-1, attn_output.shape[-2] * attn_output.shape[-1]) return attn_output def varlen_attn(self, query, key, value, past_key_value, input_lens): @@ -642,6 +646,7 @@ def forward( attn_mask=None, is_causal=True, ) + self.use_sdpa = True else: attn_output = self.varlen_attn(query, key_cache, value_cache, past_key_value, input_lens) else: @@ -754,7 +759,10 @@ def rope(self, query, key, *args, **kwargs): return query, key def postprocess_attention_output(self, attn_output): - attn_output = attn_output.reshape(-1, attn_output.shape[-2] * attn_output.shape[-1]) + if self.use_sdpa: + attn_output = attn_output.reshape(-1, self.embed_dim) + else: + attn_output = attn_output.reshape(-1, attn_output.shape[-2] * attn_output.shape[-1]) attn_output = self.c_proj(attn_output) return attn_output From 2c54045670910b199ca56bf8fec5857a02e0e7dd Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Fri, 13 Dec 2024 09:54:42 +0000 Subject: [PATCH 09/31] fix sdpa shape Signed-off-by: jiqing-feng --- optimum/exporters/ipex/modeling_utils.py | 52 +++++++++++++++++------- 1 file changed, 37 insertions(+), 15 deletions(-) diff --git a/optimum/exporters/ipex/modeling_utils.py b/optimum/exporters/ipex/modeling_utils.py index 54e184582..69a88b745 100755 --- a/optimum/exporters/ipex/modeling_utils.py +++ b/optimum/exporters/ipex/modeling_utils.py @@ -19,6 +19,9 @@ import torch from torch import nn from transformers.cache_utils import Cache +from transformers.modeling_attn_mask_utils import ( + _prepare_4d_causal_attention_mask_for_sdpa, +) from transformers.modeling_outputs import BaseModelOutputWithPast, BaseModelOutputWithPastAndCrossAttentions from optimum.intel.utils.import_utils import is_ipex_version @@ -195,6 +198,9 @@ def _llama_model_forward( next_decoder_cache = () if use_cache else None position_embeddings = self.rotary_emb(hidden_states, position_ids) + + input_lens = attention_mask.cumsum(-1)[:, -1].to(torch.int32) + if past_key_values_length == 0 and past_key_values is not None: # first token, remove the padding from hidden_states, varlen do not accept attention mask hidden_states_copy = hidden_states @@ -207,8 +213,12 @@ def _llama_model_forward( position_embeddings = (cos.unsqueeze(1), sin.unsqueeze(1)) else: hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) - - input_lens = attention_mask.cumsum(-1)[:, -1].to(torch.int32) + attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( + attention_mask=attention_mask, + input_shape=(input_ids.shape[0], input_ids.shape[-1]), + inputs_embeds=inputs_embeds, + past_key_values_length=past_key_values_length, + ) for idx, decoder_layer in enumerate(self.layers): if output_hidden_states: @@ -310,6 +320,8 @@ def _falcon_model_forward( # create position embeddings to be shared across the decoder layers position_embeddings = self.rotary_emb(hidden_states, position_ids) + input_lens = attention_mask.cumsum(-1)[:, -1].to(torch.int32) + if past_key_values_length == 0 and past_key_values is not None: # first token, remove the padding from hidden_states, varlen do not accept attention mask hidden_states_copy = hidden_states @@ -322,7 +334,12 @@ def _falcon_model_forward( position_embeddings = (cos.unsqueeze(1), sin.unsqueeze(1)) else: hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) - input_lens = attention_mask.cumsum(-1)[:, -1].to(torch.int32) + attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( + attention_mask=attention_mask, + input_shape=(input_ids.shape[0], input_ids.shape[-1]), + inputs_embeds=inputs_embeds, + past_key_values_length=past_key_values_length, + ) next_decoder_cache = None all_self_attentions = () if output_attentions else None @@ -437,6 +454,8 @@ def _gpt2_model_forward( hidden_states = self.drop(hidden_states) + input_lens = attention_mask.cumsum(-1)[:, -1].to(torch.int32) + if past_length == 0 and past_key_values is not None: # first token, remove the padding from hidden_states, varlen do not accept attention mask hidden_states_copy = hidden_states @@ -444,8 +463,12 @@ def _gpt2_model_forward( hidden_states = (hidden_states.view(-1, hidden_states.shape[-1]))[index] else: hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) - - input_lens = attention_mask.cumsum(-1)[:, -1].to(torch.int32) + attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( + attention_mask=attention_mask, + input_shape=(input_ids.shape[0], input_ids.shape[-1]), + inputs_embeds=inputs_embeds, + past_key_values_length=past_length, + ) presents = None all_self_attentions = () if output_attentions else None @@ -588,9 +611,8 @@ def rope(self, *args, **kwargs): def postprocess_attention_output(self, attn_output): if self.use_sdpa: - attn_output = attn_output.reshape(-1, self.embed_dim) - else: - attn_output = attn_output.reshape(-1, attn_output.shape[-2] * attn_output.shape[-1]) + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(-1, attn_output.shape[-2] * attn_output.shape[-1]) return attn_output def varlen_attn(self, query, key, value, past_key_value, input_lens): @@ -640,10 +662,11 @@ def forward( # prefill if past_key_value is None or is_ipex_version("<", _IPEX_MINIMUM_VERSION_FOR_FLASH_VARLEN_ATTN): attn_output = torch.nn.functional.scaled_dot_product_attention( - query.reshape(input_lens.shape[0], -1, query.shape[-1]), - key.reshape(input_lens.shape[0], -1, key.shape[-1]), - value.reshape(input_lens.shape[0], -1, value.shape[-1]), - attn_mask=None, + query.reshape(input_lens.shape[0], input_lens.max().item(), -1, query.shape[-1]).transpose(1, 2), + key.reshape(input_lens.shape[0], input_lens.max().item(), -1, key.shape[-1]).transpose(1, 2), + value.reshape(input_lens.shape[0], input_lens.max().item(), -1, value.shape[-1]).transpose(1, 2), + attn_mask=attention_mask, + dropout_p=0.0, is_causal=True, ) self.use_sdpa = True @@ -760,9 +783,8 @@ def rope(self, query, key, *args, **kwargs): def postprocess_attention_output(self, attn_output): if self.use_sdpa: - attn_output = attn_output.reshape(-1, self.embed_dim) - else: - attn_output = attn_output.reshape(-1, attn_output.shape[-2] * attn_output.shape[-1]) + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(-1, attn_output.shape[-2] * attn_output.shape[-1]) attn_output = self.c_proj(attn_output) return attn_output From bce9aa96f802c73184be09cd9d32d06eb0234ba2 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Fri, 13 Dec 2024 10:42:06 +0000 Subject: [PATCH 10/31] upgrad minimum torch version to 2.5 Signed-off-by: jiqing-feng --- .github/workflows/test_ipex.yml | 2 +- optimum/exporters/ipex/modeling_utils.py | 20 +++++++++++++++----- optimum/intel/ipex/modeling_base.py | 4 ++-- setup.py | 2 +- tests/ipex/test_modeling.py | 3 +++ 5 files changed, 22 insertions(+), 9 deletions(-) diff --git a/.github/workflows/test_ipex.yml b/.github/workflows/test_ipex.yml index de933e379..ffd1507ab 100644 --- a/.github/workflows/test_ipex.yml +++ b/.github/workflows/test_ipex.yml @@ -19,7 +19,7 @@ jobs: fail-fast: false matrix: transformers-version: ["4.46.0", "4.46.3"] - torch-version: ["2.4.0", "2.5.*"] + torch-version: ["2.5.*"] runs-on: ubuntu-22.04 diff --git a/optimum/exporters/ipex/modeling_utils.py b/optimum/exporters/ipex/modeling_utils.py index 69a88b745..beb162d5c 100755 --- a/optimum/exporters/ipex/modeling_utils.py +++ b/optimum/exporters/ipex/modeling_utils.py @@ -32,8 +32,7 @@ logger = logging.getLogger(__name__) -_IPEX_MINIMUM_VERSION_FOR_PATCHING = "2.4.0" -_IPEX_MINIMUM_VERSION_FOR_FLASH_VARLEN_ATTN = "2.5.0" +_IPEX_MINIMUM_VERSION_FOR_PATCHING = "2.5.0" if is_ipex_version("<", _IPEX_MINIMUM_VERSION_FOR_PATCHING): @@ -213,6 +212,8 @@ def _llama_model_forward( position_embeddings = (cos.unsqueeze(1), sin.unsqueeze(1)) else: hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) + + if past_key_values is None: attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( attention_mask=attention_mask, input_shape=(input_ids.shape[0], input_ids.shape[-1]), @@ -334,6 +335,8 @@ def _falcon_model_forward( position_embeddings = (cos.unsqueeze(1), sin.unsqueeze(1)) else: hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) + + if past_key_values is None: attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( attention_mask=attention_mask, input_shape=(input_ids.shape[0], input_ids.shape[-1]), @@ -463,6 +466,8 @@ def _gpt2_model_forward( hidden_states = (hidden_states.view(-1, hidden_states.shape[-1]))[index] else: hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) + + if past_key_values is None: attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( attention_mask=attention_mask, input_shape=(input_ids.shape[0], input_ids.shape[-1]), @@ -660,11 +665,16 @@ def forward( if past_len == 0: # prefill - if past_key_value is None or is_ipex_version("<", _IPEX_MINIMUM_VERSION_FOR_FLASH_VARLEN_ATTN): + if past_key_value is None: + n_rep = query.shape[1] // key.shape[1] attn_output = torch.nn.functional.scaled_dot_product_attention( query.reshape(input_lens.shape[0], input_lens.max().item(), -1, query.shape[-1]).transpose(1, 2), - key.reshape(input_lens.shape[0], input_lens.max().item(), -1, key.shape[-1]).transpose(1, 2), - value.reshape(input_lens.shape[0], input_lens.max().item(), -1, value.shape[-1]).transpose(1, 2), + key.reshape(input_lens.shape[0], input_lens.max().item(), -1, key.shape[-1]) + .transpose(1, 2) + .repeat_interleave(n_rep, 1), + value.reshape(input_lens.shape[0], input_lens.max().item(), -1, value.shape[-1]) + .transpose(1, 2) + .repeat_interleave(n_rep, 1), attn_mask=attention_mask, dropout_p=0.0, is_causal=True, diff --git a/optimum/intel/ipex/modeling_base.py b/optimum/intel/ipex/modeling_base.py index d8f830e51..15cda1ef5 100644 --- a/optimum/intel/ipex/modeling_base.py +++ b/optimum/intel/ipex/modeling_base.py @@ -299,9 +299,9 @@ def prepare_inputs_for_generation(self, *args, **kwargs): return self.model.prepare_inputs_for_generation(*args, **kwargs) def generate(self, *args, **kwargs): - if is_ipex_version("<", "2.4.0") and self._add_patch and kwargs.get("assistant_model", None): + if self._add_patch and kwargs.get("assistant_model", None): raise ValueError( - f"Assisted decoding is not supported for patched models if ipex < 2.4, support methods are {_IPEX_EXPORTED_GENERATION_METHODS}" + f"Assisted decoding is not supported for patched models for now, support methods are {_IPEX_EXPORTED_GENERATION_METHODS}" ) # Patch functions to support ipex_paged cache if self._add_patch: diff --git a/setup.py b/setup.py index ca415fca3..5ed2f7ba0 100644 --- a/setup.py +++ b/setup.py @@ -66,7 +66,7 @@ "nncf": ["nncf>=2.14.0"], "openvino": ["nncf>=2.14.0", "openvino>=2024.5.0", "openvino-tokenizers>=2024.5.0"], "neural-compressor": ["neural-compressor[pt]>3.0", "accelerate", "transformers<4.46"], - "ipex": ["intel-extension-for-pytorch>=2.4", "transformers>4.45,<4.47"], + "ipex": ["intel-extension-for-pytorch>=2.5", "transformers>4.45,<4.47"], "diffusers": ["diffusers"], "quality": QUALITY_REQUIRE, "tests": TESTS_REQUIRE, diff --git a/tests/ipex/test_modeling.py b/tests/ipex/test_modeling.py index 3a6abd9c3..85592cbe9 100644 --- a/tests/ipex/test_modeling.py +++ b/tests/ipex/test_modeling.py @@ -377,6 +377,9 @@ def test_compare_with_and_without_past_key_values(self): outputs_model_without_pkv = model_without_pkv.generate( **tokens, min_new_tokens=self.GENERATION_LENGTH, max_new_tokens=self.GENERATION_LENGTH, num_beams=1 ) + import pdb + + pdb.set_trace() self.assertTrue(torch.equal(outputs_model_with_pkv, outputs_model_without_pkv)) self.assertEqual(outputs_model_with_pkv.shape[1], self.GENERATION_LENGTH + tokens.input_ids.shape[1]) self.assertEqual(outputs_model_without_pkv.shape[1], self.GENERATION_LENGTH + tokens.input_ids.shape[1]) From 72ac9e608e24933f030c5afef0bb053873167769 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Fri, 13 Dec 2024 12:34:14 +0000 Subject: [PATCH 11/31] rm pdb Signed-off-by: jiqing-feng --- tests/ipex/test_modeling.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/tests/ipex/test_modeling.py b/tests/ipex/test_modeling.py index 85592cbe9..3a6abd9c3 100644 --- a/tests/ipex/test_modeling.py +++ b/tests/ipex/test_modeling.py @@ -377,9 +377,6 @@ def test_compare_with_and_without_past_key_values(self): outputs_model_without_pkv = model_without_pkv.generate( **tokens, min_new_tokens=self.GENERATION_LENGTH, max_new_tokens=self.GENERATION_LENGTH, num_beams=1 ) - import pdb - - pdb.set_trace() self.assertTrue(torch.equal(outputs_model_with_pkv, outputs_model_without_pkv)) self.assertEqual(outputs_model_with_pkv.shape[1], self.GENERATION_LENGTH + tokens.input_ids.shape[1]) self.assertEqual(outputs_model_without_pkv.shape[1], self.GENERATION_LENGTH + tokens.input_ids.shape[1]) From 3fdb3a5f76679fe8779d677425f68c10ca922728 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Mon, 16 Dec 2024 15:50:19 +0000 Subject: [PATCH 12/31] fix non patch path Signed-off-by: jiqing-feng --- optimum/exporters/ipex/model_patcher.py | 4 +- optimum/exporters/ipex/modeling_utils.py | 55 ++++++++++++++++++++++-- 2 files changed, 55 insertions(+), 4 deletions(-) diff --git a/optimum/exporters/ipex/model_patcher.py b/optimum/exporters/ipex/model_patcher.py index 03937754a..8c5ef5030 100644 --- a/optimum/exporters/ipex/model_patcher.py +++ b/optimum/exporters/ipex/model_patcher.py @@ -14,7 +14,7 @@ from transformers.models.bert.modeling_bert import BertIntermediate from transformers.models.falcon.modeling_falcon import FalconDecoderLayer, FalconModel -from transformers.models.gpt2.modeling_gpt2 import GPT2Attention, GPT2Block, GPT2Model +from transformers.models.gpt2.modeling_gpt2 import GPT2MLP, GPT2Attention, GPT2Block, GPT2Model from transformers.models.llama.modeling_llama import ( LlamaDecoderLayer, LlamaModel, @@ -27,6 +27,7 @@ from .modeling_utils import ( _IPEX_MINIMUM_VERSION_FOR_PATCHING, + _IPEXGPT2MLP, _falcon_model_forward, _gpt2_block_forward, _gpt2_model_forward, @@ -111,6 +112,7 @@ def _patch_gpt2_model(model): convert_functions(model, GPT2Model, "forward", _gpt2_model_forward) convert_functions(model, GPT2Block, "forward", _gpt2_block_forward) convert_class(model, GPT2Attention, _IPEXGPT2Attention, model.config) + convert_class(model, GPT2MLP, _IPEXGPT2MLP, model.config) return model diff --git a/optimum/exporters/ipex/modeling_utils.py b/optimum/exporters/ipex/modeling_utils.py index beb162d5c..aa558c437 100755 --- a/optimum/exporters/ipex/modeling_utils.py +++ b/optimum/exporters/ipex/modeling_utils.py @@ -46,6 +46,7 @@ LinearAdd, LinearAddAdd, LinearGelu, + LinearNewGelu, PagedAttention, ) @@ -557,7 +558,10 @@ def _gpt2_block_forward( attn_output = attn_outputs[0] # output_attn: a, present, (attentions) outputs = attn_outputs[1:] # residual connection - hidden_states = attn_output + residual + if hasattr(self.attn, "linear_add"): + hidden_states = self.attn.linear_add(attn_output, residual) + else: + hidden_states = attn_output + residual if encoder_hidden_states is not None: # add one self-attention block for cross-attention @@ -586,7 +590,10 @@ def _gpt2_block_forward( hidden_states = self.ln_2(hidden_states) feed_forward_hidden_states = self.mlp(hidden_states) # residual connection - hidden_states = residual + feed_forward_hidden_states + if hasattr(self.mlp, "linear_add"): + hidden_states = self.mlp.linear_add(feed_forward_hidden_states, residual) + else: + hidden_states = residual + feed_forward_hidden_states if use_cache: outputs = (hidden_states,) + outputs @@ -780,6 +787,13 @@ def __init__(self, module, config) -> None: self.c_proj_linear = nn.Linear(self.c_proj.weight.shape[0], self.c_proj.weight.shape[1]) self.c_proj_linear.weight = nn.Parameter(self.c_proj.weight.t()) self.c_proj_linear.bias = self.c_proj.bias + if self.module_device.type == "cpu": + if self.c_proj_linear not in ["LinearAllreduce"]: + self.linear_add = LinearAdd(self.c_proj_linear) + + elif self.module_device.type == "xpu": + if self.c_proj_linear not in ["LinearAllreduce"]: + self.linear_add = XPULinearAdd(self.c_proj_linear) def qkv_gemm(self, hidden_states): query, key, value = self.c_attn_linear(hidden_states).split(self.split_size, dim=-1) @@ -795,7 +809,8 @@ def postprocess_attention_output(self, attn_output): if self.use_sdpa: attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.reshape(-1, attn_output.shape[-2] * attn_output.shape[-1]) - attn_output = self.c_proj(attn_output) + if not hasattr(self, "linear_add"): + attn_output = self.c_proj(attn_output) return attn_output @@ -866,6 +881,40 @@ def forward( return output +class _IPEXGPT2MLP(nn.Module): + def __init__(self, module, config) -> None: + super().__init__() + _setattr_from_module(self, module) + self.config = config + self.module_device = next(module.parameters()).device + self.c_fc_linear = nn.Linear(self.c_fc.weight.shape[0], self.c_fc.weight.shape[1]) + self.c_fc_linear.weight = nn.Parameter(self.c_fc.weight.t()) + self.c_fc_linear.bias = self.c_fc.bias + self.c_proj_linear = nn.Linear(self.c_proj.weight.shape[0], self.c_proj.weight.shape[1]) + self.c_proj_linear.weight = nn.Parameter(self.c_proj.weight.t()) + self.c_proj_linear.bias = self.c_proj.bias + if self.module_device.type == "cpu": + self.linear_new_gelu = LinearNewGelu(self.c_fc_linear) + + if self.module_device.type == "cpu": + if self.c_proj_linear not in ["LinearAllreduce"]: + self.linear_add = LinearAdd(self.c_proj_linear) + + elif self.module_device.type == "xpu": + if self.c_proj_linear not in ["LinearAllreduce"]: + self.linear_add = XPULinearAdd(self.c_proj_linear) + + def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.FloatTensor: + if hasattr(self, "linear_new_gelu"): + hidden_states = self.linear_new_gelu(hidden_states) + else: + hidden_states = self.c_fc(hidden_states) + hidden_states = self.act(hidden_states) + if not hasattr(self, "linear_add"): + hidden_states = self.c_proj(hidden_states) + return hidden_states + + # Adapted from https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/llama/modeling_llama.py#L694 class _IPEXLlamaDecoderLayer(nn.Module): def __init__(self, module, config): From 6186aafc44d0000f3878041f2016b3effb6867b9 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Tue, 14 Jan 2025 10:10:07 +0000 Subject: [PATCH 13/31] use varlen if flash attn not available Signed-off-by: jiqing-feng --- .github/workflows/test_ipex.yml | 2 +- optimum/exporters/ipex/modeling_utils.py | 59 +++++++++++++++++------- 2 files changed, 43 insertions(+), 18 deletions(-) diff --git a/.github/workflows/test_ipex.yml b/.github/workflows/test_ipex.yml index ffd1507ab..de933e379 100644 --- a/.github/workflows/test_ipex.yml +++ b/.github/workflows/test_ipex.yml @@ -19,7 +19,7 @@ jobs: fail-fast: false matrix: transformers-version: ["4.46.0", "4.46.3"] - torch-version: ["2.5.*"] + torch-version: ["2.4.0", "2.5.*"] runs-on: ubuntu-22.04 diff --git a/optimum/exporters/ipex/modeling_utils.py b/optimum/exporters/ipex/modeling_utils.py index aa558c437..481bc7a3b 100755 --- a/optimum/exporters/ipex/modeling_utils.py +++ b/optimum/exporters/ipex/modeling_utils.py @@ -24,7 +24,7 @@ ) from transformers.modeling_outputs import BaseModelOutputWithPast, BaseModelOutputWithPastAndCrossAttentions -from optimum.intel.utils.import_utils import is_ipex_version +from optimum.intel.utils.import_utils import is_ipex_version, is_torch_version from optimum.intel.utils.modeling_utils import _setattr_from_module from .cache_utils import IPEXPagedCache @@ -32,7 +32,7 @@ logger = logging.getLogger(__name__) -_IPEX_MINIMUM_VERSION_FOR_PATCHING = "2.5.0" +_IPEX_MINIMUM_VERSION_FOR_PATCHING = "2.4.0" if is_ipex_version("<", _IPEX_MINIMUM_VERSION_FOR_PATCHING): @@ -40,7 +40,7 @@ f"Please upgrade the IPEX version to at least {_IPEX_MINIMUM_VERSION_FOR_PATCHING} if you want to patch the model." ) else: - from intel_extension_for_pytorch.llm.functional import rms_norm, rotary_embedding + from intel_extension_for_pytorch.llm.functional import rms_norm, rotary_embedding, varlen_attention from intel_extension_for_pytorch.llm.modules import ( Linear2SiluMul, LinearAdd, @@ -627,24 +627,49 @@ def postprocess_attention_output(self, attn_output): attn_output = attn_output.reshape(-1, attn_output.shape[-2] * attn_output.shape[-1]) return attn_output + # Maybe removed after torch 2.6 released + def has_flash_attn(query): + if query.device.type == "cpu": + return is_torch_version(">", "2.4.99") + elif query.device.type == "xpu": + return is_torch_version(">", "2.5.99") + def varlen_attn(self, query, key, value, past_key_value, input_lens): # prefill, remove padding attn_output = torch.empty_like(query) seq_len_tensor = torch.cat((input_lens.new_tensor([0]), input_lens.cumsum(-1).int())) - PagedAttention.flash_attn_varlen_func( - attn_output, - query, - key, - value, - seq_len_tensor, - seq_len_tensor, - input_lens.max(), - input_lens.max(), - 1.0 / math.sqrt(self.head_dim), - True, - past_key_value.block_tables, - None, - ) + if self.has_flash_attn(query): + PagedAttention.flash_attn_varlen_func( + attn_output, + query, + key, + value, + seq_len_tensor, + seq_len_tensor, + input_lens.max(), + input_lens.max(), + 1.0 / math.sqrt(self.head_dim), + True, + past_key_value.block_tables, + None, + ) + else: + varlen_attention( + query.contiguous() if query.device.type == "xpu" else query, + key.contiguous() if key.device.type == "xpu" else key, + value.contiguous() if value.device.type == "xpu" else value, + attn_output, + seq_len_tensor, + seq_len_tensor, + input_lens.max(), + input_lens.max(), + 0.0, + 1.0 / math.sqrt(self.head_dim), + False, + True, + False, + None, + ) return attn_output From cbc232ba78cd699defd3d479532ee89e4257d3f6 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Tue, 14 Jan 2025 10:11:18 +0000 Subject: [PATCH 14/31] revert ipex version change Signed-off-by: jiqing-feng --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index fce8ef6a2..0f02ef15c 100644 --- a/setup.py +++ b/setup.py @@ -66,7 +66,7 @@ "nncf": ["nncf>=2.14.0"], "openvino": ["nncf>=2.14.0", "openvino>=2024.5.0", "openvino-tokenizers>=2024.5.0"], "neural-compressor": ["neural-compressor[pt]>3.0", "accelerate", "transformers<4.46"], - "ipex": ["intel-extension-for-pytorch>=2.5", "transformers>4.45,<4.47", "accelerate"], + "ipex": ["intel-extension-for-pytorch>=2.4", "transformers>4.45,<4.47", "accelerate"], "diffusers": ["diffusers"], "quality": QUALITY_REQUIRE, "tests": TESTS_REQUIRE, From 4dd2e44ad1ab9c7ad5616e83a5720a782ba4f889 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Tue, 14 Jan 2025 10:14:29 +0000 Subject: [PATCH 15/31] fix flash attn check Signed-off-by: jiqing-feng --- optimum/exporters/ipex/modeling_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/optimum/exporters/ipex/modeling_utils.py b/optimum/exporters/ipex/modeling_utils.py index 481bc7a3b..010c079ce 100755 --- a/optimum/exporters/ipex/modeling_utils.py +++ b/optimum/exporters/ipex/modeling_utils.py @@ -628,7 +628,7 @@ def postprocess_attention_output(self, attn_output): return attn_output # Maybe removed after torch 2.6 released - def has_flash_attn(query): + def has_flash_attn(self, query): if query.device.type == "cpu": return is_torch_version(">", "2.4.99") elif query.device.type == "xpu": From 372d3f8fdd1755d39cc807618c337ea3166816e0 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Tue, 14 Jan 2025 10:34:33 +0000 Subject: [PATCH 16/31] prefill attn Signed-off-by: jiqing-feng --- optimum/exporters/ipex/modeling_utils.py | 52 +++++++++++++----------- 1 file changed, 28 insertions(+), 24 deletions(-) diff --git a/optimum/exporters/ipex/modeling_utils.py b/optimum/exporters/ipex/modeling_utils.py index 010c079ce..d3e4f5f85 100755 --- a/optimum/exporters/ipex/modeling_utils.py +++ b/optimum/exporters/ipex/modeling_utils.py @@ -634,16 +634,31 @@ def has_flash_attn(self, query): elif query.device.type == "xpu": return is_torch_version(">", "2.5.99") - def varlen_attn(self, query, key, value, past_key_value, input_lens): - # prefill, remove padding - attn_output = torch.empty_like(query) - seq_len_tensor = torch.cat((input_lens.new_tensor([0]), input_lens.cumsum(-1).int())) - if self.has_flash_attn(query): + def prefill_attn(self, query, key_cache, value_cache, key, value, past_key_value, attention_mask, input_lens): + if past_key_value is None: + n_rep = query.shape[1] // key.shape[1] + attn_output = torch.nn.functional.scaled_dot_product_attention( + query.reshape(input_lens.shape[0], input_lens.max().item(), -1, query.shape[-1]).transpose(1, 2), + key.reshape(input_lens.shape[0], input_lens.max().item(), -1, key.shape[-1]) + .transpose(1, 2) + .repeat_interleave(n_rep, 1), + value.reshape(input_lens.shape[0], input_lens.max().item(), -1, value.shape[-1]) + .transpose(1, 2) + .repeat_interleave(n_rep, 1), + attn_mask=attention_mask, + dropout_p=0.0, + is_causal=True, + ) + self.use_sdpa = True + elif self.has_flash_attn(query): + # prefill, remove padding + attn_output = torch.empty_like(query) + seq_len_tensor = torch.cat((input_lens.new_tensor([0]), input_lens.cumsum(-1).int())) PagedAttention.flash_attn_varlen_func( attn_output, query, - key, - value, + key_cache, + value_cache, seq_len_tensor, seq_len_tensor, input_lens.max(), @@ -654,6 +669,9 @@ def varlen_attn(self, query, key, value, past_key_value, input_lens): None, ) else: + # prefill, remove padding + attn_output = torch.empty_like(query) + seq_len_tensor = torch.cat((input_lens.new_tensor([0]), input_lens.cumsum(-1).int())) varlen_attention( query.contiguous() if query.device.type == "xpu" else query, key.contiguous() if key.device.type == "xpu" else key, @@ -697,23 +715,9 @@ def forward( if past_len == 0: # prefill - if past_key_value is None: - n_rep = query.shape[1] // key.shape[1] - attn_output = torch.nn.functional.scaled_dot_product_attention( - query.reshape(input_lens.shape[0], input_lens.max().item(), -1, query.shape[-1]).transpose(1, 2), - key.reshape(input_lens.shape[0], input_lens.max().item(), -1, key.shape[-1]) - .transpose(1, 2) - .repeat_interleave(n_rep, 1), - value.reshape(input_lens.shape[0], input_lens.max().item(), -1, value.shape[-1]) - .transpose(1, 2) - .repeat_interleave(n_rep, 1), - attn_mask=attention_mask, - dropout_p=0.0, - is_causal=True, - ) - self.use_sdpa = True - else: - attn_output = self.varlen_attn(query, key_cache, value_cache, past_key_value, input_lens) + attn_output = self.prefill_attn( + query, key_cache, value_cache, key, value, past_key_value, attention_mask, input_lens + ) else: # decode attn_output = torch.empty_like(query) From daddabfc69105d09350331cd15faee98a297c7a7 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Tue, 14 Jan 2025 10:43:34 +0000 Subject: [PATCH 17/31] fix cache Signed-off-by: jiqing-feng --- optimum/exporters/ipex/modeling_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/optimum/exporters/ipex/modeling_utils.py b/optimum/exporters/ipex/modeling_utils.py index d3e4f5f85..1bf605ca7 100755 --- a/optimum/exporters/ipex/modeling_utils.py +++ b/optimum/exporters/ipex/modeling_utils.py @@ -710,6 +710,7 @@ def forward( query, key, value = self.qkv_gemm(hidden_states) query, key = self.rope(query, key, **kwargs) + key_cache, value_cache = None, None if past_key_value is not None: key_cache, value_cache = past_key_value.update(key, value, self.layer_idx, attention_mask, input_lens) From 8e8c95f2e9292bcbf4934b6c5066e7bba7095a03 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Tue, 14 Jan 2025 13:04:46 +0000 Subject: [PATCH 18/31] qwen2 model forward Signed-off-by: jiqing-feng --- optimum/exporters/ipex/modeling_utils.py | 159 +++++++++++++++++++++++ 1 file changed, 159 insertions(+) diff --git a/optimum/exporters/ipex/modeling_utils.py b/optimum/exporters/ipex/modeling_utils.py index 1bf605ca7..dd97485a5 100755 --- a/optimum/exporters/ipex/modeling_utils.py +++ b/optimum/exporters/ipex/modeling_utils.py @@ -603,6 +603,138 @@ def _gpt2_block_forward( return outputs # hidden_states, present, (attentions, cross_attentions) +# Adapted from https://github.com/huggingface/transformers/blob/v4.48.0/src/transformers/models/qwen2/modeling_qwen2.py#L499 +def _qwen2_model_forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, +) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + batch_size, seq_length = inputs_embeds.shape[:2] + + if use_cache and past_key_values is None: + past_key_values = DynamicCache() + + past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 + if cache_position is None: + cache_position = torch.arange( + past_key_values_length, past_key_values_length + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + ) + + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + input_lens = attention_mask.cumsum(-1)[:, -1].to(torch.int32) + + if past_key_values_length == 0 and past_key_values is not None: + # first token, remove the padding from hidden_states, varlen do not accept attention mask + hidden_states_copy = hidden_states + index = attention_mask.view(-1) != 0 + hidden_states = (hidden_states.view(-1, hidden_states.shape[-1]))[index] + cos = position_embeddings[0] + sin = position_embeddings[1] + cos = (cos.reshape(-1, cos.shape[-1]))[index] + sin = (sin.reshape(-1, sin.shape[-1]))[index] + position_embeddings = (cos.unsqueeze(1), sin.unsqueeze(1)) + else: + hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) + + if past_key_values is None: + attention_mask = causal_mask + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + + for decoder_layer in self.layers[: self.config.num_hidden_layers]: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + attention_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, + position_embeddings, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + if hidden_states.shape[0] != batch_size * seq_length: + (hidden_states_copy.view(-1, hidden_states.shape[-1]))[attention_mask.view(-1) != 0] = hidden_states + hidden_states = hidden_states_copy + hidden_states = hidden_states.view(batch_size, -1, hidden_states.shape[-1]) + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + output = BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values if use_cache else None, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + return output if return_dict else output.to_tuple() + + class _IPEXAttention(nn.Module): def __init__(self, module, config) -> None: super().__init__() @@ -1005,6 +1137,33 @@ def forward(self, hidden_states: torch.Tensor, **kwargs): return outputs +# Adapted from https://github.com/huggingface/transformers/blob/v4.48.0/src/transformers/models/qwen2/modeling_qwen2.py#L228 +class _IPEXQwen2DecoderLayer(nn.Module): + def __init__(self, module, config, layer_idx): + super().__init__() + _setattr_from_module(self, module) + self.self_attn = _IPEXQwen2Attention(config=config, layer_idx=layer_idx) + self.mlp = _IPEXQwen2MLP(config) + + def forward(self, hidden_states: torch.Tensor, **kwargs): + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + # Self Attention + hidden_states, self_attn_weights = self.self_attn(**kwargs) + hidden_states = residual + hidden_states + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + if kwargs.get("output_attentions", None): + outputs += (self_attn_weights,) + + return outputs + + # Adapted from https://github.com/huggingface/transformers/blob/v4.41.2/src/transformers/models/bert/modeling_bert.py#L524 class _IPEXIntermediate(nn.Module): def __init__(self, module, config): From 95b7043d592100702b04464529e1777ddbbab5b9 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Tue, 14 Jan 2025 13:54:17 +0000 Subject: [PATCH 19/31] refactor attention Signed-off-by: jiqing-feng --- optimum/exporters/ipex/modeling_utils.py | 54 ++++++++++++------------ 1 file changed, 27 insertions(+), 27 deletions(-) diff --git a/optimum/exporters/ipex/modeling_utils.py b/optimum/exporters/ipex/modeling_utils.py index 1bf605ca7..46f3868cc 100755 --- a/optimum/exporters/ipex/modeling_utils.py +++ b/optimum/exporters/ipex/modeling_utils.py @@ -634,7 +634,9 @@ def has_flash_attn(self, query): elif query.device.type == "xpu": return is_torch_version(">", "2.5.99") - def prefill_attn(self, query, key_cache, value_cache, key, value, past_key_value, attention_mask, input_lens): + def attention_interface( + self, query, key_cache, value_cache, key, value, past_key_value, attention_mask, input_lens, past_len + ): if past_key_value is None: n_rep = query.shape[1] // key.shape[1] attn_output = torch.nn.functional.scaled_dot_product_attention( @@ -650,15 +652,15 @@ def prefill_attn(self, query, key_cache, value_cache, key, value, past_key_value is_causal=True, ) self.use_sdpa = True - elif self.has_flash_attn(query): + elif self.has_flash_attn(query) and past_len == 0: # prefill, remove padding attn_output = torch.empty_like(query) seq_len_tensor = torch.cat((input_lens.new_tensor([0]), input_lens.cumsum(-1).int())) PagedAttention.flash_attn_varlen_func( attn_output, - query, - key_cache, - value_cache, + query.contiguous() if query.device.type == "xpu" else query, + key_cache.contiguous() if key_cache.device.type == "xpu" else key_cache, + value_cache.contiguous() if value_cache.device.type == "xpu" else value_cache, seq_len_tensor, seq_len_tensor, input_lens.max(), @@ -668,7 +670,7 @@ def prefill_attn(self, query, key_cache, value_cache, key, value, past_key_value past_key_value.block_tables, None, ) - else: + elif past_len == 0: # prefill, remove padding attn_output = torch.empty_like(query) seq_len_tensor = torch.cat((input_lens.new_tensor([0]), input_lens.cumsum(-1).int())) @@ -688,6 +690,22 @@ def prefill_attn(self, query, key_cache, value_cache, key, value, past_key_value False, None, ) + else: + # decode + attn_output = torch.empty_like(query) + PagedAttention.single_query_cached_kv_attention( + attn_output, + query, + key_cache, + value_cache, + self.kv_head_mapping, + 1.0 / math.sqrt(self.head_dim), + past_key_value.block_tables, + input_lens, + past_key_value.block_size, + input_lens.max(), + None, + ) return attn_output @@ -714,27 +732,9 @@ def forward( if past_key_value is not None: key_cache, value_cache = past_key_value.update(key, value, self.layer_idx, attention_mask, input_lens) - if past_len == 0: - # prefill - attn_output = self.prefill_attn( - query, key_cache, value_cache, key, value, past_key_value, attention_mask, input_lens - ) - else: - # decode - attn_output = torch.empty_like(query) - PagedAttention.single_query_cached_kv_attention( - attn_output, - query, - key_cache, - value_cache, - self.kv_head_mapping, - 1.0 / math.sqrt(self.head_dim), - past_key_value.block_tables, - input_lens, - past_key_value.block_size, - input_lens.max(), - None, - ) + attn_output = self.attention_interface( + query, key_cache, value_cache, key, value, past_key_value, attention_mask, input_lens, past_len + ) attn_output = self.postprocess_attention_output(attn_output) if not output_attentions: From 71aa6b0e2632e49b1bd536875f24c44df33cf728 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Tue, 14 Jan 2025 15:06:49 +0000 Subject: [PATCH 20/31] use flash attn for decode Signed-off-by: jiqing-feng --- optimum/exporters/ipex/modeling_utils.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/optimum/exporters/ipex/modeling_utils.py b/optimum/exporters/ipex/modeling_utils.py index 46f3868cc..b69dfcfe5 100755 --- a/optimum/exporters/ipex/modeling_utils.py +++ b/optimum/exporters/ipex/modeling_utils.py @@ -652,18 +652,19 @@ def attention_interface( is_causal=True, ) self.use_sdpa = True - elif self.has_flash_attn(query) and past_len == 0: - # prefill, remove padding + elif self.has_flash_attn(query): attn_output = torch.empty_like(query) seq_len_tensor = torch.cat((input_lens.new_tensor([0]), input_lens.cumsum(-1).int())) + query_len_tensor = seq_len_tensor if past_len == 0 else torch.arange(seq_len_tensor.shape[0]) + query_max_len = input_lens.max() if past_len == 0 else 1 PagedAttention.flash_attn_varlen_func( attn_output, query.contiguous() if query.device.type == "xpu" else query, key_cache.contiguous() if key_cache.device.type == "xpu" else key_cache, value_cache.contiguous() if value_cache.device.type == "xpu" else value_cache, + query_len_tensor, seq_len_tensor, - seq_len_tensor, - input_lens.max(), + query_max_len, input_lens.max(), 1.0 / math.sqrt(self.head_dim), True, From 9211803eeb8269e3a50bc3a228df444621b24eff Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Tue, 14 Jan 2025 15:08:36 +0000 Subject: [PATCH 21/31] fix dtype Signed-off-by: jiqing-feng --- optimum/exporters/ipex/modeling_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/optimum/exporters/ipex/modeling_utils.py b/optimum/exporters/ipex/modeling_utils.py index b69dfcfe5..41dd5693d 100755 --- a/optimum/exporters/ipex/modeling_utils.py +++ b/optimum/exporters/ipex/modeling_utils.py @@ -655,7 +655,7 @@ def attention_interface( elif self.has_flash_attn(query): attn_output = torch.empty_like(query) seq_len_tensor = torch.cat((input_lens.new_tensor([0]), input_lens.cumsum(-1).int())) - query_len_tensor = seq_len_tensor if past_len == 0 else torch.arange(seq_len_tensor.shape[0]) + query_len_tensor = seq_len_tensor if past_len == 0 else torch.arange(seq_len_tensor.shape[0]).int() query_max_len = input_lens.max() if past_len == 0 else 1 PagedAttention.flash_attn_varlen_func( attn_output, From d3fbd65fa8f8953a46eb7c304a94875cad4aaede Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Tue, 14 Jan 2025 16:45:39 +0000 Subject: [PATCH 22/31] enable qwen2 model Signed-off-by: jiqing-feng --- optimum/exporters/ipex/model_patcher.py | 21 ++++++ optimum/exporters/ipex/modeling_utils.py | 92 ++++++------------------ optimum/intel/ipex/modeling_base.py | 4 +- 3 files changed, 46 insertions(+), 71 deletions(-) diff --git a/optimum/exporters/ipex/model_patcher.py b/optimum/exporters/ipex/model_patcher.py index 8c5ef5030..c171aef29 100644 --- a/optimum/exporters/ipex/model_patcher.py +++ b/optimum/exporters/ipex/model_patcher.py @@ -20,6 +20,11 @@ LlamaModel, LlamaRMSNorm, ) +from transformers.models.qwen2.modeling_qwen2 import ( + Qwen2DecoderLayer, + Qwen2Model, + Qwen2RMSNorm, +) from transformers.models.vit.modeling_vit import ViTIntermediate from optimum.intel.utils.import_utils import is_ipex_version, is_transformers_version @@ -36,7 +41,9 @@ _IPEXGPT2Attention, _IPEXIntermediate, _IPEXLlamaDecoderLayer, + _IPEXQwen2DecoderLayer, _llama_model_forward, + _qwen2_model_forward, ) @@ -116,6 +123,18 @@ def _patch_gpt2_model(model): return model +def _patch_qwen2_model(model): + """ + Patch qwen2 model: + 1. Use IPEX rope and paged cache + 2. Linear fusion with (2 Linears + Silu + Mul) and (Linear + Add) + """ + convert_functions(model, Qwen2Model, "forward", _qwen2_model_forward) + convert_functions(model, Qwen2RMSNorm, "forward", _ipex_rms_layer_norm_forward) + convert_class(model, Qwen2DecoderLayer, _IPEXQwen2DecoderLayer, model.config) + return model + + def _patch_bert_model(model): """ Patch bert model: @@ -149,6 +168,8 @@ def _patch_model(model): model = _patch_falcon_model(model) elif model.config.model_type == "gpt2": model = _patch_gpt2_model(model) + elif model.config.model_type == "qwen2": + model = _patch_qwen2_model(model) elif model.config.model_type == "bert": model = _patch_bert_model(model) elif model.config.model_type == "vit": diff --git a/optimum/exporters/ipex/modeling_utils.py b/optimum/exporters/ipex/modeling_utils.py index 8026ec3d7..7a21ef3c3 100755 --- a/optimum/exporters/ipex/modeling_utils.py +++ b/optimum/exporters/ipex/modeling_utils.py @@ -629,9 +629,7 @@ def _qwen2_model_forward( raise ValueError("You must specify exactly one of input_ids or inputs_embeds") if self.gradient_checkpointing and self.training and use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." - ) + logger.warning_once("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.") use_cache = False if inputs_embeds is None: @@ -639,9 +637,6 @@ def _qwen2_model_forward( batch_size, seq_length = inputs_embeds.shape[:2] - if use_cache and past_key_values is None: - past_key_values = DynamicCache() - past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 if cache_position is None: cache_position = torch.arange( @@ -686,30 +681,18 @@ def _qwen2_model_forward( if output_hidden_states: all_hidden_states += (hidden_states,) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - attention_mask, - position_ids, - past_key_values, - output_attentions, - use_cache, - cache_position, - position_embeddings, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - **kwargs, - ) + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + input_lens=input_lens, + **kwargs, + ) hidden_states = layer_outputs[0] @@ -750,8 +733,10 @@ def __init__(self, module, config) -> None: def qkv_gemm(self, hidden_states): raise NotImplementedError("Need to implement in specific model class") - def rope(self, *args, **kwargs): - raise NotImplementedError("Need to implement in specific model class") + def rope(self, query, key, **kwargs): + position_embeddings = kwargs.pop("position_embeddings", None) + rotary_embedding(query, key, position_embeddings[1], position_embeddings[0], query.size(-1), True) + return query, key def postprocess_attention_output(self, attn_output): if self.use_sdpa: @@ -880,13 +865,13 @@ class _IPEXLlamaAttention(_IPEXAttention): def __init__(self, module, config) -> None: super().__init__(module, config) concat_weight = torch.concat([self.q_proj.weight, self.k_proj.weight, self.v_proj.weight]).contiguous() - bias_list = [bias for bias in [self.q_proj.bias, self.k_proj.bias, self.v_proj.bias] if bias] + bias_list = [bias for bias in [self.q_proj.bias, self.k_proj.bias, self.v_proj.bias] if bias is not None] use_bias = bias_list != [] self.concat_qkv = nn.Linear(concat_weight.shape[1], concat_weight.shape[0], bias=use_bias) self.concat_qkv.weight = nn.Parameter(concat_weight) if use_bias: concat_bias = torch.concat(bias_list, 0).contiguous() - self.concat_linear.bias = nn.Parameter(concat_bias) + self.concat_qkv.bias = nn.Parameter(concat_bias) self.q_slice = self.q_proj.weight.shape[0] self.k_slice = self.q_slice + self.k_proj.weight.shape[0] self.v_slice = self.k_slice + self.v_proj.weight.shape[0] @@ -906,11 +891,6 @@ def qkv_gemm(self, hidden_states): return query, key, value - def rope(self, query, key, **kwargs): - position_embeddings = kwargs.pop("position_embeddings", None) - rotary_embedding(query, key, position_embeddings[1], position_embeddings[0], query.size(-1), True) - return query, key - class _IPEXFalconAttention(_IPEXAttention): def __init__(self, module, config): @@ -933,11 +913,6 @@ def qkv_gemm(self, hidden_states): value = qkv_out[:, self.k_slice :].view(-1, self.num_key_value_heads, self.head_dim) return query, key, value - def rope(self, query, key, **kwargs): - position_embeddings = kwargs.pop("position_embeddings", None) - rotary_embedding(query, key, position_embeddings[1], position_embeddings[0], query.size(-1), True) - return query, key - class _IPEXGPT2Attention(_IPEXAttention): def __init__(self, module, config) -> None: @@ -1138,31 +1113,10 @@ def forward(self, hidden_states: torch.Tensor, **kwargs): return outputs -# Adapted from https://github.com/huggingface/transformers/blob/v4.48.0/src/transformers/models/qwen2/modeling_qwen2.py#L228 -class _IPEXQwen2DecoderLayer(nn.Module): - def __init__(self, module, config, layer_idx): - super().__init__() - _setattr_from_module(self, module) - self.self_attn = _IPEXQwen2Attention(config=config, layer_idx=layer_idx) - self.mlp = _IPEXQwen2MLP(config) - - def forward(self, hidden_states: torch.Tensor, **kwargs): - residual = hidden_states - hidden_states = self.input_layernorm(hidden_states) - # Self Attention - hidden_states, self_attn_weights = self.self_attn(**kwargs) - hidden_states = residual + hidden_states - # Fully Connected - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states = self.mlp(hidden_states) - hidden_states = residual + hidden_states - - outputs = (hidden_states,) - if kwargs.get("output_attentions", None): - outputs += (self_attn_weights,) - - return outputs +# Currently can just apply llama decoder layer. +class _IPEXQwen2DecoderLayer(_IPEXLlamaDecoderLayer): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) # Adapted from https://github.com/huggingface/transformers/blob/v4.41.2/src/transformers/models/bert/modeling_bert.py#L524 diff --git a/optimum/intel/ipex/modeling_base.py b/optimum/intel/ipex/modeling_base.py index 3263e31db..76746c171 100644 --- a/optimum/intel/ipex/modeling_base.py +++ b/optimum/intel/ipex/modeling_base.py @@ -58,11 +58,11 @@ logger = logging.getLogger(__name__) -_IPEX_SUPPORT_MODEL_TYPES = ("llama", "bert", "vit", "falcon", "gpt2") +_IPEX_SUPPORT_MODEL_TYPES = ("llama", "bert", "vit", "falcon", "gpt2", "qwen2") _IPEX_EXPORTED_GENERATION_METHODS = ("sample", "greedy_search", "beam_sample", "beam_search", "assisted_generation") _IPEX_MINIMUM_VERSION_FOR_COMPILE = "2.5.0" # TODO: Some models are already fixed in torch 2.6, will enable them when torch upgrading to 2.6 -_COMPILE_NOT_READY_MODEL_TYPES = ("electra", "roformer", "gpt_neox", "beit", "llama", "falcon", "gpt2") +_COMPILE_NOT_READY_MODEL_TYPES = ("electra", "roformer", "gpt_neox", "beit", "llama", "falcon", "gpt2", "qwen2") def _is_patched_with_ipex(model, task, use_cache: bool = True): From 06798e202b53e7110b84bc50c6acb7dd59ff84a3 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Tue, 14 Jan 2025 17:27:22 +0000 Subject: [PATCH 23/31] enable qwen2 test Signed-off-by: jiqing-feng --- tests/ipex/test_modeling.py | 3 ++- tests/ipex/test_pipelines.py | 1 + tests/ipex/utils_tests.py | 1 + 3 files changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/ipex/test_modeling.py b/tests/ipex/test_modeling.py index 419e1bb42..a91d08ee0 100644 --- a/tests/ipex/test_modeling.py +++ b/tests/ipex/test_modeling.py @@ -233,8 +233,9 @@ class IPEXModelForCausalLMTest(unittest.TestCase): "distilgpt2", "mpt", "opt", + "qwen2", ) - IPEX_PATCHED_SUPPORTED_ARCHITECTURES = ("llama2", "falcon", "gpt2") + IPEX_PATCHED_SUPPORTED_ARCHITECTURES = ("llama2", "falcon", "gpt2", "qwen2") GENERATION_LENGTH = 100 SPEEDUP_CACHE = 1.0 diff --git a/tests/ipex/test_pipelines.py b/tests/ipex/test_pipelines.py index f376c6050..bcdc59208 100644 --- a/tests/ipex/test_pipelines.py +++ b/tests/ipex/test_pipelines.py @@ -66,6 +66,7 @@ class PipelinesIntegrationTest(unittest.TestCase): "mistral", "mpt", "opt", + "qwen2", ) QUESTION_ANSWERING_SUPPORTED_ARCHITECTURES = ( "bert", diff --git a/tests/ipex/utils_tests.py b/tests/ipex/utils_tests.py index 8cd93516d..75926e6b0 100644 --- a/tests/ipex/utils_tests.py +++ b/tests/ipex/utils_tests.py @@ -64,4 +64,5 @@ "patched_falcon": "Intel/tiny-random-falcon_ipex_model", "patched_gpt2": "Intel/tiny-random-gpt2_ipex_model", "patched_llama2": "Intel/tiny-random-llama2_ipex_model", + "qwen2": "Jiqing/tiny-random-Qwen2", } From 12dd802859b7ce6667d343782ac8606d053585ff Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Wed, 15 Jan 2025 10:09:30 +0000 Subject: [PATCH 24/31] set default block size Signed-off-by: jiqing-feng --- optimum/exporters/ipex/cache_utils.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/optimum/exporters/ipex/cache_utils.py b/optimum/exporters/ipex/cache_utils.py index 7154c4449..b91da262f 100755 --- a/optimum/exporters/ipex/cache_utils.py +++ b/optimum/exporters/ipex/cache_utils.py @@ -5,6 +5,10 @@ from transformers import Cache, PretrainedConfig +# May need to tune based on sequence length and different models but default to 16 currently. +BLOCK_SIZE = 16 + + class IPEXPagedCache(Cache): """ A PagedCache that grows dynamically as more tokens are generated. everytime it grows block-size memory, vendor could set the pageCache memory layout. @@ -44,7 +48,7 @@ def __init__( self.batch_size = batch_size # Used in `generate` to keep tally of how many tokens the cache has seen self._seen_tokens = torch.zeros([batch_size], dtype=torch.int32, device=device) - self.block_size = 64 + self.block_size = BLOCK_SIZE self.num_blocks = (max_cache_len // self.block_size + (max_cache_len % self.block_size != 0)) * batch_size self.block_tables = -1 * torch.ones([self.num_blocks], dtype=torch.int32, device=device).reshape( batch_size, -1 From c6d2d0f9d140c401761fb67d62b63834c25bbace Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Wed, 15 Jan 2025 10:38:14 +0000 Subject: [PATCH 25/31] decoding use single query Signed-off-by: jiqing-feng --- optimum/exporters/ipex/modeling_utils.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/optimum/exporters/ipex/modeling_utils.py b/optimum/exporters/ipex/modeling_utils.py index 41dd5693d..4e8de1012 100755 --- a/optimum/exporters/ipex/modeling_utils.py +++ b/optimum/exporters/ipex/modeling_utils.py @@ -652,19 +652,17 @@ def attention_interface( is_causal=True, ) self.use_sdpa = True - elif self.has_flash_attn(query): + elif self.has_flash_attn(query) and past_len == 0: attn_output = torch.empty_like(query) seq_len_tensor = torch.cat((input_lens.new_tensor([0]), input_lens.cumsum(-1).int())) - query_len_tensor = seq_len_tensor if past_len == 0 else torch.arange(seq_len_tensor.shape[0]).int() - query_max_len = input_lens.max() if past_len == 0 else 1 PagedAttention.flash_attn_varlen_func( attn_output, query.contiguous() if query.device.type == "xpu" else query, key_cache.contiguous() if key_cache.device.type == "xpu" else key_cache, value_cache.contiguous() if value_cache.device.type == "xpu" else value_cache, - query_len_tensor, seq_len_tensor, - query_max_len, + seq_len_tensor, + input_lens.max(), input_lens.max(), 1.0 / math.sqrt(self.head_dim), True, From acfd0cecd3e3e0a272f9fd67fe7557f9ecbbb4fd Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Wed, 15 Jan 2025 15:04:39 +0000 Subject: [PATCH 26/31] fix position_id init for qwen2 Signed-off-by: jiqing-feng --- optimum/exporters/ipex/modeling_utils.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/optimum/exporters/ipex/modeling_utils.py b/optimum/exporters/ipex/modeling_utils.py index 603d2ee80..b62eb8bf5 100755 --- a/optimum/exporters/ipex/modeling_utils.py +++ b/optimum/exporters/ipex/modeling_utils.py @@ -644,7 +644,11 @@ def _qwen2_model_forward( ) if position_ids is None: - position_ids = cache_position.unsqueeze(0) + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = torch.arange( + past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device + ) + position_ids = position_ids.unsqueeze(0).repeat_interleave(input_ids.shape[0], 0) causal_mask = self._update_causal_mask( attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions From ccbe97aa84a258656d850c19d341d8d897a59e53 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Wed, 15 Jan 2025 15:08:50 +0000 Subject: [PATCH 27/31] add patched qwen2 test Signed-off-by: jiqing-feng --- tests/ipex/utils_tests.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/ipex/utils_tests.py b/tests/ipex/utils_tests.py index 75926e6b0..72f407cc1 100644 --- a/tests/ipex/utils_tests.py +++ b/tests/ipex/utils_tests.py @@ -50,6 +50,7 @@ "mt5": "stas/mt5-tiny-random", "opt": "hf-internal-testing/tiny-random-OPTModel", "phi": "echarlaix/tiny-random-PhiForCausalLM", + "qwen2": "Jiqing/tiny-random-Qwen2", "resnet": "hf-internal-testing/tiny-random-resnet", "roberta": "hf-internal-testing/tiny-random-roberta", "roformer": "hf-internal-testing/tiny-random-roformer", @@ -64,5 +65,5 @@ "patched_falcon": "Intel/tiny-random-falcon_ipex_model", "patched_gpt2": "Intel/tiny-random-gpt2_ipex_model", "patched_llama2": "Intel/tiny-random-llama2_ipex_model", - "qwen2": "Jiqing/tiny-random-Qwen2", + "patched_qwen2": "Jiqing/tiny-random-Qwen2_ipex_model", } From ee7dd81ad16dbe0929e09845a3a4396fd31cff2a Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Wed, 15 Jan 2025 15:10:52 +0000 Subject: [PATCH 28/31] fix format Signed-off-by: jiqing-feng --- optimum/exporters/ipex/model_patcher.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/optimum/exporters/ipex/model_patcher.py b/optimum/exporters/ipex/model_patcher.py index fee825ba6..c171aef29 100644 --- a/optimum/exporters/ipex/model_patcher.py +++ b/optimum/exporters/ipex/model_patcher.py @@ -15,7 +15,6 @@ from transformers.models.bert.modeling_bert import BertIntermediate from transformers.models.falcon.modeling_falcon import FalconDecoderLayer, FalconModel from transformers.models.gpt2.modeling_gpt2 import GPT2MLP, GPT2Attention, GPT2Block, GPT2Model -from transformers.models.gpt2.modeling_gpt2 import GPT2MLP, GPT2Attention, GPT2Block, GPT2Model from transformers.models.llama.modeling_llama import ( LlamaDecoderLayer, LlamaModel, @@ -34,7 +33,6 @@ from .modeling_utils import ( _IPEX_MINIMUM_VERSION_FOR_PATCHING, _IPEXGPT2MLP, - _IPEXGPT2MLP, _falcon_model_forward, _gpt2_block_forward, _gpt2_model_forward, From c86fd1ca17191f44ba221c93208f631d405210c9 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Wed, 15 Jan 2025 15:37:29 +0000 Subject: [PATCH 29/31] fix pipeline test Signed-off-by: jiqing-feng --- tests/ipex/test_pipelines.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/ipex/test_pipelines.py b/tests/ipex/test_pipelines.py index bcdc59208..62c3877b5 100644 --- a/tests/ipex/test_pipelines.py +++ b/tests/ipex/test_pipelines.py @@ -145,10 +145,11 @@ def test_text_generation_pipeline_inference(self, model_arch): "text-generation", model_id, accelerator="ipex", torch_dtype=dtype, device_map=DEVICE ) inputs = "Describe a real-world application of AI." + max_new_tokens = 10 if model_arch != "qwen2" else 2 with torch.inference_mode(): - transformers_output = transformers_generator(inputs, do_sample=False, max_new_tokens=10) + transformers_output = transformers_generator(inputs, do_sample=False, max_new_tokens=max_new_tokens) with torch.inference_mode(): - ipex_output = ipex_generator(inputs, do_sample=False, max_new_tokens=10) + ipex_output = ipex_generator(inputs, do_sample=False, max_new_tokens=max_new_tokens) self.assertTrue(isinstance(ipex_generator.model, IPEXModelForCausalLM)) self.assertEqual(transformers_output[0]["generated_text"], ipex_output[0]["generated_text"]) From 5b930362210142d954d09de741140cbb8ff448b9 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Thu, 16 Jan 2025 11:31:31 +0000 Subject: [PATCH 30/31] set block size as a env parameter Signed-off-by: jiqing-feng --- optimum/exporters/ipex/cache_utils.py | 5 +++-- optimum/exporters/ipex/modeling_utils.py | 8 +++++--- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/optimum/exporters/ipex/cache_utils.py b/optimum/exporters/ipex/cache_utils.py index b91da262f..ded696483 100755 --- a/optimum/exporters/ipex/cache_utils.py +++ b/optimum/exporters/ipex/cache_utils.py @@ -1,3 +1,4 @@ +import os from typing import List, Optional, Tuple import torch @@ -5,8 +6,8 @@ from transformers import Cache, PretrainedConfig -# May need to tune based on sequence length and different models but default to 16 currently. -BLOCK_SIZE = 16 +# Recommend 16 on CPU and 64 on XPU. +BLOCK_SIZE = int(os.environ.get("OI_PAGED_ATTN_BLOCK_SIZE", 16)) class IPEXPagedCache(Cache): diff --git a/optimum/exporters/ipex/modeling_utils.py b/optimum/exporters/ipex/modeling_utils.py index 4e8de1012..41dd5693d 100755 --- a/optimum/exporters/ipex/modeling_utils.py +++ b/optimum/exporters/ipex/modeling_utils.py @@ -652,17 +652,19 @@ def attention_interface( is_causal=True, ) self.use_sdpa = True - elif self.has_flash_attn(query) and past_len == 0: + elif self.has_flash_attn(query): attn_output = torch.empty_like(query) seq_len_tensor = torch.cat((input_lens.new_tensor([0]), input_lens.cumsum(-1).int())) + query_len_tensor = seq_len_tensor if past_len == 0 else torch.arange(seq_len_tensor.shape[0]).int() + query_max_len = input_lens.max() if past_len == 0 else 1 PagedAttention.flash_attn_varlen_func( attn_output, query.contiguous() if query.device.type == "xpu" else query, key_cache.contiguous() if key_cache.device.type == "xpu" else key_cache, value_cache.contiguous() if value_cache.device.type == "xpu" else value_cache, + query_len_tensor, seq_len_tensor, - seq_len_tensor, - input_lens.max(), + query_max_len, input_lens.max(), 1.0 / math.sqrt(self.head_dim), True, From 31accd2134daa78e8e77cfc3721dbef9c865a7d2 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Thu, 16 Jan 2025 13:29:48 +0000 Subject: [PATCH 31/31] set different default value for block size based on device Signed-off-by: jiqing-feng --- optimum/exporters/ipex/cache_utils.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/optimum/exporters/ipex/cache_utils.py b/optimum/exporters/ipex/cache_utils.py index ded696483..f9df2cf69 100755 --- a/optimum/exporters/ipex/cache_utils.py +++ b/optimum/exporters/ipex/cache_utils.py @@ -6,10 +6,6 @@ from transformers import Cache, PretrainedConfig -# Recommend 16 on CPU and 64 on XPU. -BLOCK_SIZE = int(os.environ.get("OI_PAGED_ATTN_BLOCK_SIZE", 16)) - - class IPEXPagedCache(Cache): """ A PagedCache that grows dynamically as more tokens are generated. everytime it grows block-size memory, vendor could set the pageCache memory layout. @@ -49,7 +45,8 @@ def __init__( self.batch_size = batch_size # Used in `generate` to keep tally of how many tokens the cache has seen self._seen_tokens = torch.zeros([batch_size], dtype=torch.int32, device=device) - self.block_size = BLOCK_SIZE + default_block_size = 16 if device.type == "cpu" else 64 + self.block_size = int(os.environ.get("OI_PAGED_ATTN_BLOCK_SIZE", str(default_block_size))) self.num_blocks = (max_cache_len // self.block_size + (max_cache_len % self.block_size != 0)) * batch_size self.block_tables = -1 * torch.ones([self.num_blocks], dtype=torch.int32, device=device).reshape( batch_size, -1