From 936d2729a1054f02d1ca927d2ed5fb03c19d8b2e Mon Sep 17 00:00:00 2001 From: Ekaterina Aidova Date: Mon, 28 Oct 2024 20:43:21 +0400 Subject: [PATCH] Restore SDPA in Gemma2 models for transformers > 4.45 (#976) * Restore SDPA in Gemma2 models for transformers > 4.45 * Update tests/openvino/test_modeling.py * Update tests/openvino/test_modeling.py --- optimum/exporters/openvino/model_patcher.py | 20 ++++++++++++++++++++ tests/openvino/test_modeling.py | 8 ++++++++ 2 files changed, 28 insertions(+) diff --git a/optimum/exporters/openvino/model_patcher.py b/optimum/exporters/openvino/model_patcher.py index 3bc9452ff9..7e5cd76a76 100644 --- a/optimum/exporters/openvino/model_patcher.py +++ b/optimum/exporters/openvino/model_patcher.py @@ -2505,6 +2505,26 @@ def patched_forward(*args, **kwargs): self.patched_forward = patched_forward + def __enter__(self): + super().__enter__() + if is_transformers_version(">=", "4.45.0"): + from transformers.models.gemma2.modeling_gemma2 import GEMMA2_ATTENTION_CLASSES + + sdpa_attn = GEMMA2_ATTENTION_CLASSES["sdpa"] + eager_attn = GEMMA2_ATTENTION_CLASSES["eager"] + + for layer in self._model.model.layers: + if isinstance(layer.self_attn, eager_attn): + layer.self_attn._orig_forward = layer.self_attn.forward + layer.self_attn.forward = types.MethodType(sdpa_attn.forward, layer.self_attn) + + def __exit__(self, exc_type, exc_value, traceback): + super().__exit__(exc_type, exc_value, traceback) + if is_transformers_version(">=", "4.45.0"): + for layer in self._model.model.layers: + if hasattr(layer.self_attn, "_orig_forward"): + layer.self_attn.forward = layer.self_attn._orig_forward + def _decilm_attn_forward( self, diff --git a/tests/openvino/test_modeling.py b/tests/openvino/test_modeling.py index 119e004035..082ffef285 100644 --- a/tests/openvino/test_modeling.py +++ b/tests/openvino/test_modeling.py @@ -863,6 +863,10 @@ def test_compare_to_transformers(self, model_arch): if model_arch in self.REMOTE_CODE_MODELS: model_kwargs = {"trust_remote_code": True} + # starting from transformers 4.45.0 gemma2 uses eager attention by default, while ov - sdpa + if model_arch == "gemma2" and is_transformers_version(">=", "4.45.0"): + model_kwargs["attn_implementation"] = "sdpa" + ov_model = OVModelForCausalLM.from_pretrained(model_id, export=True, ov_config=F32_CONFIG, **model_kwargs) self.assertIsInstance(ov_model.config, PretrainedConfig) self.assertTrue(ov_model.use_cache) @@ -1094,6 +1098,10 @@ def test_beam_search(self, model_arch): "config": AutoConfig.from_pretrained(model_id, trust_remote_code=True), "trust_remote_code": True, } + + # starting from transformers 4.45.0 gemma2 uses eager attention by default, while ov - sdpa + if model_arch == "gemma2" and is_transformers_version(">=", "4.45.0"): + model_kwargs["attn_implementation"] = "sdpa" # Qwen tokenizer does not support padding, chatglm, glm4 testing models produce nan that incompatible with beam search if model_arch in ["qwen", "chatglm", "glm4"]: return