Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

try to remove onnx fallback #1116

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 34 additions & 71 deletions optimum/exporters/openvino/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,80 +415,43 @@ def export_pytorch(
dummy_inputs = config.rename_ambiguous_inputs(dummy_inputs)
dummy_inputs, dict_inputs = remove_none_from_dummy_inputs(dummy_inputs)

try:
# TorchScript used behind OpenVINO conversion. Optimum supports only return_dict=True models for patching,
# while TorchScript do not support dictionary with values of mixed types (e.g. Tensor and None) in model input/output
# To handle it, additional wrapper on patcher forward applied.
# model.config.torchscript = True can not be used for patching, because it overrides return_dict to False
patcher = config.patch_model_for_export(model, model_kwargs=model_kwargs)
patched_forward = patcher.patched_forward

@functools.wraps(patched_forward)
def ts_patched_forward(*args, **kwargs):
for i in range(len(dict_inputs)):
input_name, keys = dict_inputs[i]
tuple_input = kwargs[input_name]
input_dict = dict(zip(keys, tuple_input))
kwargs[input_name] = input_dict
outputs = patched_forward(*args, **kwargs)
return tuple([value if not isinstance(value, list) else tuple(value) for value in outputs.values()])

patcher.patched_forward = ts_patched_forward

ts_decoder_kwargs = {}
if library_name == "diffusers" and is_openvino_version(">=", "2025.0"):
ts_decoder_kwargs["trace_kwargs"] = {"check_trace": False}

with patcher:
if patch_16bit_model:
from openvino.frontend.pytorch.patch_model import __make_16bit_traceable

__make_16bit_traceable(model)
check_dummy_inputs_are_allowed(model, dummy_inputs)
input_info = _get_input_info(model, config, dummy_inputs)
ts_decoder = TorchScriptPythonDecoder(model, example_input=dummy_inputs, **ts_decoder_kwargs)
ov_model = convert_model(
ts_decoder,
example_input=dummy_inputs,
input=[(item.shape, item.type) for item in input_info],
)

except Exception as ex:
logger.warning(f"Export model to OpenVINO directly failed with: \n{ex}.\nModel will be exported to ONNX")

if stateful:
# cannot raise because stateful is enabled by default and it would break backward compatibility for models that couldn't convert to OV directly
# TODO: Implement stateful for ONNX path as well, not doing it right now because of lack of validation
logger.warning(
"[ WARNING ] Making stateful models is not supported when exporting to ONNX as an intermediate step. "
"A stateless model will be exported instead. It may result in sub-optimal inference performance."
"Provide a model that can be converted to OpenVINO without fallback to ONNX conversion path."
)

# TorchScript used behind OpenVINO conversion. Optimum supports only return_dict=True models for patching,
# while TorchScript do not support dictionary with values of mixed types (e.g. Tensor and None) in model input/output
# To handle it, additional wrapper on patcher forward applied.
# model.config.torchscript = True can not be used for patching, because it overrides return_dict to False
patcher = config.patch_model_for_export(model, model_kwargs=model_kwargs)
patched_forward = patcher.patched_forward

@functools.wraps(patched_forward)
def ts_patched_forward(*args, **kwargs):
for i in range(len(dict_inputs)):
input_name, keys = dict_inputs[i]
tuple_input = kwargs[input_name]
input_dict = dict(zip(keys, tuple_input))
kwargs[input_name] = input_dict
outputs = patched_forward(*args, **kwargs)
return tuple([value if not isinstance(value, list) else tuple(value) for value in outputs.values()])

patcher.patched_forward = ts_patched_forward

ts_decoder_kwargs = {}
if library_name == "diffusers" and is_openvino_version(">=", "2025.0"):
ts_decoder_kwargs["trace_kwargs"] = {"check_trace": False}

with patcher:
if patch_16bit_model:
from openvino.frontend.pytorch.patch_model import unpatch_model

unpatch_model(model, "_openvino_module_extension_patch_orig_forward")
for m in model.modules():
if any(p.dtype in [torch.float16, torch.bfloat16] for p in m.parameters(False)) or any(
b.dtype in [torch.float16, torch.bfloat16] for b in m.buffers(False)
):
m.float()

return export_pytorch_via_onnx(
model,
config,
opset,
output,
device,
input_shapes,
model_kwargs,
ov_config=ov_config,
library_name=library_name,
from openvino.frontend.pytorch.patch_model import __make_16bit_traceable

__make_16bit_traceable(model)
check_dummy_inputs_are_allowed(model, dummy_inputs)
input_info = _get_input_info(model, config, dummy_inputs)
ts_decoder = TorchScriptPythonDecoder(model, example_input=dummy_inputs, **ts_decoder_kwargs)
ov_model = convert_model(
ts_decoder,
example_input=dummy_inputs,
input=[(item.shape, item.type) for item in input_info],
)

ov_model.validate_nodes_and_infer_types() # TODO: remove as unnecessary validation?

output_names = list(config.outputs.keys())
for idx, out_tensor in enumerate(ov_model.outputs):
if idx < len(output_names):
Expand Down
Loading