From 87c431c9eb777a220a417214df1b9e6a1b957108 Mon Sep 17 00:00:00 2001 From: Ekaterina Aidova Date: Mon, 23 Dec 2024 19:42:41 +0400 Subject: [PATCH] restore input format for stable diffusion and export configs mapping (#1091) * restore input format for stable diffusion * update configs registration * fix shapes for timestep * align names for t5 --- optimum/exporters/openvino/model_configs.py | 24 ++++++++++++++++++++- 1 file changed, 23 insertions(+), 1 deletion(-) diff --git a/optimum/exporters/openvino/model_configs.py b/optimum/exporters/openvino/model_configs.py index 02a8c300a8..aca2359864 100644 --- a/optimum/exporters/openvino/model_configs.py +++ b/optimum/exporters/openvino/model_configs.py @@ -1250,6 +1250,7 @@ def patch_model_for_export( @register_in_tasks_manager("clip-text-model", *["feature-extraction"], library_name="transformers") @register_in_tasks_manager("clip-text-model", *["feature-extraction"], library_name="diffusers") +@register_in_tasks_manager("clip-text", *["feature-extraction"], library_name="diffusers") class CLIPTextOpenVINOConfig(CLIPTextOnnxConfig): def patch_model_for_export( self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None @@ -1795,12 +1796,31 @@ def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int ) +class DummyUnetTimestepInputGenerator(DummyTimestepInputGenerator): + def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"): + if input_name != "timestep": + return super().generate(input_name, framework, int_dtype, float_dtype) + shape = [self.batch_size] + return self.random_int_tensor(shape, max_value=self.vocab_size, framework=framework, dtype=int_dtype) + + @register_in_tasks_manager("unet", *["semantic-segmentation"], library_name="diffusers") +@register_in_tasks_manager("unet-2d-condition", *["semantic-segmentation"], library_name="diffusers") class UnetOpenVINOConfig(UNetOnnxConfig): - DUMMY_INPUT_GENERATOR_CLASSES = (DummyUnetVisionInputGenerator,) + UNetOnnxConfig.DUMMY_INPUT_GENERATOR_CLASSES[1:] + DUMMY_INPUT_GENERATOR_CLASSES = ( + DummyUnetVisionInputGenerator, + DummyUnetTimestepInputGenerator, + ) + UNetOnnxConfig.DUMMY_INPUT_GENERATOR_CLASSES[2:] + + @property + def inputs(self) -> Dict[str, Dict[int, str]]: + common_inputs = super().inputs + common_inputs["timestep"] = {0: "batch_size"} + return common_inputs @register_in_tasks_manager("sd3-transformer", *["semantic-segmentation"], library_name="diffusers") +@register_in_tasks_manager("sd3-transformer-2d", *["semantic-segmentation"], library_name="diffusers") class SD3TransformerOpenVINOConfig(UNetOnnxConfig): DUMMY_INPUT_GENERATOR_CLASSES = ( (DummyTransformerTimestpsInputGenerator,) @@ -1830,6 +1850,7 @@ def rename_ambiguous_inputs(self, inputs): @register_in_tasks_manager("t5-encoder-model", *["feature-extraction"], library_name="diffusers") +@register_in_tasks_manager("t5-encoder", *["feature-extraction"], library_name="diffusers") class T5EncoderOpenVINOConfig(CLIPTextOpenVINOConfig): pass @@ -1905,6 +1926,7 @@ def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int @register_in_tasks_manager("flux-transformer", *["semantic-segmentation"], library_name="diffusers") +@register_in_tasks_manager("flux-transformer-2d", *["semantic-segmentation"], library_name="diffusers") class FluxTransformerOpenVINOConfig(SD3TransformerOpenVINOConfig): DUMMY_INPUT_GENERATOR_CLASSES = ( DummyTransformerTimestpsInputGenerator,