From 73231e68f7b7e21b76b4b842eab39d057528af4e Mon Sep 17 00:00:00 2001
From: Martin Kozlovsky <martin.kozlovsky@luxonis.com>
Date: Thu, 10 Oct 2024 19:14:43 +0200
Subject: [PATCH] fixed nn archive export for rvc2 2021.4

---
 modelconverter/__main__.py       |  4 ++-
 modelconverter/utils/metadata.py | 57 ++++++++++++++++++++++++++++++--
 modelconverter/utils/types.py    | 26 +++++++++++++--
 3 files changed, 81 insertions(+), 6 deletions(-)

diff --git a/modelconverter/__main__.py b/modelconverter/__main__.py
index b7bebe6..fac2772 100644
--- a/modelconverter/__main__.py
+++ b/modelconverter/__main__.py
@@ -479,7 +479,9 @@ def convert(
                     archive_cfg,
                     preprocessing,
                     main_stage,
-                    out_models[0],
+                    exporter.inference_model_path
+                    if isinstance(exporter, Exporter)
+                    else exporter.exporters[main_stage].inference_model_path,
                 )
                 generator = ArchiveGenerator(
                     archive_name=f"{cfg.name}.{target.value.lower()}",
diff --git a/modelconverter/utils/metadata.py b/modelconverter/utils/metadata.py
index 685ac6b..8d020c7 100644
--- a/modelconverter/utils/metadata.py
+++ b/modelconverter/utils/metadata.py
@@ -1,4 +1,5 @@
 import io
+from importlib.metadata import version
 from dataclasses import dataclass
 from pathlib import Path
 from typing import Dict, List
@@ -76,6 +77,58 @@ def _get_metadata_dlc(model_path: Path) -> Metadata:
 
 
 def _get_metadata_ir(bin_path: Path, xml_path: Path) -> Metadata:
+    if version("openvino") == "2021.4.0":
+        return _get_metadata_ir_ie(bin_path, xml_path)
+    return _get_metadata_ir_runtime(bin_path, xml_path)
+
+
+def _get_metadata_ir_ie(bin_path: Path, xml_path: Path) -> Metadata:
+    """
+    Extracts metadata from an OpenVINO IR model using the Inference Engine API.
+
+    Args:
+        bin_path (Path): Path to the model's .bin file.
+        xml_path (Path): Path to the model's .xml file.
+
+    Returns:
+        Metadata: An object containing input/output shapes and data types.
+    """
+    from openvino.inference_engine import IECore
+
+    ie = IECore()
+    try:
+        network = ie.read_network(model=str(xml_path), weights=str(bin_path))
+    except Exception as e:
+        raise ValueError(
+            f"Failed to load IR model: `{bin_path}` and `{xml_path}`"
+        ) from e
+
+    input_shapes = {}
+    input_dtypes = {}
+    output_shapes = {}
+    output_dtypes = {}
+
+    for input_name, input_info in network.input_info.items():
+        input_shapes[input_name] = list(input_info.input_data.shape)
+
+        ie_precision = input_info.input_data.precision
+        input_dtypes[input_name] = DataType.from_ir_ie_dtype(ie_precision)
+
+    for output_name, output_data in network.outputs.items():
+        output_shapes[output_name] = list(output_data.shape)
+
+        ie_precision = output_data.precision
+        output_dtypes[output_name] = DataType.from_ir_ie_dtype(ie_precision)
+
+    return Metadata(
+        input_shapes=input_shapes,
+        input_dtypes=input_dtypes,
+        output_shapes=output_shapes,
+        output_dtypes=output_dtypes,
+    )
+
+
+def _get_metadata_ir_runtime(bin_path: Path, xml_path: Path) -> Metadata:
     from openvino.runtime import Core
 
     ie = Core()
@@ -94,13 +147,13 @@ def _get_metadata_ir(bin_path: Path, xml_path: Path) -> Metadata:
     for inp in model.inputs:
         name = list(inp.names)[0]
         input_shapes[name] = list(inp.shape)
-        input_dtypes[name] = DataType.from_ir_dtype(
+        input_dtypes[name] = DataType.from_ir_runtime_dtype(
             inp.element_type.get_type_name()
         )
     for output in model.outputs:
         name = list(output.names)[0]
         output_shapes[name] = list(output.shape)
-        output_dtypes[name] = DataType.from_ir_dtype(
+        output_dtypes[name] = DataType.from_ir_runtime_dtype(
             output.element_type.get_type_name()
         )
 
diff --git a/modelconverter/utils/types.py b/modelconverter/utils/types.py
index 26c77f7..fee1d89 100644
--- a/modelconverter/utils/types.py
+++ b/modelconverter/utils/types.py
@@ -1,7 +1,7 @@
 from enum import Enum
 from pathlib import Path
-
 from typing import Union
+
 import numpy as np
 from onnx.onnx_pb import TensorProto
 
@@ -131,7 +131,27 @@ def from_numpy_dtype(cls, dtype: np.dtype) -> "DataType":
         return cls(dtype_map[dtype])
 
     @classmethod
-    def from_ir_dtype(cls, dtype: str) -> "DataType":
+    def from_ir_ie_dtype(cls, dtype: str) -> "DataType":
+        dtype_map = {
+            "FP16": "float16",
+            "FP32": "float32",
+            "FP64": "float64",
+            "I8": "int8",
+            "I16": "int16",
+            "I32": "int32",
+            "I64": "int64",
+            "U8": "uint8",
+            "U16": "uint16",
+            "U32": "uint32",
+            "U64": "uint64",
+            "BOOL": "boolean",
+        }
+        if dtype not in dtype_map:
+            raise ValueError(f"Unsupported IR data type: `{dtype}`")
+        return cls(dtype_map[dtype])
+
+    @classmethod
+    def from_ir_runtime_dtype(cls, dtype: str) -> "DataType":
         dtype_map = {
             "f16": "float16",
             "f32": "float32",
@@ -147,7 +167,7 @@ def from_ir_dtype(cls, dtype: str) -> "DataType":
             "boolean": "boolean",
         }
         if dtype not in dtype_map:
-            raise ValueError(f"Unsupported IR data type: `{dtype}`")
+            raise ValueError(f"Unsupported IR runtime data type: `{dtype}`")
         return cls(dtype_map[dtype])
 
     def as_numpy_dtype(self) -> np.dtype: