Skip to content

Commit

Permalink
fixed nn archive export for rvc2 2021.4
Browse files Browse the repository at this point in the history
  • Loading branch information
kozlov721 committed Oct 10, 2024
1 parent 031b46a commit 73231e6
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 6 deletions.
4 changes: 3 additions & 1 deletion modelconverter/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,7 +479,9 @@ def convert(
archive_cfg,
preprocessing,
main_stage,
out_models[0],
exporter.inference_model_path
if isinstance(exporter, Exporter)
else exporter.exporters[main_stage].inference_model_path,
)
generator = ArchiveGenerator(
archive_name=f"{cfg.name}.{target.value.lower()}",
Expand Down
57 changes: 55 additions & 2 deletions modelconverter/utils/metadata.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import io
from importlib.metadata import version
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List
Expand Down Expand Up @@ -76,6 +77,58 @@ def _get_metadata_dlc(model_path: Path) -> Metadata:


def _get_metadata_ir(bin_path: Path, xml_path: Path) -> Metadata:
if version("openvino") == "2021.4.0":
return _get_metadata_ir_ie(bin_path, xml_path)
return _get_metadata_ir_runtime(bin_path, xml_path)


def _get_metadata_ir_ie(bin_path: Path, xml_path: Path) -> Metadata:
"""
Extracts metadata from an OpenVINO IR model using the Inference Engine API.
Args:
bin_path (Path): Path to the model's .bin file.
xml_path (Path): Path to the model's .xml file.
Returns:
Metadata: An object containing input/output shapes and data types.
"""
from openvino.inference_engine import IECore

ie = IECore()
try:
network = ie.read_network(model=str(xml_path), weights=str(bin_path))
except Exception as e:
raise ValueError(
f"Failed to load IR model: `{bin_path}` and `{xml_path}`"
) from e

input_shapes = {}
input_dtypes = {}
output_shapes = {}
output_dtypes = {}

for input_name, input_info in network.input_info.items():
input_shapes[input_name] = list(input_info.input_data.shape)

ie_precision = input_info.input_data.precision
input_dtypes[input_name] = DataType.from_ir_ie_dtype(ie_precision)

for output_name, output_data in network.outputs.items():
output_shapes[output_name] = list(output_data.shape)

ie_precision = output_data.precision
output_dtypes[output_name] = DataType.from_ir_ie_dtype(ie_precision)

return Metadata(
input_shapes=input_shapes,
input_dtypes=input_dtypes,
output_shapes=output_shapes,
output_dtypes=output_dtypes,
)


def _get_metadata_ir_runtime(bin_path: Path, xml_path: Path) -> Metadata:
from openvino.runtime import Core

ie = Core()
Expand All @@ -94,13 +147,13 @@ def _get_metadata_ir(bin_path: Path, xml_path: Path) -> Metadata:
for inp in model.inputs:
name = list(inp.names)[0]
input_shapes[name] = list(inp.shape)
input_dtypes[name] = DataType.from_ir_dtype(
input_dtypes[name] = DataType.from_ir_runtime_dtype(
inp.element_type.get_type_name()
)
for output in model.outputs:
name = list(output.names)[0]
output_shapes[name] = list(output.shape)
output_dtypes[name] = DataType.from_ir_dtype(
output_dtypes[name] = DataType.from_ir_runtime_dtype(
output.element_type.get_type_name()
)

Expand Down
26 changes: 23 additions & 3 deletions modelconverter/utils/types.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from enum import Enum
from pathlib import Path

from typing import Union

import numpy as np
from onnx.onnx_pb import TensorProto

Expand Down Expand Up @@ -131,7 +131,27 @@ def from_numpy_dtype(cls, dtype: np.dtype) -> "DataType":
return cls(dtype_map[dtype])

@classmethod
def from_ir_dtype(cls, dtype: str) -> "DataType":
def from_ir_ie_dtype(cls, dtype: str) -> "DataType":
dtype_map = {
"FP16": "float16",
"FP32": "float32",
"FP64": "float64",
"I8": "int8",
"I16": "int16",
"I32": "int32",
"I64": "int64",
"U8": "uint8",
"U16": "uint16",
"U32": "uint32",
"U64": "uint64",
"BOOL": "boolean",
}
if dtype not in dtype_map:
raise ValueError(f"Unsupported IR data type: `{dtype}`")
return cls(dtype_map[dtype])

@classmethod
def from_ir_runtime_dtype(cls, dtype: str) -> "DataType":
dtype_map = {
"f16": "float16",
"f32": "float32",
Expand All @@ -147,7 +167,7 @@ def from_ir_dtype(cls, dtype: str) -> "DataType":
"boolean": "boolean",
}
if dtype not in dtype_map:
raise ValueError(f"Unsupported IR data type: `{dtype}`")
raise ValueError(f"Unsupported IR runtime data type: `{dtype}`")
return cls(dtype_map[dtype])

def as_numpy_dtype(self) -> np.dtype:
Expand Down

0 comments on commit 73231e6

Please sign in to comment.