Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

disable flash_attn during export internvl2 #1105

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion optimum/exporters/openvino/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@
InputEmbeddingPatcher,
InternLM2Patcher,
InternLMModelPatcher,
InternVL2ChatLangModelPatcher,
InternVLChatImageEmbeddingModelPatcher,
JaisModelPatcher,
LlamaModelPatcher,
Expand Down Expand Up @@ -1642,7 +1643,11 @@ def with_behavior(
if behavior == InternVLChatConfigBehavior.LANGUAGE:
model_type = self._orig_config.llm_config.model_type
return get_vlm_text_generation_config(
model_type, self._orig_config.llm_config, self.int_dtype, self.float_dtype
model_type,
self._orig_config.llm_config,
self.int_dtype,
self.float_dtype,
InternVL2ChatLangModelPatcher,
)

if behavior == InternVLChatConfigBehavior.VISION_EMBEDDINGS:
Expand Down
81 changes: 81 additions & 0 deletions optimum/exporters/openvino/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import torch
import torch.nn.functional as F
from transformers import PreTrainedModel, TFPreTrainedModel
from transformers.modeling_outputs import BaseModelOutputWithPast, BaseModelOutputWithPooling
from transformers.utils import is_tf_available

Expand Down Expand Up @@ -2992,11 +2993,91 @@ def __init__(
model.__orig_forward = model.forward
model.forward = model.extract_feature

if model.vision_model.encoder.layers[0].attn.use_flash_attn:
for layer in model.vision_model.encoder.layers:
layer.attn._orig_use_flash_attn = layer.attn.use_flash_attn
layer.attn.use_flash_attn = False

super().__init__(config, model, model_kwargs)

def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
self._model.forward = self._model.__orig_forward
if hasattr(self._model.vision_model.encoder.layers[0].attn, "_orig_use_flash_attn"):
for layer in self._model.vision_model.encoder.layers:
layer.attn.use_flash_attn = layer.attn._orig_use_flash_attn


class InternVL2ChatLangModelPatcher(DecoderModelPatcher):
def __init__(
self, config: "OnnxConfig", model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Dict[str, Any]
):
model_type = model.config.model_type
patcher_for_model_type = {
"llama": LlamaModelPatcher,
"qwen2": UpdateCausalMaskModelPatcher,
"phi3": Phi3ModelPatcher,
"internlm2": InternLM2Patcher,
}
self._internal_patcher = None
self._patched_forward = None
internal_patcher_cls = patcher_for_model_type.get(model_type)
if internal_patcher_cls is not None:
self._internal_patcher = internal_patcher_cls(config, model, model_kwargs)
self._patched_forward = self._internal_patcher.patched_forward
super().__init__(config, model, model_kwargs)

@property
def patched_forward(self):
if self._internal_patcher is not None:
return self._internal_patcher.patched_forward
return self._patched_forward

@patched_forward.setter
def patched_forward(self, fn):
self._patched_forward = fn
if self._internal_patcher is not None:
self._internal_patcher.patched_forward = fn

def __enter__(self):
if is_torch_version(">=", "2.1.0"):
if self._model.config.model_type == "qwen2" and self._model.config._attn_implementation != "sdpa":
from transformers.models.qwen2.modeling_qwen2 import QWEN2_ATTENTION_CLASSES

sdpa_attn = QWEN2_ATTENTION_CLASSES["sdpa"]
self._model.config._orig_attn_implementation = self._model.config._attn_implementation
self._model.config._attn_implementation = "sdpa"

for layer in self._model.model.layers:
layer.self_attn._orig_forward = layer.self_attn.forward
layer.self_attn.forward = types.MethodType(sdpa_attn.forward, layer.self_attn)

if self._model.config.model_type == "llama" and self._model.config._attn_implementation != "sdpa":
self._model.config._orig_attn_implementation = self._model.config._attn_implementation
self._model.config._attn_implementation = "sdpa"
if is_transformers_version("<", "4.47"):
from transformers.models.llama.modeling_llama import LLAMA_ATTENTION_CLASSES

sdpa_attn = LLAMA_ATTENTION_CLASSES["sdpa"]
for layer in self._model.model.layers:
layer.self_attn._orig_forward = layer.self_attn.forward
layer.self_attn.forward = types.MethodType(sdpa_attn.forward, layer.self_attn)

if self._internal_patcher is not None:
return self._internal_patcher.__enter__()
return super().__enter__()

def __exit__(self, exc_type, exc_value, traceback):
if self._internal_patcher:
self._internal_patcher.__exit__(exc_type, exc_value, traceback)
else:
super().__exit__(exc_type, exc_value, traceback)

if hasattr(self._model.config, "_orig_attn_implementation"):
self._model.config._attn_implementation = self._model.config._orig_attn_implementation
for layer in self._model.model.layers:
if hasattr(layer.self_attn, "_orig_forward"):
layer.self_attn.forward = layer.self_attn._orig_forward


def llava_vision_embed_forward(self, pixel_values):
Expand Down