diff --git a/optimum/exporters/openvino/model_configs.py b/optimum/exporters/openvino/model_configs.py index 4b1dbb50b..3a389aa79 100644 --- a/optimum/exporters/openvino/model_configs.py +++ b/optimum/exporters/openvino/model_configs.py @@ -87,6 +87,7 @@ InputEmbeddingPatcher, InternLM2Patcher, InternLMModelPatcher, + InternVL2ChatLangModelPatcher, InternVLChatImageEmbeddingModelPatcher, JaisModelPatcher, LlamaModelPatcher, @@ -1642,7 +1643,11 @@ def with_behavior( if behavior == InternVLChatConfigBehavior.LANGUAGE: model_type = self._orig_config.llm_config.model_type return get_vlm_text_generation_config( - model_type, self._orig_config.llm_config, self.int_dtype, self.float_dtype + model_type, + self._orig_config.llm_config, + self.int_dtype, + self.float_dtype, + InternVL2ChatLangModelPatcher, ) if behavior == InternVLChatConfigBehavior.VISION_EMBEDDINGS: diff --git a/optimum/exporters/openvino/model_patcher.py b/optimum/exporters/openvino/model_patcher.py index e7a777938..fdb2d965b 100644 --- a/optimum/exporters/openvino/model_patcher.py +++ b/optimum/exporters/openvino/model_patcher.py @@ -21,6 +21,7 @@ import torch import torch.nn.functional as F +from transformers import PreTrainedModel, TFPreTrainedModel from transformers.modeling_outputs import BaseModelOutputWithPast, BaseModelOutputWithPooling from transformers.utils import is_tf_available @@ -2992,11 +2993,91 @@ def __init__( model.__orig_forward = model.forward model.forward = model.extract_feature + if model.vision_model.encoder.layers[0].attn.use_flash_attn: + for layer in model.vision_model.encoder.layers: + layer.attn._orig_use_flash_attn = layer.attn.use_flash_attn + layer.attn.use_flash_attn = False + super().__init__(config, model, model_kwargs) def __exit__(self, exc_type, exc_value, traceback): super().__exit__(exc_type, exc_value, traceback) self._model.forward = self._model.__orig_forward + if hasattr(self._model.vision_model.encoder.layers[0].attn, "_orig_use_flash_attn"): + for layer in self._model.vision_model.encoder.layers: + layer.attn.use_flash_attn = layer.attn._orig_use_flash_attn + + +class InternVL2ChatLangModelPatcher(DecoderModelPatcher): + def __init__( + self, config: "OnnxConfig", model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Dict[str, Any] + ): + model_type = model.config.model_type + patcher_for_model_type = { + "llama": LlamaModelPatcher, + "qwen2": UpdateCausalMaskModelPatcher, + "phi3": Phi3ModelPatcher, + "internlm2": InternLM2Patcher, + } + self._internal_patcher = None + self._patched_forward = None + internal_patcher_cls = patcher_for_model_type.get(model_type) + if internal_patcher_cls is not None: + self._internal_patcher = internal_patcher_cls(config, model, model_kwargs) + self._patched_forward = self._internal_patcher.patched_forward + super().__init__(config, model, model_kwargs) + + @property + def patched_forward(self): + if self._internal_patcher is not None: + return self._internal_patcher.patched_forward + return self._patched_forward + + @patched_forward.setter + def patched_forward(self, fn): + self._patched_forward = fn + if self._internal_patcher is not None: + self._internal_patcher.patched_forward = fn + + def __enter__(self): + if is_torch_version(">=", "2.1.0"): + if self._model.config.model_type == "qwen2" and self._model.config._attn_implementation != "sdpa": + from transformers.models.qwen2.modeling_qwen2 import QWEN2_ATTENTION_CLASSES + + sdpa_attn = QWEN2_ATTENTION_CLASSES["sdpa"] + self._model.config._orig_attn_implementation = self._model.config._attn_implementation + self._model.config._attn_implementation = "sdpa" + + for layer in self._model.model.layers: + layer.self_attn._orig_forward = layer.self_attn.forward + layer.self_attn.forward = types.MethodType(sdpa_attn.forward, layer.self_attn) + + if self._model.config.model_type == "llama" and self._model.config._attn_implementation != "sdpa": + self._model.config._orig_attn_implementation = self._model.config._attn_implementation + self._model.config._attn_implementation = "sdpa" + if is_transformers_version("<", "4.47"): + from transformers.models.llama.modeling_llama import LLAMA_ATTENTION_CLASSES + + sdpa_attn = LLAMA_ATTENTION_CLASSES["sdpa"] + for layer in self._model.model.layers: + layer.self_attn._orig_forward = layer.self_attn.forward + layer.self_attn.forward = types.MethodType(sdpa_attn.forward, layer.self_attn) + + if self._internal_patcher is not None: + return self._internal_patcher.__enter__() + return super().__enter__() + + def __exit__(self, exc_type, exc_value, traceback): + if self._internal_patcher: + self._internal_patcher.__exit__(exc_type, exc_value, traceback) + else: + super().__exit__(exc_type, exc_value, traceback) + + if hasattr(self._model.config, "_orig_attn_implementation"): + self._model.config._attn_implementation = self._model.config._orig_attn_implementation + for layer in self._model.model.layers: + if hasattr(layer.self_attn, "_orig_forward"): + layer.self_attn.forward = layer.self_attn._orig_forward def llava_vision_embed_forward(self, pixel_values):