diff --git a/optimum/exporters/ipex/cache_utils.py b/optimum/exporters/ipex/cache_utils.py index 91821989f..7154c4449 100755 --- a/optimum/exporters/ipex/cache_utils.py +++ b/optimum/exporters/ipex/cache_utils.py @@ -44,7 +44,7 @@ def __init__( self.batch_size = batch_size # Used in `generate` to keep tally of how many tokens the cache has seen self._seen_tokens = torch.zeros([batch_size], dtype=torch.int32, device=device) - self.block_size = 16 + self.block_size = 64 self.num_blocks = (max_cache_len // self.block_size + (max_cache_len % self.block_size != 0)) * batch_size self.block_tables = -1 * torch.ones([self.num_blocks], dtype=torch.int32, device=device).reshape( batch_size, -1 diff --git a/optimum/exporters/ipex/model_patcher.py b/optimum/exporters/ipex/model_patcher.py index 03937754a..8c5ef5030 100644 --- a/optimum/exporters/ipex/model_patcher.py +++ b/optimum/exporters/ipex/model_patcher.py @@ -14,7 +14,7 @@ from transformers.models.bert.modeling_bert import BertIntermediate from transformers.models.falcon.modeling_falcon import FalconDecoderLayer, FalconModel -from transformers.models.gpt2.modeling_gpt2 import GPT2Attention, GPT2Block, GPT2Model +from transformers.models.gpt2.modeling_gpt2 import GPT2MLP, GPT2Attention, GPT2Block, GPT2Model from transformers.models.llama.modeling_llama import ( LlamaDecoderLayer, LlamaModel, @@ -27,6 +27,7 @@ from .modeling_utils import ( _IPEX_MINIMUM_VERSION_FOR_PATCHING, + _IPEXGPT2MLP, _falcon_model_forward, _gpt2_block_forward, _gpt2_model_forward, @@ -111,6 +112,7 @@ def _patch_gpt2_model(model): convert_functions(model, GPT2Model, "forward", _gpt2_model_forward) convert_functions(model, GPT2Block, "forward", _gpt2_block_forward) convert_class(model, GPT2Attention, _IPEXGPT2Attention, model.config) + convert_class(model, GPT2MLP, _IPEXGPT2MLP, model.config) return model diff --git a/optimum/exporters/ipex/modeling_utils.py b/optimum/exporters/ipex/modeling_utils.py index ca51c47fb..41dd5693d 100755 --- a/optimum/exporters/ipex/modeling_utils.py +++ b/optimum/exporters/ipex/modeling_utils.py @@ -19,9 +19,12 @@ import torch from torch import nn from transformers.cache_utils import Cache +from transformers.modeling_attn_mask_utils import ( + _prepare_4d_causal_attention_mask_for_sdpa, +) from transformers.modeling_outputs import BaseModelOutputWithPast, BaseModelOutputWithPastAndCrossAttentions -from optimum.intel.utils.import_utils import is_ipex_version +from optimum.intel.utils.import_utils import is_ipex_version, is_torch_version from optimum.intel.utils.modeling_utils import _setattr_from_module from .cache_utils import IPEXPagedCache @@ -43,6 +46,7 @@ LinearAdd, LinearAddAdd, LinearGelu, + LinearNewGelu, PagedAttention, ) @@ -194,7 +198,10 @@ def _llama_model_forward( next_decoder_cache = () if use_cache else None position_embeddings = self.rotary_emb(hidden_states, position_ids) - if past_key_values_length == 0: + + input_lens = attention_mask.cumsum(-1)[:, -1].to(torch.int32) + + if past_key_values_length == 0 and past_key_values is not None: # first token, remove the padding from hidden_states, varlen do not accept attention mask hidden_states_copy = hidden_states index = attention_mask.view(-1) != 0 @@ -207,7 +214,13 @@ def _llama_model_forward( else: hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) - input_lens = attention_mask.cumsum(-1)[:, -1].to(torch.int32) + if past_key_values is None: + attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( + attention_mask=attention_mask, + input_shape=(input_ids.shape[0], input_ids.shape[-1]), + inputs_embeds=inputs_embeds, + past_key_values_length=past_key_values_length, + ) for idx, decoder_layer in enumerate(self.layers): if output_hidden_states: @@ -309,7 +322,9 @@ def _falcon_model_forward( # create position embeddings to be shared across the decoder layers position_embeddings = self.rotary_emb(hidden_states, position_ids) - if past_key_values_length == 0: + input_lens = attention_mask.cumsum(-1)[:, -1].to(torch.int32) + + if past_key_values_length == 0 and past_key_values is not None: # first token, remove the padding from hidden_states, varlen do not accept attention mask hidden_states_copy = hidden_states index = attention_mask.view(-1) != 0 @@ -321,7 +336,14 @@ def _falcon_model_forward( position_embeddings = (cos.unsqueeze(1), sin.unsqueeze(1)) else: hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) - input_lens = attention_mask.cumsum(-1)[:, -1].to(torch.int32) + + if past_key_values is None: + attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( + attention_mask=attention_mask, + input_shape=(input_ids.shape[0], input_ids.shape[-1]), + inputs_embeds=inputs_embeds, + past_key_values_length=past_key_values_length, + ) next_decoder_cache = None all_self_attentions = () if output_attentions else None @@ -436,7 +458,9 @@ def _gpt2_model_forward( hidden_states = self.drop(hidden_states) - if past_length == 0: + input_lens = attention_mask.cumsum(-1)[:, -1].to(torch.int32) + + if past_length == 0 and past_key_values is not None: # first token, remove the padding from hidden_states, varlen do not accept attention mask hidden_states_copy = hidden_states index = attention_mask.view(-1) != 0 @@ -444,7 +468,13 @@ def _gpt2_model_forward( else: hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) - input_lens = attention_mask.cumsum(-1)[:, -1].to(torch.int32) + if past_key_values is None: + attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( + attention_mask=attention_mask, + input_shape=(input_ids.shape[0], input_ids.shape[-1]), + inputs_embeds=inputs_embeds, + past_key_values_length=past_length, + ) presents = None all_self_attentions = () if output_attentions else None @@ -528,7 +558,10 @@ def _gpt2_block_forward( attn_output = attn_outputs[0] # output_attn: a, present, (attentions) outputs = attn_outputs[1:] # residual connection - hidden_states = attn_output + residual + if hasattr(self.attn, "linear_add"): + hidden_states = self.attn.linear_add(attn_output, residual) + else: + hidden_states = attn_output + residual if encoder_hidden_states is not None: # add one self-attention block for cross-attention @@ -557,7 +590,10 @@ def _gpt2_block_forward( hidden_states = self.ln_2(hidden_states) feed_forward_hidden_states = self.mlp(hidden_states) # residual connection - hidden_states = residual + feed_forward_hidden_states + if hasattr(self.mlp, "linear_add"): + hidden_states = self.mlp.linear_add(feed_forward_hidden_states, residual) + else: + hidden_states = residual + feed_forward_hidden_states if use_cache: outputs = (hidden_states,) + outputs @@ -577,6 +613,7 @@ def __init__(self, module, config) -> None: self.kv_head_mapping = torch.arange( 0, self.num_key_value_heads, dtype=torch.int32, device=self.module_device ).repeat_interleave(self.num_groups) + self.use_sdpa = False def qkv_gemm(self, hidden_states): raise NotImplementedError("Need to implement in specific model class") @@ -585,34 +622,58 @@ def rope(self, *args, **kwargs): raise NotImplementedError("Need to implement in specific model class") def postprocess_attention_output(self, attn_output): + if self.use_sdpa: + attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.reshape(-1, attn_output.shape[-2] * attn_output.shape[-1]) return attn_output - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[IPEXPagedCache] = None, - output_attentions: bool = False, - use_cache: bool = False, - **kwargs, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - if past_key_value is None and kwargs.get("layer_past", None) is not None: - past_key_value = kwargs.pop("layer_past", None) - input_lens = kwargs.pop("input_lens", None) - past_len = 0 - if past_key_value is not None: - past_len = past_key_value.get_seq_length() - query, key, value = self.qkv_gemm(hidden_states) - query, key = self.rope(query, key, **kwargs) + # Maybe removed after torch 2.6 released + def has_flash_attn(self, query): + if query.device.type == "cpu": + return is_torch_version(">", "2.4.99") + elif query.device.type == "xpu": + return is_torch_version(">", "2.5.99") - if past_key_value is not None: - key_cache, value_cache = past_key_value.update(key, value, self.layer_idx, attention_mask, input_lens) - - attn_output = torch.empty_like(query) - if past_len == 0: + def attention_interface( + self, query, key_cache, value_cache, key, value, past_key_value, attention_mask, input_lens, past_len + ): + if past_key_value is None: + n_rep = query.shape[1] // key.shape[1] + attn_output = torch.nn.functional.scaled_dot_product_attention( + query.reshape(input_lens.shape[0], input_lens.max().item(), -1, query.shape[-1]).transpose(1, 2), + key.reshape(input_lens.shape[0], input_lens.max().item(), -1, key.shape[-1]) + .transpose(1, 2) + .repeat_interleave(n_rep, 1), + value.reshape(input_lens.shape[0], input_lens.max().item(), -1, value.shape[-1]) + .transpose(1, 2) + .repeat_interleave(n_rep, 1), + attn_mask=attention_mask, + dropout_p=0.0, + is_causal=True, + ) + self.use_sdpa = True + elif self.has_flash_attn(query): + attn_output = torch.empty_like(query) + seq_len_tensor = torch.cat((input_lens.new_tensor([0]), input_lens.cumsum(-1).int())) + query_len_tensor = seq_len_tensor if past_len == 0 else torch.arange(seq_len_tensor.shape[0]).int() + query_max_len = input_lens.max() if past_len == 0 else 1 + PagedAttention.flash_attn_varlen_func( + attn_output, + query.contiguous() if query.device.type == "xpu" else query, + key_cache.contiguous() if key_cache.device.type == "xpu" else key_cache, + value_cache.contiguous() if value_cache.device.type == "xpu" else value_cache, + query_len_tensor, + seq_len_tensor, + query_max_len, + input_lens.max(), + 1.0 / math.sqrt(self.head_dim), + True, + past_key_value.block_tables, + None, + ) + elif past_len == 0: # prefill, remove padding + attn_output = torch.empty_like(query) seq_len_tensor = torch.cat((input_lens.new_tensor([0]), input_lens.cumsum(-1).int())) varlen_attention( query.contiguous() if query.device.type == "xpu" else query, @@ -632,6 +693,7 @@ def forward( ) else: # decode + attn_output = torch.empty_like(query) PagedAttention.single_query_cached_kv_attention( attn_output, query, @@ -646,6 +708,35 @@ def forward( None, ) + return attn_output + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[IPEXPagedCache] = None, + output_attentions: bool = False, + use_cache: bool = False, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if past_key_value is None and kwargs.get("layer_past", None) is not None: + past_key_value = kwargs.pop("layer_past", None) + input_lens = kwargs.pop("input_lens", None) + past_len = 0 + if past_key_value is not None: + past_len = past_key_value.get_seq_length() + query, key, value = self.qkv_gemm(hidden_states) + query, key = self.rope(query, key, **kwargs) + + key_cache, value_cache = None, None + if past_key_value is not None: + key_cache, value_cache = past_key_value.update(key, value, self.layer_idx, attention_mask, input_lens) + + attn_output = self.attention_interface( + query, key_cache, value_cache, key, value, past_key_value, attention_mask, input_lens, past_len + ) + attn_output = self.postprocess_attention_output(attn_output) if not output_attentions: attn_weights = None @@ -720,9 +811,23 @@ class _IPEXGPT2Attention(_IPEXAttention): def __init__(self, module, config) -> None: self.num_key_value_heads = config.num_key_value_heads super().__init__(module, config) + _setattr_from_module(self, module) + self.c_attn_linear = nn.Linear(self.c_attn.weight.shape[0], self.c_attn.weight.shape[1]) + self.c_attn_linear.weight = nn.Parameter(self.c_attn.weight.t()) + self.c_attn_linear.bias = self.c_attn.bias + self.c_proj_linear = nn.Linear(self.c_proj.weight.shape[0], self.c_proj.weight.shape[1]) + self.c_proj_linear.weight = nn.Parameter(self.c_proj.weight.t()) + self.c_proj_linear.bias = self.c_proj.bias + if self.module_device.type == "cpu": + if self.c_proj_linear not in ["LinearAllreduce"]: + self.linear_add = LinearAdd(self.c_proj_linear) + + elif self.module_device.type == "xpu": + if self.c_proj_linear not in ["LinearAllreduce"]: + self.linear_add = XPULinearAdd(self.c_proj_linear) def qkv_gemm(self, hidden_states): - query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=-1) + query, key, value = self.c_attn_linear(hidden_states).split(self.split_size, dim=-1) query = query.view(-1, self.num_heads, self.head_dim) key = key.view(-1, self.num_heads, self.head_dim) value = value.view(-1, self.num_heads, self.head_dim) @@ -732,9 +837,11 @@ def rope(self, query, key, *args, **kwargs): return query, key def postprocess_attention_output(self, attn_output): + if self.use_sdpa: + attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.reshape(-1, attn_output.shape[-2] * attn_output.shape[-1]) - attn_output = self.c_proj(attn_output) - attn_output = self.resid_dropout(attn_output) + if not hasattr(self, "linear_add"): + attn_output = self.c_proj(attn_output) return attn_output @@ -805,6 +912,40 @@ def forward( return output +class _IPEXGPT2MLP(nn.Module): + def __init__(self, module, config) -> None: + super().__init__() + _setattr_from_module(self, module) + self.config = config + self.module_device = next(module.parameters()).device + self.c_fc_linear = nn.Linear(self.c_fc.weight.shape[0], self.c_fc.weight.shape[1]) + self.c_fc_linear.weight = nn.Parameter(self.c_fc.weight.t()) + self.c_fc_linear.bias = self.c_fc.bias + self.c_proj_linear = nn.Linear(self.c_proj.weight.shape[0], self.c_proj.weight.shape[1]) + self.c_proj_linear.weight = nn.Parameter(self.c_proj.weight.t()) + self.c_proj_linear.bias = self.c_proj.bias + if self.module_device.type == "cpu": + self.linear_new_gelu = LinearNewGelu(self.c_fc_linear) + + if self.module_device.type == "cpu": + if self.c_proj_linear not in ["LinearAllreduce"]: + self.linear_add = LinearAdd(self.c_proj_linear) + + elif self.module_device.type == "xpu": + if self.c_proj_linear not in ["LinearAllreduce"]: + self.linear_add = XPULinearAdd(self.c_proj_linear) + + def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.FloatTensor: + if hasattr(self, "linear_new_gelu"): + hidden_states = self.linear_new_gelu(hidden_states) + else: + hidden_states = self.c_fc(hidden_states) + hidden_states = self.act(hidden_states) + if not hasattr(self, "linear_add"): + hidden_states = self.c_proj(hidden_states) + return hidden_states + + # Adapted from https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/llama/modeling_llama.py#L694 class _IPEXLlamaDecoderLayer(nn.Module): def __init__(self, module, config): diff --git a/optimum/intel/ipex/modeling_base.py b/optimum/intel/ipex/modeling_base.py index af36d06f4..3263e31db 100644 --- a/optimum/intel/ipex/modeling_base.py +++ b/optimum/intel/ipex/modeling_base.py @@ -316,9 +316,9 @@ def prepare_inputs_for_generation(self, *args, **kwargs): return self.model.prepare_inputs_for_generation(*args, **kwargs) def generate(self, *args, **kwargs): - if is_ipex_version("<", "2.4.0") and self._add_patch and kwargs.get("assistant_model", None): + if self._add_patch and kwargs.get("assistant_model", None): raise ValueError( - f"Assisted decoding is not supported for patched models if ipex < 2.4, support methods are {_IPEX_EXPORTED_GENERATION_METHODS}" + f"Assisted decoding is not supported for patched models for now, support methods are {_IPEX_EXPORTED_GENERATION_METHODS}" ) # Patch functions to support ipex_paged cache if self._add_patch: