Skip to content

Commit

Permalink
Add phi3 export openvino (#686)
Browse files Browse the repository at this point in the history
* support hpi3 export openvino

* fix inv_freq tracing based on latest changes in model

* add test model

* Update optimum/exporters/openvino/model_patcher.py

Co-authored-by: Ella Charlaix <[email protected]>

---------

Co-authored-by: Ella Charlaix <[email protected]>
  • Loading branch information
eaidova and echarlaix authored Apr 25, 2024
1 parent 33fc7b7 commit 41876fb
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 2 deletions.
21 changes: 20 additions & 1 deletion optimum/exporters/openvino/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from transformers.utils import is_tf_available

from optimum.exporters.onnx.config import TextDecoderOnnxConfig, TextDecoderWithPositionIdsOnnxConfig
from optimum.exporters.onnx.model_configs import FalconOnnxConfig, GemmaOnnxConfig, LlamaOnnxConfig
from optimum.exporters.onnx.model_configs import FalconOnnxConfig, GemmaOnnxConfig, LlamaOnnxConfig, PhiOnnxConfig
from optimum.exporters.tasks import TasksManager
from optimum.utils import DEFAULT_DUMMY_SHAPES
from optimum.utils.input_generators import (
Expand All @@ -37,6 +37,7 @@
GemmaModelPatcher,
LlamaModelPatcher,
MixtralModelPatcher,
Phi3ModelPatcher,
QwenModelPatcher,
)

Expand Down Expand Up @@ -440,6 +441,24 @@ class OrionOpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig):
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig


@register_in_tasks_manager(
"phi3",
*[
"feature-extraction",
"feature-extraction-with-past",
"text-generation",
"text-generation-with-past",
"text-classification",
],
library_name="transformers",
)
class Phi3OpenVINOConfig(PhiOnnxConfig):
def patch_model_for_export(
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
) -> "ModelPatcher":
return Phi3ModelPatcher(self, model, model_kwargs=model_kwargs)


class OVFalconDummyPastKeyValuesGenerator(FalconDummyPastKeyValuesGenerator):
def __init__(
self,
Expand Down
14 changes: 14 additions & 0 deletions optimum/exporters/openvino/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -623,3 +623,17 @@ def __init__(
# model has first inference buffers initialization
if hasattr(self._model.lm_head, "first_flag"):
self._model(torch.ones((1, 10), dtype=torch.int64), torch.ones((1, 10), dtype=torch.int64))


class Phi3ModelPatcher(DecoderModelPatcher):
def __enter__(self):
super().__enter__()

# https://github.com/huggingface/transformers/blob/30ee508c6c92a1c0aa0281d193c7c0fb815b8d2f/src/transformers/models/phi3/modeling_phi3.py#L113
# init inv_freq for torchscript tracing
for layer in self._model.model.layers:
if layer.self_attn.rotary_emb.inv_freq is None:
rotary_emb = layer.self_attn.rotary_emb
layer.self_attn.rotary_emb.inv_freq = 1.0 / (
rotary_emb.base ** (torch.arange(0, rotary_emb.dim, 2, dtype=torch.int64).float() / rotary_emb.dim)
)
3 changes: 2 additions & 1 deletion tests/openvino/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -535,13 +535,14 @@ class OVModelForCausalLMIntegrationTest(unittest.TestCase):
"stablelm",
"starcoder2",
"phi",
"phi3",
"internlm2",
"orion",
"falcon",
"falcon-40b",
)
GENERATION_LENGTH = 100
REMOTE_CODE_MODELS = ("chatglm", "minicpm", "baichuan2", "jais", "qwen", "internlm2", "olmo", "orion")
REMOTE_CODE_MODELS = ("chatglm", "minicpm", "baichuan2", "jais", "qwen", "internlm2", "olmo", "orion", "phi3")

@parameterized.expand(SUPPORTED_ARCHITECTURES)
def test_compare_to_transformers(self, model_arch):
Expand Down
1 change: 1 addition & 0 deletions tests/openvino/utils_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@
"pegasus": "hf-internal-testing/tiny-random-pegasus",
"pix2struct": "fxmarty/pix2struct-tiny-random",
"phi": "echarlaix/tiny-random-PhiForCausalLM",
"phi3": "katuni4ka/tiny-random-phi3",
"poolformer": "hf-internal-testing/tiny-random-PoolFormerModel",
"qwen": "katuni4ka/tiny-random-qwen",
"qwen2": "Qwen/Qwen1.5-0.5B",
Expand Down

0 comments on commit 41876fb

Please sign in to comment.