From 3ee174ca9e8a5bf904b2ef3acb1634d3614e6f3d Mon Sep 17 00:00:00 2001 From: Ekaterina Aidova Date: Mon, 29 Jul 2024 22:29:55 +0400 Subject: [PATCH] Fix update causal mask for transformers 4.42 (#852) * fix update causal mask for transformers 4.42 * more models * revert rope for phi3 * fix phi3 * phi3 issue --- optimum/exporters/openvino/model_configs.py | 39 +++++ optimum/exporters/openvino/model_patcher.py | 170 ++++++++++++++++++-- 2 files changed, 199 insertions(+), 10 deletions(-) diff --git a/optimum/exporters/openvino/model_configs.py b/optimum/exporters/openvino/model_configs.py index c968e92b2c..b8aed025b1 100644 --- a/optimum/exporters/openvino/model_configs.py +++ b/optimum/exporters/openvino/model_configs.py @@ -60,6 +60,7 @@ PersimmonModelPatcher, Phi3ModelPatcher, QwenModelPatcher, + UpdateCausalMaskModelPatcher, XverseModelPatcher, ) @@ -119,6 +120,11 @@ class Qwen2OpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig): DUMMY_PKV_GENERATOR_CLASS = MistralDummyPastKeyValuesGenerator NORMALIZED_CONFIG_CLASS = NormalizedTextConfig + def patch_model_for_export( + self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None + ) -> "ModelPatcher": + return UpdateCausalMaskModelPatcher(self, model, model_kwargs=model_kwargs) + @register_in_tasks_manager("qwen2-moe", *["text-generation", "text-generation-with-past"], library_name="transformers") class Qwen2MoEOpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig): @@ -128,6 +134,11 @@ class Qwen2MoEOpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig): DUMMY_PKV_GENERATOR_CLASS = MistralDummyPastKeyValuesGenerator NORMALIZED_CONFIG_CLASS = NormalizedTextConfig + def patch_model_for_export( + self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None + ) -> "ModelPatcher": + return UpdateCausalMaskModelPatcher(self, model, model_kwargs=model_kwargs) + @register_in_tasks_manager("minicpm", *["text-generation", "text-generation-with-past"], library_name="transformers") class MiniCPMOpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig): @@ -146,6 +157,11 @@ class StableLMOpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig): DUMMY_PKV_GENERATOR_CLASS = MistralDummyPastKeyValuesGenerator NORMALIZED_CONFIG_CLASS = NormalizedTextConfig + def patch_model_for_export( + self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None + ) -> "ModelPatcher": + return UpdateCausalMaskModelPatcher(self, model, model_kwargs=model_kwargs) + class ChatGLM2DummyPastKeyValuesGenerator(DummyPastKeyValuesGenerator): def __init__( @@ -468,6 +484,11 @@ class Starcoder2OpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig): DUMMY_PKV_GENERATOR_CLASS = MistralDummyPastKeyValuesGenerator NORMALIZED_CONFIG_CLASS = NormalizedTextConfig + def patch_model_for_export( + self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None + ) -> "ModelPatcher": + return UpdateCausalMaskModelPatcher(self, model, model_kwargs=model_kwargs) + @register_in_tasks_manager("internlm2", *["text-generation", "text-generation-with-past"], library_name="transformers") class InternLM2OpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig): @@ -532,6 +553,24 @@ def patch_model_for_export( return Phi3ModelPatcher(self, model, model_kwargs=model_kwargs) +@register_in_tasks_manager( + "phi", + *[ + "feature-extraction", + "feature-extraction-with-past", + "text-generation", + "text-generation-with-past", + "text-classification", + ], + library_name="transformers", +) +class PhiOpenVINOConfig(PhiOnnxConfig): + def patch_model_for_export( + self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None + ) -> "ModelPatcher": + return UpdateCausalMaskModelPatcher(self, model, model_kwargs=model_kwargs) + + class OVFalconDummyPastKeyValuesGenerator(FalconDummyPastKeyValuesGenerator): def __init__( self, diff --git a/optimum/exporters/openvino/model_patcher.py b/optimum/exporters/openvino/model_patcher.py index 377a0fbf43..56d5094b6b 100644 --- a/optimum/exporters/openvino/model_patcher.py +++ b/optimum/exporters/openvino/model_patcher.py @@ -101,6 +101,11 @@ def patch_model_with_bettertransformer(model): return model +def patch_update_causal_mask(model, transformers_version): + if is_transformers_version(">=", transformers_version): + model.model._update_causal_mask = types.MethodType(_llama_gemma_update_causal_mask, model.model) + + def _mixtral_sparse_moe_block_forward(self, hidden_states: torch.Tensor) -> torch.Tensor: """ """ batch_size, sequence_length, hidden_dim = hidden_states.shape @@ -144,6 +149,8 @@ def _mixtral_sparse_moe_block_forward(self, hidden_states: torch.Tensor) -> torc class MixtralModelPatcher(DecoderModelPatcher): def __enter__(self): super().__enter__() + patch_update_causal_mask(self._model, "4.42.0") + for layer in self._model.model.layers: layer.block_sparse_moe._unpatched_forward = layer.block_sparse_moe.forward layer.block_sparse_moe.forward = types.MethodType( @@ -152,6 +159,9 @@ def __enter__(self): def __exit__(self, exc_type, exc_value, traceback): super().__exit__(exc_type, exc_value, traceback) + if hasattr(self._model.model, "_orig_update_causal_mask"): + self._model.model._update_causal_mask = self._model.model._orig_update_causal_mask + for layer in self._model.model.layers: layer.block_sparse_moe.forward = layer.block_sparse_moe._unpatched_forward @@ -549,11 +559,9 @@ def __enter__(self): # llama/gemma has some accuracy issues with bf16 with transformers >= 4.39 # fill causal mask in slightly different way for avoid overflow on some platforms + patch_update_causal_mask(self._model, "4.39.0") + if is_transformers_version(">=", "4.39.0"): - self._model.model._orig_update_causal_mask = self._model.model._update_causal_mask - self._model.model._update_causal_mask = types.MethodType( - _llama_gemma_update_causal_mask, self._model.model - ) register_sin_cos_buffer(self._model) def __exit__(self, exc_type, exc_value, traceback): @@ -620,7 +628,7 @@ def _mistral_update_causal_mask( return None dtype, device = input_tensor.dtype, input_tensor.device - min_dtype = torch.finfo(dtype).min + min_dtype = torch.finfo(torch.float16).min sequence_length = input_tensor.shape[1] # SlidingWindowCache if using_sliding_window_cache: @@ -1328,6 +1336,128 @@ def __exit__(self, exc_type, exc_value, traceback): block.attention.forward = block.attention._orig_forward +def phi3_442_forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, +) -> Union[Tuple, BaseModelOutputWithPast]: + from transformers.cache_utils import Cache, DynamicCache + from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape[:2] + elif inputs_embeds is not None: + batch_size, seq_length = inputs_embeds.shape[:2] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + past_key_values_length = 0 + + if use_cache: + use_legacy_cache = not isinstance(past_key_values, Cache) + if use_legacy_cache: + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + past_key_values_length = past_key_values.get_usable_length(seq_length) + + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = torch.arange( + past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device + ) + position_ids = position_ids.unsqueeze(0).view(-1, seq_length) + else: + position_ids = position_ids.view(-1, seq_length).long() + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if attention_mask is not None and self._attn_implementation == "flash_attention_2" and use_cache: + is_padding_right = attention_mask[:, -1].sum().item() != batch_size + if is_padding_right: + raise ValueError( + "You are attempting to perform batched generation with padding_side='right'" + " this may lead to unexpected behaviour for Flash Attention version of Phi3. Make sure to " + " call `tokenizer.padding_side = 'left'` before tokenizing the input. " + ) + + if self._attn_implementation == "flash_attention_2": + # 2d mask is passed through the layers + attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None + else: + # 4d mask is passed through the layers + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, + (batch_size, seq_length), + inputs_embeds, + past_key_values_length, + sliding_window=self.config.sliding_window, + ) + + hidden_states = inputs_embeds + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache = layer_outputs[2 if output_attentions else 1] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = None + if use_cache: + next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + # Adapted from https://github.com/huggingface/transformers/blob/ccdabc5642bf84849af93f591e207dc625c8e1e1/src/transformers/models/phi3/modeling_phi3.py#L729 def _phi3_self_attn_sdpa_forward( self, @@ -1373,7 +1503,7 @@ def _phi3_self_attn_sdpa_forward( query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models + cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) key_states = repeat_kv(key_states, self.num_key_value_groups) @@ -1411,6 +1541,11 @@ def _phi3_self_attn_sdpa_forward( class Phi3ModelPatcher(DecoderModelPatcher): def __enter__(self): super().__enter__() + + if is_transformers_version(">=", "4.42.0"): + self._model.model._orig_forward = self._model.model.forward + self._model.model.forward = types.MethodType(phi3_442_forward, self._model.model) + # https://github.com/huggingface/transformers/blob/30ee508c6c92a1c0aa0281d193c7c0fb815b8d2f/src/transformers/models/phi3/modeling_phi3.py#L113 # init inv_freq for torchscript tracing for layer in self._model.model.layers: @@ -1425,15 +1560,15 @@ def __enter__(self): rotary_emb.base ** (torch.arange(0, rotary_emb.dim, 2, dtype=torch.int64).float() / rotary_emb.dim) ) - # phi3 has issue with bf16 inference, precollect sin/cos for rotary_position_embedding for avoid accuracy issues - register_sin_cos_buffer(self._model) - def __exit__(self, exc_type, exc_value, traceback): super().__exit__(exc_type, exc_value, traceback) + if hasattr(self._model.model, "_orig_forward"): + self._model.model.forward = self._model.model._orig_forward + if hasattr(self._model.model, "_orig_update_causal_mask"): + self._model.model._update_causal_mask = self._model.model._orig_update_causal_mask for layer in self._model.model.layers: if hasattr(layer.self_attn, "_orig_forward"): layer.self_attn.forward = layer.self_attn._orig_forward - layer.self_attn.rotary_emb.forward = layer.self_attn.rotary_emb._orig_forward def _aquila_self_attn_sdpa_forward( @@ -2089,6 +2224,8 @@ def _persimmon_self_attn_sdpa_forward( class PersimmonModelPatcher(DecoderModelPatcher): def __enter__(self): super().__enter__() + patch_update_causal_mask(self._model, "4.42.0") + for layer in self._model.model.layers: if is_torch_version(">=", "2.1.0"): orig_self_attn_fwd = layer.self_attn.forward @@ -2097,6 +2234,8 @@ def __enter__(self): def __exit__(self, exc_type, exc_value, traceback): super().__exit__(exc_type, exc_value, traceback) + if hasattr(self._model.model, "_orig_update_causal_mask"): + self._model.model._update_causal_mask = self._model.model._orig_update_causal_mask for layer in self._model.model.layers: if hasattr(layer.self_attn, "_orig_forward"): layer.self_attn.forward = layer.self_attn._orig_forward @@ -2221,3 +2360,14 @@ def __exit__(self, exc_type, exc_value, traceback): if hasattr(layer.attn, "_orig_attn"): layer.attn._attn = layer.attn._orig_attn layer.attn.forward = layer.attn._orig_forward + + +class UpdateCausalMaskModelPatcher(DecoderModelPatcher): + def __enter__(self): + super().__enter__() + patch_update_causal_mask(self._model, "4.42.0") + + def __exit__(self, exc_type, exc_value, traceback): + super().__exit__(exc_type, exc_value, traceback) + if hasattr(self._model.model, "_orig_update_causal_mask"): + self._model.model._update_causal_mask = self._model.model._orig_update_causal_mask