diff --git a/optimum/exporters/openvino/model_patcher.py b/optimum/exporters/openvino/model_patcher.py index 29ef7db892..152e4db4ee 100644 --- a/optimum/exporters/openvino/model_patcher.py +++ b/optimum/exporters/openvino/model_patcher.py @@ -109,13 +109,14 @@ def patch_model_with_bettertransformer(model): return model -def patch_update_causal_mask(model, transformers_version, inner_model_name="model"): +def patch_update_causal_mask(model, transformers_version, inner_model_name="model", patch_fn=None): if is_transformers_version(">=", transformers_version): inner_model = getattr(model, inner_model_name, None) if inner_model is not None: if hasattr(inner_model, "_update_causal_mask"): inner_model._orig_update_causal_mask = inner_model._update_causal_mask - inner_model._update_causal_mask = types.MethodType(_llama_gemma_update_causal_mask, inner_model) + patch_fn = patch_fn or _llama_gemma_update_causal_mask + inner_model._update_causal_mask = types.MethodType(patch_fn, inner_model) def unpatch_update_causal_mask(model, inner_model_name="model"): @@ -2431,6 +2432,107 @@ def __enter__(self): _reinitialize_cos_sin_cached_fp32(layer.self_attn.rotary_emb) +def _falcon_update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: "Cache", + output_attentions: bool, + head_mask: torch.Tensor, + alibi: torch.Tensor, +): + # copied from https://github.com/huggingface/transformers/blob/a30c865f991dfec9452cc64bd9a97bfbb96be036/src/transformers/models/falcon/modeling_falcon.py#L1130 + from transformers.cache_utils import StaticCache + from transformers.modeling_attn_mask_utils import AttentionMaskConverter + + # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static + # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. + # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using + # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114 + + if hasattr(self, "_prepare_4d_causal_attention_mask_with_cache_position"): + _prepare_4d_causal_attention_mask_with_cache_position = ( + self._prepare_4d_causal_attention_mask_with_cache_position + ) + else: + from transformers.models.falcon.modeling_falcon import _prepare_4d_causal_attention_mask_with_cache_position + + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if ( + self.config._attn_implementation == "sdpa" + and not using_static_cache + and not output_attentions + and head_mask is None + and alibi is None + ): + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + # difference from original, replace torch.finfo(dtype).min to float16 for prevent overflow for fp16/bf16 execution + min_dtype = torch.finfo(torch.float16).min + batch_size, sequence_length, _ = input_tensor.shape + if using_static_cache: + target_length = past_key_values.get_max_length() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + device=device, + min_dtype=min_dtype, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + ) + + # We take care to integrate alibi bias in the causal_mask here + if head_mask is None and alibi is not None: + alibi = alibi.reshape(batch_size, -1, *alibi.shape[1:]) + causal_mask = torch.masked_fill( + alibi / math.sqrt(self.config.hidden_size // self.num_heads), + causal_mask < -1, + min_dtype, + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type == "cuda" + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + class FalconModelPatcher(DecoderModelPatcher): def __enter__(self): super().__enter__() @@ -2438,7 +2540,7 @@ def __enter__(self): for layer in self._model.transformer.h: _reinitialize_cos_sin_cached_fp32(layer.self_attention.rotary_emb) else: - patch_update_causal_mask(self._model, "4.45.0", "transformer") + patch_update_causal_mask(self._model, "4.45.0", "transformer", _falcon_update_causal_mask) def __exit__(self, exc_type, exc_value, traceback): super().__exit__(exc_type, exc_value, traceback)