diff --git a/optimum/exporters/openvino/convert.py b/optimum/exporters/openvino/convert.py index 4614d9368..0d6c2b436 100644 --- a/optimum/exporters/openvino/convert.py +++ b/optimum/exporters/openvino/convert.py @@ -1013,11 +1013,10 @@ def _get_submodels_and_export_configs( def get_diffusion_models_for_export_ext( pipeline: "DiffusionPipeline", int_dtype: str = "int64", float_dtype: str = "fp32", exporter: str = "openvino" ): -<<<<<<< HEAD is_sdxl = pipeline.__class__.__name__.startswith("StableDiffusionXL") is_sd3 = pipeline.__class__.__name__.startswith("StableDiffusion3") is_flux = pipeline.__class__.__name__.startswith("Flux") - is_sana = pipeline.__class__.__name__.startswith("Sana") + is_sana = pipeline.__class__.__name__.startswith("Sana") is_sd = pipeline.__class__.__name__.startswith("StableDiffusion") and not is_sd3 is_lcm = pipeline.__class__.__name__.startswith("LatentConsistencyModel") @@ -1036,51 +1035,6 @@ def get_diffusion_models_for_export_ext( models_for_export = get_sd3_models_for_export(pipeline, exporter, int_dtype, float_dtype) elif is_flux: models_for_export = get_flux_models_for_export(pipeline, exporter, int_dtype, float_dtype) -======= - if is_diffusers_version(">=", "0.29.0"): - from diffusers import StableDiffusion3Img2ImgPipeline, StableDiffusion3Pipeline - - sd3_pipes = [StableDiffusion3Pipeline, StableDiffusion3Img2ImgPipeline] - if is_diffusers_version(">=", "0.30.0"): - from diffusers import StableDiffusion3InpaintPipeline - - sd3_pipes.append(StableDiffusion3InpaintPipeline) - - is_sd3 = isinstance(pipeline, tuple(sd3_pipes)) - else: - is_sd3 = False - - if is_diffusers_version(">=", "0.30.0"): - from diffusers import FluxPipeline - - flux_pipes = [FluxPipeline] - - if is_diffusers_version(">=", "0.31.0"): - from diffusers import FluxImg2ImgPipeline, FluxInpaintPipeline - - flux_pipes.extend([FluxPipeline, FluxImg2ImgPipeline, FluxInpaintPipeline]) - - if is_diffusers_version(">=", "0.32.0"): - from diffusers import FluxFillPipeline - - flux_pipes.append(FluxFillPipeline) - - is_flux = isinstance(pipeline, tuple(flux_pipes)) - else: - is_flux = False - - if is_diffusers_version(">=", "0.32.0"): - from diffusers import SanaPipeline - - is_sana = isinstance(pipeline, SanaPipeline) - else: - is_sana = False - - if not any([is_sana, is_flux, is_sd3]): - return None, get_diffusion_models_for_export(pipeline, int_dtype, float_dtype, exporter) - if is_sd3: - models_for_export = get_sd3_models_for_export(pipeline, exporter, int_dtype, float_dtype) ->>>>>>> add pipeline elif is_sana: models_for_export = get_sana_models_for_export(pipeline, exporter, int_dtype, float_dtype) else: diff --git a/optimum/exporters/openvino/model_configs.py b/optimum/exporters/openvino/model_configs.py index 1f7695cf8..e73039159 100644 --- a/optimum/exporters/openvino/model_configs.py +++ b/optimum/exporters/openvino/model_configs.py @@ -106,6 +106,7 @@ Qwen2VLVisionEmbMergerPatcher, QwenModelPatcher, RotaryEmbPatcher, + SanaTextEncoderModelPatcher, StatefulSeq2SeqDecoderPatcher, UpdateCausalMaskModelPatcher, XverseModelPatcher, @@ -1903,6 +1904,11 @@ def inputs(self) -> Dict[str, Dict[int, str]]: "attention_mask": {0: "batch_size", 1: "sequence_length"}, } + def patch_model_for_export( + self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None + ) -> ModelPatcher: + return SanaTextEncoderModelPatcher(self, model, model_kwargs) + class DummySanaSeq2SeqDecoderTextWithEncMaskInputGenerator(DummySeq2SeqDecoderTextInputGenerator): SUPPORTED_INPUT_NAMES = ( diff --git a/optimum/exporters/openvino/model_patcher.py b/optimum/exporters/openvino/model_patcher.py index e7a777938..08bc14988 100644 --- a/optimum/exporters/openvino/model_patcher.py +++ b/optimum/exporters/openvino/model_patcher.py @@ -21,9 +21,11 @@ import torch import torch.nn.functional as F +from transformers import PreTrainedModel, TFPreTrainedModel from transformers.modeling_outputs import BaseModelOutputWithPast, BaseModelOutputWithPooling from transformers.utils import is_tf_available +from optimum.exporters.onnx.base import OnnxConfig from optimum.exporters.onnx.model_patcher import ( DecoderModelPatcher, ModelPatcher, @@ -114,9 +116,11 @@ def patch_model_with_bettertransformer(model): return model -def patch_update_causal_mask(model, transformers_version, inner_model_name="model", patch_fn=None): +def patch_update_causal_mask( + model, transformers_version, inner_model_name="model", patch_fn=None, patch_extrnal_model=False +): if is_transformers_version(">=", transformers_version): - inner_model = getattr(model, inner_model_name, None) + inner_model = getattr(model, inner_model_name, None) if not patch_extrnal_model else model if inner_model is not None: if hasattr(inner_model, "_update_causal_mask"): inner_model._orig_update_causal_mask = inner_model._update_causal_mask @@ -124,8 +128,8 @@ def patch_update_causal_mask(model, transformers_version, inner_model_name="mode inner_model._update_causal_mask = types.MethodType(patch_fn, inner_model) -def unpatch_update_causal_mask(model, inner_model_name="model"): - inner_model = getattr(model, inner_model_name, None) +def unpatch_update_causal_mask(model, inner_model_name="model", patch_extrnal_model=False): + inner_model = getattr(model, inner_model_name, None) if not patch_extrnal_model else model if inner_model is not None and hasattr(inner_model, "._orig_update_causal_mask"): inner_model._update_causal_mask = inner_model._orig_update_causal_mask @@ -3791,3 +3795,29 @@ def patched_forward(*args, **kwargs): model.forward = patched_forward super().__init__(config, model, model_kwargs) + + +class SanaTextEncoderModelPatcher(ModelPatcher): + def __enter__(self): + super().__enter__() + patch_update_causal_mask(self._model, "4.39.0", None, patch_extrnal_model=True) + + if self._model.config._attn_implementation != "sdpa": + self._model.config._orig_attn_implementation = self._model.config._attn_implementation + self._model.config._attn_implementation = "sdpa" + if is_transformers_version("<", "4.47.0"): + from transformers.models.gemma2.modeling_gemma2 import GEMMA2_ATTENTION_CLASSES + + sdpa_attn = GEMMA2_ATTENTION_CLASSES["sdpa"] + for layer in self._model.layers: + layer.self_attn._orig_forward = layer.self_attn.forward + layer.self_attn.forward = types.MethodType(sdpa_attn.forward, layer.self_attn) + + def __exit__(self, exc_type, exc_value, traceback): + super().__exit__(exc_type, exc_value, traceback) + unpatch_update_causal_mask(self._model, None, True) + if hasattr(self._model.config, "_orig_attn_implementation"): + self._model.config._attn_implementation = self._model.config._orig_attn_implementation + for layer in self._model.layers: + if hasattr(layer.self_attn, "_orig_forward"): + layer.self_attn.forward = layer.self_attn._orig_forward diff --git a/optimum/exporters/openvino/utils.py b/optimum/exporters/openvino/utils.py index 46b151e7d..1743dc59b 100644 --- a/optimum/exporters/openvino/utils.py +++ b/optimum/exporters/openvino/utils.py @@ -257,9 +257,15 @@ def deduce_diffusers_dtype(model_name_or_path, **loading_kwargs): model_part_name = "unet" if model_part_name: directory = path / model_part_name - safetensors_files = [ - filename for filename in directory.glob("*.safetensors") if len(filename.suffixes) == 1 - ] + + pattern = "*.safetensors" + if "variant" in loading_kwargs: + variant = loading_kwargs["variant"] + pattern = f"*.{variant}.safetensors" + safetensors_files = list(directory.glob(pattern)) + else: + # filter out variant files + safetensors_files = [filename for filename in directory.glob(pattern) if len(filename.suffixes) == 1] safetensors_file = None if len(safetensors_files) > 0: safetensors_file = safetensors_files.pop(0) diff --git a/optimum/intel/openvino/modeling_base.py b/optimum/intel/openvino/modeling_base.py index 99422f1a5..3fd26a6e0 100644 --- a/optimum/intel/openvino/modeling_base.py +++ b/optimum/intel/openvino/modeling_base.py @@ -594,6 +594,8 @@ def _from_transformers( else: ov_config = OVConfig(dtype="fp32") + variant = kwargs.pop("variant", None) + main_export( model_name_or_path=model_id, output=save_dir_path, @@ -607,6 +609,7 @@ def _from_transformers( trust_remote_code=trust_remote_code, ov_config=ov_config, library_name=cls._library_name, + model_variant=variant, ) return cls._from_pretrained( diff --git a/optimum/intel/openvino/modeling_base_seq2seq.py b/optimum/intel/openvino/modeling_base_seq2seq.py index 11ee8f89a..c60c0ec70 100644 --- a/optimum/intel/openvino/modeling_base_seq2seq.py +++ b/optimum/intel/openvino/modeling_base_seq2seq.py @@ -408,6 +408,7 @@ def _from_transformers( else: ov_config = OVConfig(dtype="fp32") stateful = kwargs.get("stateful", True) + variant = kwargs.pop("variant", None) main_export( model_name_or_path=model_id, @@ -422,6 +423,7 @@ def _from_transformers( trust_remote_code=trust_remote_code, ov_config=ov_config, stateful=stateful, + model_variant=variant, ) return cls._from_pretrained( diff --git a/optimum/intel/openvino/modeling_decoder.py b/optimum/intel/openvino/modeling_decoder.py index 4897db145..b411bf07d 100644 --- a/optimum/intel/openvino/modeling_decoder.py +++ b/optimum/intel/openvino/modeling_decoder.py @@ -310,6 +310,8 @@ def _from_transformers( if torch_dtype is not None: model_loading_kwargs["torch_dtype"] = torch_dtype + variant = kwargs.pop("variant", None) + main_export( model_name_or_path=model_id, output=save_dir_path, @@ -325,6 +327,7 @@ def _from_transformers( stateful=stateful, model_loading_kwargs=model_loading_kwargs, library_name=cls._library_name, + model_variant=variant, ) if config.model_type == "phi3" and config.max_position_embeddings != getattr( diff --git a/optimum/intel/openvino/modeling_diffusion.py b/optimum/intel/openvino/modeling_diffusion.py index 4a3f7104b..c2e245c5e 100644 --- a/optimum/intel/openvino/modeling_diffusion.py +++ b/optimum/intel/openvino/modeling_diffusion.py @@ -575,6 +575,7 @@ def _from_transformers( model_save_dir = TemporaryDirectory() model_save_path = Path(model_save_dir.name) + variant = kwargs.pop("variant", None) main_export( model_name_or_path=model_id, @@ -589,6 +590,7 @@ def _from_transformers( force_download=force_download, ov_config=ov_config, library_name=cls._library_name, + model_variant=variant, ) return cls._from_pretrained( diff --git a/optimum/intel/openvino/modeling_visual_language.py b/optimum/intel/openvino/modeling_visual_language.py index 1c0e35cca..c7cd7227f 100644 --- a/optimum/intel/openvino/modeling_visual_language.py +++ b/optimum/intel/openvino/modeling_visual_language.py @@ -615,6 +615,7 @@ def _from_transformers( ov_config = OVConfig(dtype="fp32" if load_in_8bit is False else "auto") stateful = kwargs.pop("stateful", ensure_stateful_is_available(warn=False) and use_cache) + variant = kwargs.pop("variant", None) main_export( model_name_or_path=model_id, @@ -629,6 +630,7 @@ def _from_transformers( trust_remote_code=trust_remote_code, ov_config=ov_config, stateful=stateful, + model_variant=variant, ) config = AutoConfig.from_pretrained(save_dir_path, trust_remote_code=trust_remote_code) return cls._from_pretrained( diff --git a/tests/openvino/test_diffusion.py b/tests/openvino/test_diffusion.py index b155353fc..bff187934 100644 --- a/tests/openvino/test_diffusion.py +++ b/tests/openvino/test_diffusion.py @@ -149,12 +149,9 @@ def test_compare_to_diffusers_pipeline(self, model_arch: str): for output_type in ["latent", "np", "pt"]: inputs["output_type"] = output_type if model_arch == "sana": - if output_type == "latent": - continue + # resolution binning will lead to resize output to standard resolution and back that can interpolate floating-point deviations inputs["use_resolution_binning"] = False - atol = 4e-2 - else: - atol = 6e-3 + atol = 1e-4 ov_output = ov_pipeline(**inputs, generator=get_generator("pt", SEED)).images diffusers_output = diffusers_pipeline(**inputs, generator=get_generator("pt", SEED)).images @@ -166,12 +163,9 @@ def test_compare_to_diffusers_pipeline(self, model_arch: str): for output_type in ["latent", "np", "pt"]: inputs["output_type"] = output_type if model_arch == "sana": - if output_type == "latent": - continue + # resolution binning will lead to resize output to standard resolution and back that can interpolate floating-point deviations inputs["use_resolution_binning"] = False - atol = 4e-2 - else: - atol = 6e-3 + atol = 6e-3 ov_output = ov_pipeline(**inputs, generator=get_generator("pt", SEED)).images diffusers_output = diffusers_pipeline(**inputs, generator=get_generator("pt", SEED)).images