diff --git a/modelconverter/__main__.py b/modelconverter/__main__.py index b7bebe6..fac2772 100644 --- a/modelconverter/__main__.py +++ b/modelconverter/__main__.py @@ -479,7 +479,9 @@ def convert( archive_cfg, preprocessing, main_stage, - out_models[0], + exporter.inference_model_path + if isinstance(exporter, Exporter) + else exporter.exporters[main_stage].inference_model_path, ) generator = ArchiveGenerator( archive_name=f"{cfg.name}.{target.value.lower()}", diff --git a/modelconverter/utils/metadata.py b/modelconverter/utils/metadata.py index 685ac6b..8d020c7 100644 --- a/modelconverter/utils/metadata.py +++ b/modelconverter/utils/metadata.py @@ -1,4 +1,5 @@ import io +from importlib.metadata import version from dataclasses import dataclass from pathlib import Path from typing import Dict, List @@ -76,6 +77,58 @@ def _get_metadata_dlc(model_path: Path) -> Metadata: def _get_metadata_ir(bin_path: Path, xml_path: Path) -> Metadata: + if version("openvino") == "2021.4.0": + return _get_metadata_ir_ie(bin_path, xml_path) + return _get_metadata_ir_runtime(bin_path, xml_path) + + +def _get_metadata_ir_ie(bin_path: Path, xml_path: Path) -> Metadata: + """ + Extracts metadata from an OpenVINO IR model using the Inference Engine API. + + Args: + bin_path (Path): Path to the model's .bin file. + xml_path (Path): Path to the model's .xml file. + + Returns: + Metadata: An object containing input/output shapes and data types. + """ + from openvino.inference_engine import IECore + + ie = IECore() + try: + network = ie.read_network(model=str(xml_path), weights=str(bin_path)) + except Exception as e: + raise ValueError( + f"Failed to load IR model: `{bin_path}` and `{xml_path}`" + ) from e + + input_shapes = {} + input_dtypes = {} + output_shapes = {} + output_dtypes = {} + + for input_name, input_info in network.input_info.items(): + input_shapes[input_name] = list(input_info.input_data.shape) + + ie_precision = input_info.input_data.precision + input_dtypes[input_name] = DataType.from_ir_ie_dtype(ie_precision) + + for output_name, output_data in network.outputs.items(): + output_shapes[output_name] = list(output_data.shape) + + ie_precision = output_data.precision + output_dtypes[output_name] = DataType.from_ir_ie_dtype(ie_precision) + + return Metadata( + input_shapes=input_shapes, + input_dtypes=input_dtypes, + output_shapes=output_shapes, + output_dtypes=output_dtypes, + ) + + +def _get_metadata_ir_runtime(bin_path: Path, xml_path: Path) -> Metadata: from openvino.runtime import Core ie = Core() @@ -94,13 +147,13 @@ def _get_metadata_ir(bin_path: Path, xml_path: Path) -> Metadata: for inp in model.inputs: name = list(inp.names)[0] input_shapes[name] = list(inp.shape) - input_dtypes[name] = DataType.from_ir_dtype( + input_dtypes[name] = DataType.from_ir_runtime_dtype( inp.element_type.get_type_name() ) for output in model.outputs: name = list(output.names)[0] output_shapes[name] = list(output.shape) - output_dtypes[name] = DataType.from_ir_dtype( + output_dtypes[name] = DataType.from_ir_runtime_dtype( output.element_type.get_type_name() ) diff --git a/modelconverter/utils/types.py b/modelconverter/utils/types.py index 26c77f7..fee1d89 100644 --- a/modelconverter/utils/types.py +++ b/modelconverter/utils/types.py @@ -1,7 +1,7 @@ from enum import Enum from pathlib import Path - from typing import Union + import numpy as np from onnx.onnx_pb import TensorProto @@ -131,7 +131,27 @@ def from_numpy_dtype(cls, dtype: np.dtype) -> "DataType": return cls(dtype_map[dtype]) @classmethod - def from_ir_dtype(cls, dtype: str) -> "DataType": + def from_ir_ie_dtype(cls, dtype: str) -> "DataType": + dtype_map = { + "FP16": "float16", + "FP32": "float32", + "FP64": "float64", + "I8": "int8", + "I16": "int16", + "I32": "int32", + "I64": "int64", + "U8": "uint8", + "U16": "uint16", + "U32": "uint32", + "U64": "uint64", + "BOOL": "boolean", + } + if dtype not in dtype_map: + raise ValueError(f"Unsupported IR data type: `{dtype}`") + return cls(dtype_map[dtype]) + + @classmethod + def from_ir_runtime_dtype(cls, dtype: str) -> "DataType": dtype_map = { "f16": "float16", "f32": "float32", @@ -147,7 +167,7 @@ def from_ir_dtype(cls, dtype: str) -> "DataType": "boolean": "boolean", } if dtype not in dtype_map: - raise ValueError(f"Unsupported IR data type: `{dtype}`") + raise ValueError(f"Unsupported IR runtime data type: `{dtype}`") return cls(dtype_map[dtype]) def as_numpy_dtype(self) -> np.dtype: