Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add support of nanollava model
Browse files Browse the repository at this point in the history
eaidova committed Oct 24, 2024

Verified

This commit was signed with the committer’s verified signature.
W-Mai Benign X
1 parent 86598a6 commit b69a83d
Showing 6 changed files with 434 additions and 33 deletions.
162 changes: 161 additions & 1 deletion optimum/exporters/openvino/model_configs.py
Original file line number Diff line number Diff line change
@@ -17,7 +17,7 @@
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union

from packaging import version
from transformers import PretrainedConfig, PreTrainedModel, TFPreTrainedModel
from transformers import AutoConfig, PretrainedConfig, PreTrainedModel, TFPreTrainedModel
from transformers.utils import is_tf_available

from optimum.exporters.onnx.config import OnnxConfig, TextDecoderOnnxConfig, TextDecoderWithPositionIdsOnnxConfig
@@ -75,6 +75,7 @@
JaisModelPatcher,
LlamaModelPatcher,
LlavaImageEmbeddingModelPatcher,
LlavaQwen2ImageEmbeddingsModelPatcher,
MistralModelPatcher,
MixtralModelPatcher,
MPTModelPatcher,
@@ -1577,6 +1578,165 @@ def patch_model_for_export(
return InternVLChatImageEmbeddingModelPatcher(self, model, model_kwargs)


@register_in_tasks_manager(
"llava-qwen2", *["image-text-to-text", "text-generation", "text-generation-with-past"], library_name="transformers"
)
class LlavaQwen2OpenVINOConfig(OnnxConfig):
SUPPORTS_PAST = True
MIN_TRANSFORMERS_VERSION = version.parse("4.40.0")
SUPPORTED_BEHAVIORS = [model_type.value for model_type in LlavaConfigBehavior]
NORMALIZED_CONFIG_CLASS = NormalizedVisionConfig
DUMMY_INPUT_GENERATOR_CLASSES = (DummyVisionInputGenerator,)

def __init__(
self,
config: "PretrainedConfig",
task: str = "feature-extraction",
int_dtype: str = "int64",
float_dtype: str = "fp32",
behavior: LlavaConfigBehavior = LlavaConfigBehavior.VISION_EMBEDDINGS,
preprocessors: Optional[List[Any]] = None,
use_past: bool = False,
):
self._behavior = behavior
self._orig_config = config
if self._behavior == LlavaConfigBehavior.VISION_EMBEDDINGS:
config = AutoConfig.from_pretrained(config.mm_vision_tower, trust_remote_code=True)
if hasattr(config, "vision_config"):
config = config.vision_config
super().__init__(
config=config,
task=task,
int_dtype=int_dtype,
float_dtype=float_dtype,
preprocessors=preprocessors,
)

@property
def inputs(self) -> Dict[str, Dict[int, str]]:
if not self._behavior == LlavaConfigBehavior.VISION_EMBEDDINGS:
return {}
return {"pixel_values": {0: "batch_size", 2: "height", 3: "width"}}

@property
def outputs(self) -> Dict[str, Dict[int, str]]:
if not self._behavior == LlavaConfigBehavior.VISION_EMBEDDINGS:
return {}
return {"last_hidden_state": {0: "batch_size"}}

def get_model_for_behavior(self, model, behavior: Union[str, LlavaConfigBehavior]):
if isinstance(behavior, str) and not isinstance(behavior, LlavaConfigBehavior):
behavior = LlavaConfigBehavior(behavior)

if behavior == LlavaConfigBehavior.LANGUAGE:
model.forward = super(type(model), model).forward
return model

if behavior == LlavaConfigBehavior.VISION_EMBEDDINGS:
return model

if behavior == LlavaConfigBehavior.TEXT_EMBEDDINGS:
text_embedding = model.model.embed_tokens
text_embedding.config = model.model.config
return text_embedding

def with_behavior(
self,
behavior: Union[str, LlavaConfigBehavior],
):
"""
Creates a config for different behaviour.
Args:
behavior ([`ConfigBehavior`]):
The behavior to use for the new instance.
"""
if isinstance(behavior, str) and not isinstance(behavior, LlavaConfigBehavior):
behavior = LlavaConfigBehavior(behavior)

if behavior == LlavaConfigBehavior.TEXT_EMBEDDINGS:
model_type = self._orig_config.model_type.replace("llava-", "")
model_type = model_type.replace("_", "-")
if model_type not in TasksManager._SUPPORTED_MODEL_TYPE:
raise ValueError(
f"Unsupported language model type provided `{model_type}`. Please define custom export config"
)

if "text-generation-with-past" not in TasksManager._SUPPORTED_MODEL_TYPE[model_type]["openvino"]:
raise ValueError(
f"Export config for text generation for `{model_type}` is not available. Please define custom export config"
)
internal_export_config_class = TasksManager._SUPPORTED_MODEL_TYPE[model_type]["openvino"][
"text-generation-with-past"
]
internal_export_config = internal_export_config_class(
self._orig_config,
use_past=True,
use_past_in_inputs=True,
int_dtype=self.int_dtype,
float_dtype=self.float_dtype,
)
InputEmbedOpenvVINOConfig.NORMALIZED_CONFIG_CLASS = internal_export_config.NORMALIZED_CONFIG_CLASS
export_config = InputEmbedOpenvVINOConfig(
self._orig_config,
task="feature-extraction",
int_dtype=self.int_dtype,
float_dtype=self.float_dtype,
)
return export_config

if behavior == LlavaConfigBehavior.LANGUAGE:
model_type = self._orig_config.model_type.replace("llava-", "")
model_type = model_type.replace("_", "-")

if model_type not in TasksManager._SUPPORTED_MODEL_TYPE:
raise ValueError(
f"Unsupported language model type provided `{model_type}`. Please define custom export config"
)

if "text-generation-with-past" not in TasksManager._SUPPORTED_MODEL_TYPE[model_type]["openvino"]:
raise ValueError(
f"Export config for text generation for `{model_type}` is not available. Please define custom export config"
)
internal_export_config_class = TasksManager._SUPPORTED_MODEL_TYPE[model_type]["openvino"][
"text-generation-with-past"
]
internal_export_config = internal_export_config_class(
self._orig_config,
use_past=True,
use_past_in_inputs=True,
int_dtype=self.int_dtype,
float_dtype=self.float_dtype,
)
export_config = LMInputEmbedsConfigHelper(internal_export_config)
export_config._normalized_config = internal_export_config._normalized_config
return export_config

if behavior == LlavaConfigBehavior.VISION_EMBEDDINGS:
return self.__class__(
self._orig_config,
task=self.task,
int_dtype=self.int_dtype,
float_dtype=self.float_dtype,
behavior=behavior,
preprocessors=self._preprocessors,
)

def patch_model_for_export(
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
):
model_kwargs = model_kwargs or {}
if self._behavior != LlavaConfigBehavior.VISION_EMBEDDINGS:
return super().patch_model_for_export(model, model_kwargs)
return LlavaQwen2ImageEmbeddingsModelPatcher(self, model, model_kwargs)

def rename_ambiguous_inputs(self, inputs):
if self._behavior == LlavaConfigBehavior.VISION_EMBEDDINGS:
model_inputs = {}
model_inputs["images"] = inputs["pixel_values"]
return model_inputs
return super().rename_ambiguous_inputs(inputs)


class PooledProjectionsDummyInputGenerator(DummyInputGenerator):
SUPPORTED_INPUT_NAMES = ["pooled_projections"]

18 changes: 18 additions & 0 deletions optimum/exporters/openvino/model_patcher.py
Original file line number Diff line number Diff line change
@@ -2743,3 +2743,21 @@ def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
if hasattr(self._model.pos_embed, "_orig_forward"):
self._model.pos_embed.forward = self._model.pos_embed._orig_forward


class LlavaQwen2ImageEmbeddingsModelPatcher(ModelPatcher):
def __init__(
self,
config: "OnnxConfig",
model: Union["PreTrainedModel", "TFPreTrainedModel"],
model_kwargs: Dict[str, Any],
):
model.__orig_forward = model.forward
model.forward = model.encode_images
super().__init__(config, model, model_kwargs)
if not self._model.get_vision_tower().is_loaded:
self._model.get_vision_tower().load_model()

def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
self._model.forward = self._model.__orig_forward
2 changes: 1 addition & 1 deletion optimum/exporters/openvino/utils.py
Original file line number Diff line number Diff line change
@@ -208,4 +208,4 @@ def get_submodels(model):
return custom_export, fn_get_submodels


MULTI_MODAL_TEXT_GENERATION_MODELS = ["llava", "llava-next", "internvl-chat"]
MULTI_MODAL_TEXT_GENERATION_MODELS = ["llava", "llava-next", "llava-qwen2", "internvl-chat"]
Loading

0 comments on commit b69a83d

Please sign in to comment.