Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sana support #1106

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions optimum/commands/export/openvino.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,12 @@ def parse_args_openvino(parser: "ArgumentParser"):
"This is needed by some models, for some tasks. If not provided, will attempt to use the tokenizer to guess it."
),
)
optional_group.add_argument(
"--variant",
type=str,
default=None,
help=("Select a variant of the model to export."),
)
optional_group.add_argument(
"--ratio",
type=float,
Expand Down Expand Up @@ -410,6 +416,10 @@ def run(self):
from optimum.intel import OVFluxPipeline

model_cls = OVFluxPipeline
elif class_name == "SanaPipeline":
from optimum.intel import OVSanaPipeline

model_cls = OVSanaPipeline
else:
raise NotImplementedError(f"Quantization in hybrid mode isn't supported for class {class_name}.")

Expand Down Expand Up @@ -442,6 +452,8 @@ def run(self):
quantization_config=quantization_config,
stateful=not self.args.disable_stateful,
trust_remote_code=self.args.trust_remote_code,
variant=self.args.variant,
cache_dir=self.args.cache_dir,
)
model.save_pretrained(self.args.output)

Expand All @@ -463,5 +475,6 @@ def run(self):
stateful=not self.args.disable_stateful,
convert_tokenizer=not self.args.disable_convert_tokenizer,
library_name=library_name,
model_variant=self.args.variant,
# **input_shapes,
)
5 changes: 5 additions & 0 deletions optimum/exporters/openvino/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ def main_export(
convert_tokenizer: bool = False,
library_name: Optional[str] = None,
model_loading_kwargs: Optional[Dict[str, Any]] = None,
model_variant: Optional[str] = None,
**kwargs_shapes,
):
"""
Expand Down Expand Up @@ -237,6 +238,8 @@ def main_export(
custom_architecture = False
patch_16bit = False
loading_kwargs = model_loading_kwargs or {}
if model_variant is not None:
loading_kwargs["variant"] = model_variant
if library_name == "transformers":
config = AutoConfig.from_pretrained(
model_name_or_path,
Expand Down Expand Up @@ -347,6 +350,7 @@ class StoreAttr(object):

GPTQQuantizer.post_init_model = post_init_model
elif library_name == "diffusers" and is_openvino_version(">=", "2024.6"):
_loading_kwargs = {} if model_variant is None else {"variant": model_variant}
dtype = deduce_diffusers_dtype(
model_name_or_path,
revision=revision,
Expand All @@ -355,6 +359,7 @@ class StoreAttr(object):
local_files_only=local_files_only,
force_download=force_download,
trust_remote_code=trust_remote_code,
**_loading_kwargs,
)
if dtype in [torch.float16, torch.bfloat16]:
loading_kwargs["torch_dtype"] = dtype
Expand Down
68 changes: 68 additions & 0 deletions optimum/exporters/openvino/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -1016,6 +1016,7 @@ def get_diffusion_models_for_export_ext(
is_sdxl = pipeline.__class__.__name__.startswith("StableDiffusionXL")
is_sd3 = pipeline.__class__.__name__.startswith("StableDiffusion3")
is_flux = pipeline.__class__.__name__.startswith("Flux")
is_sana = pipeline.__class__.__name__.startswith("Sana")
is_sd = pipeline.__class__.__name__.startswith("StableDiffusion") and not is_sd3
is_lcm = pipeline.__class__.__name__.startswith("LatentConsistencyModel")

Expand All @@ -1034,11 +1035,78 @@ def get_diffusion_models_for_export_ext(
models_for_export = get_sd3_models_for_export(pipeline, exporter, int_dtype, float_dtype)
elif is_flux:
models_for_export = get_flux_models_for_export(pipeline, exporter, int_dtype, float_dtype)
elif is_sana:
models_for_export = get_sana_models_for_export(pipeline, exporter, int_dtype, float_dtype)
else:
raise ValueError(f"Unsupported pipeline type `{pipeline.__class__.__name__}` provided")
return None, models_for_export


def get_sana_models_for_export(pipeline, exporter, int_dtype, float_dtype):
models_for_export = {}
text_encoder = pipeline.text_encoder
text_encoder_config_constructor = TasksManager.get_exporter_config_constructor(
model=text_encoder,
exporter=exporter,
library_name="diffusers",
task="feature-extraction",
model_type="gemma2-text-encoder",
)
text_encoder_export_config = text_encoder_config_constructor(
pipeline.text_encoder.config, int_dtype=int_dtype, float_dtype=float_dtype
)
text_encoder_export_config.runtime_options = {"ACTIVATIONS_SCALE_FACTOR": "8.0"}
models_for_export["text_encoder"] = (text_encoder, text_encoder_export_config)
transformer = pipeline.transformer
transformer.config.text_encoder_projection_dim = transformer.config.caption_channels
transformer.config.requires_aesthetics_score = False
transformer.config.time_cond_proj_dim = None
export_config_constructor = TasksManager.get_exporter_config_constructor(
model=transformer,
exporter=exporter,
library_name="diffusers",
task="semantic-segmentation",
model_type="sana-transformer",
)
transformer_export_config = export_config_constructor(
pipeline.transformer.config, int_dtype=int_dtype, float_dtype=float_dtype
)
models_for_export["transformer"] = (transformer, transformer_export_config)
# VAE Encoder https://github.com/huggingface/diffusers/blob/v0.11.1/src/diffusers/models/vae.py#L565
vae_encoder = copy.deepcopy(pipeline.vae)
vae_encoder.forward = lambda sample: {"latent": vae_encoder.encode(x=sample)["latent"]}
vae_config_constructor = TasksManager.get_exporter_config_constructor(
model=vae_encoder,
exporter=exporter,
library_name="diffusers",
task="semantic-segmentation",
model_type="dcae-encoder",
)
vae_encoder_export_config = vae_config_constructor(
vae_encoder.config, int_dtype=int_dtype, float_dtype=float_dtype
)
vae_encoder_export_config.runtime_options = {"ACTIVATIONS_SCALE_FACTOR": "8.0"}
models_for_export["vae_encoder"] = (vae_encoder, vae_encoder_export_config)

# VAE Decoder https://github.com/huggingface/diffusers/blob/v0.11.1/src/diffusers/models/vae.py#L600
vae_decoder = copy.deepcopy(pipeline.vae)
vae_decoder.forward = lambda latent_sample: vae_decoder.decode(z=latent_sample)
vae_config_constructor = TasksManager.get_exporter_config_constructor(
model=vae_decoder,
exporter=exporter,
library_name="diffusers",
task="semantic-segmentation",
model_type="vae-decoder",
)
vae_decoder_export_config = vae_config_constructor(
vae_decoder.config, int_dtype=int_dtype, float_dtype=float_dtype
)
vae_decoder_export_config.runtime_options = {"ACTIVATIONS_SCALE_FACTOR": "8.0"}
models_for_export["vae_decoder"] = (vae_decoder, vae_decoder_export_config)

return models_for_export


def get_sd3_models_for_export(pipeline, exporter, int_dtype, float_dtype):
models_for_export = {}

Expand Down
81 changes: 81 additions & 0 deletions optimum/exporters/openvino/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
PhiOnnxConfig,
T5OnnxConfig,
UNetOnnxConfig,
VaeEncoderOnnxConfig,
VisionOnnxConfig,
WhisperOnnxConfig,
)
Expand Down Expand Up @@ -105,6 +106,7 @@
Qwen2VLVisionEmbMergerPatcher,
QwenModelPatcher,
RotaryEmbPatcher,
SanaTextEncoderModelPatcher,
StatefulSeq2SeqDecoderPatcher,
UpdateCausalMaskModelPatcher,
XverseModelPatcher,
Expand Down Expand Up @@ -133,6 +135,8 @@ def init_model_configs():
if is_diffusers_available() and "fill" not in TasksManager._DIFFUSERS_TASKS_TO_MODEL_LOADERS:
TasksManager._DIFFUSERS_TASKS_TO_MODEL_LOADERS["fill"] = "FluxFillPipeline"
TasksManager._DIFFUSERS_TASKS_TO_MODEL_MAPPINGS["fill"] = {"flux": "FluxFillPipeline"}
TasksManager._DIFFUSERS_TASKS_TO_MODEL_LOADERS["text-to-image"] = ("AutoPipelineForText2Image", "SanaPipeline")
TasksManager._DIFFUSERS_TASKS_TO_MODEL_MAPPINGS["text-to-image"]["sana"] = "SanaPipeline"

supported_model_types = [
"_SUPPORTED_MODEL_TYPE",
Expand Down Expand Up @@ -1891,6 +1895,83 @@ class T5EncoderOpenVINOConfig(CLIPTextOpenVINOConfig):
pass


@register_in_tasks_manager("gemma2-text-encoder", *["feature-extraction"], library_name="diffusers")
class Gemma2TextEncoderOpenVINOConfig(CLIPTextOpenVINOConfig):
@property
def inputs(self) -> Dict[str, Dict[int, str]]:
return {
"input_ids": {0: "batch_size", 1: "sequence_length"},
"attention_mask": {0: "batch_size", 1: "sequence_length"},
}

def patch_model_for_export(
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
) -> ModelPatcher:
return SanaTextEncoderModelPatcher(self, model, model_kwargs)


class DummySanaSeq2SeqDecoderTextWithEncMaskInputGenerator(DummySeq2SeqDecoderTextInputGenerator):
SUPPORTED_INPUT_NAMES = (
"decoder_input_ids",
"decoder_attention_mask",
"encoder_outputs",
"encoder_hidden_states",
"encoder_attention_mask",
)


class DummySanaTransformerVisionInputGenerator(DummyUnetVisionInputGenerator):
def __init__(
self,
task: str,
normalized_config: NormalizedVisionConfig,
batch_size: int = DEFAULT_DUMMY_SHAPES["batch_size"],
num_channels: int = DEFAULT_DUMMY_SHAPES["num_channels"],
width: int = DEFAULT_DUMMY_SHAPES["width"] // 8,
height: int = DEFAULT_DUMMY_SHAPES["height"] // 8,
# Reduce img shape by 4 for FLUX to reduce memory usage on conversion
**kwargs,
):
super().__init__(task, normalized_config, batch_size, num_channels, width=width, height=height, **kwargs)


@register_in_tasks_manager("sana-transformer", *["semantic-segmentation"], library_name="diffusers")
class SanaTransformerOpenVINOConfig(UNetOpenVINOConfig):
NORMALIZED_CONFIG_CLASS = NormalizedConfig.with_args(
image_size="sample_size",
num_channels="in_channels",
hidden_size="caption_channels",
vocab_size="attention_head_dim",
allow_new=True,
)
DUMMY_INPUT_GENERATOR_CLASSES = (
DummySanaTransformerVisionInputGenerator,
DummySanaSeq2SeqDecoderTextWithEncMaskInputGenerator,
) + UNetOpenVINOConfig.DUMMY_INPUT_GENERATOR_CLASSES[1:-1]

@property
def inputs(self):
common_inputs = super().inputs
common_inputs["encoder_attention_mask"] = {0: "batch_size", 1: "sequence_length"}
return common_inputs

def rename_ambiguous_inputs(self, inputs):
# The input name in the model signature is `x, hence the export input name is updated.
hidden_states = inputs.pop("sample", None)
if hidden_states is not None:
inputs["hidden_states"] = hidden_states
return inputs


@register_in_tasks_manager("dcae-encoder", *["semantic-segmentation"], library_name="diffusers")
class DcaeEncoderOpenVINOConfig(VaeEncoderOnnxConfig):
@property
def outputs(self) -> Dict[str, Dict[int, str]]:
return {
"latent": {0: "batch_size", 2: "height_latent", 3: "width_latent"},
}


class DummyFluxTransformerInputGenerator(DummyVisionInputGenerator):
SUPPORTED_INPUT_NAMES = (
"pixel_values",
Expand Down
38 changes: 34 additions & 4 deletions optimum/exporters/openvino/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,11 @@

import torch
import torch.nn.functional as F
from transformers import PreTrainedModel, TFPreTrainedModel
from transformers.modeling_outputs import BaseModelOutputWithPast, BaseModelOutputWithPooling
from transformers.utils import is_tf_available

from optimum.exporters.onnx.base import OnnxConfig
from optimum.exporters.onnx.model_patcher import (
DecoderModelPatcher,
ModelPatcher,
Expand Down Expand Up @@ -114,18 +116,20 @@ def patch_model_with_bettertransformer(model):
return model


def patch_update_causal_mask(model, transformers_version, inner_model_name="model", patch_fn=None):
def patch_update_causal_mask(
model, transformers_version, inner_model_name="model", patch_fn=None, patch_extrnal_model=False
):
if is_transformers_version(">=", transformers_version):
inner_model = getattr(model, inner_model_name, None)
inner_model = getattr(model, inner_model_name, None) if not patch_extrnal_model else model
if inner_model is not None:
if hasattr(inner_model, "_update_causal_mask"):
inner_model._orig_update_causal_mask = inner_model._update_causal_mask
patch_fn = patch_fn or _llama_gemma_update_causal_mask
inner_model._update_causal_mask = types.MethodType(patch_fn, inner_model)


def unpatch_update_causal_mask(model, inner_model_name="model"):
inner_model = getattr(model, inner_model_name, None)
def unpatch_update_causal_mask(model, inner_model_name="model", patch_extrnal_model=False):
inner_model = getattr(model, inner_model_name, None) if not patch_extrnal_model else model
if inner_model is not None and hasattr(inner_model, "._orig_update_causal_mask"):
inner_model._update_causal_mask = inner_model._orig_update_causal_mask

Expand Down Expand Up @@ -3791,3 +3795,29 @@ def patched_forward(*args, **kwargs):
model.forward = patched_forward

super().__init__(config, model, model_kwargs)


class SanaTextEncoderModelPatcher(ModelPatcher):
def __enter__(self):
super().__enter__()
patch_update_causal_mask(self._model, "4.39.0", None, patch_extrnal_model=True)

if self._model.config._attn_implementation != "sdpa":
self._model.config._orig_attn_implementation = self._model.config._attn_implementation
self._model.config._attn_implementation = "sdpa"
if is_transformers_version("<", "4.47.0"):
from transformers.models.gemma2.modeling_gemma2 import GEMMA2_ATTENTION_CLASSES

sdpa_attn = GEMMA2_ATTENTION_CLASSES["sdpa"]
for layer in self._model.layers:
layer.self_attn._orig_forward = layer.self_attn.forward
layer.self_attn.forward = types.MethodType(sdpa_attn.forward, layer.self_attn)

def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
unpatch_update_causal_mask(self._model, None, True)
if hasattr(self._model.config, "_orig_attn_implementation"):
self._model.config._attn_implementation = self._model.config._orig_attn_implementation
for layer in self._model.layers:
if hasattr(layer.self_attn, "_orig_forward"):
layer.self_attn.forward = layer.self_attn._orig_forward
12 changes: 9 additions & 3 deletions optimum/exporters/openvino/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,9 +257,15 @@ def deduce_diffusers_dtype(model_name_or_path, **loading_kwargs):
model_part_name = "unet"
if model_part_name:
directory = path / model_part_name
safetensors_files = [
filename for filename in directory.glob("*.safetensors") if len(filename.suffixes) == 1
]

pattern = "*.safetensors"
if "variant" in loading_kwargs:
variant = loading_kwargs["variant"]
pattern = f"*.{variant}.safetensors"
safetensors_files = list(directory.glob(pattern))
else:
# filter out variant files
safetensors_files = [filename for filename in directory.glob(pattern) if len(filename.suffixes) == 1]
safetensors_file = None
if len(safetensors_files) > 0:
safetensors_file = safetensors_files.pop(0)
Expand Down
2 changes: 2 additions & 0 deletions optimum/intel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@
"OVFluxImg2ImgPipeline",
"OVFluxInpaintPipeline",
"OVFluxFillPipeline",
"OVSanaPipeline",
"OVPipelineForImage2Image",
"OVPipelineForText2Image",
"OVPipelineForInpainting",
Expand All @@ -150,6 +151,7 @@
"OVFluxImg2ImgPipeline",
"OVFluxInpaintPipeline",
"OVFluxFillPipeline",
"OVSanaPipeline",
"OVPipelineForImage2Image",
"OVPipelineForText2Image",
"OVPipelineForInpainting",
Expand Down
1 change: 1 addition & 0 deletions optimum/intel/openvino/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@
OVPipelineForImage2Image,
OVPipelineForInpainting,
OVPipelineForText2Image,
OVSanaPipeline,
OVStableDiffusion3Img2ImgPipeline,
OVStableDiffusion3InpaintPipeline,
OVStableDiffusion3Pipeline,
Expand Down
3 changes: 3 additions & 0 deletions optimum/intel/openvino/modeling_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -594,6 +594,8 @@ def _from_transformers(
else:
ov_config = OVConfig(dtype="fp32")

variant = kwargs.pop("variant", None)

main_export(
model_name_or_path=model_id,
output=save_dir_path,
Expand All @@ -607,6 +609,7 @@ def _from_transformers(
trust_remote_code=trust_remote_code,
ov_config=ov_config,
library_name=cls._library_name,
model_variant=variant,
)

return cls._from_pretrained(
Expand Down
Loading