Skip to content

Commit

Permalink
Fix test_modifier test error with EfficientVIT model and change the A…
Browse files Browse the repository at this point in the history
…PI calls.
  • Loading branch information
ptoupas committed Jan 10, 2025
1 parent 82a7044 commit a34b9ed
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 30 deletions.
19 changes: 15 additions & 4 deletions modelconverter/utils/onnx_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,11 +191,19 @@ class ONNXModifier:
Path to the base ONNX model
output_path : Path
Path to save the modified ONNX model
skip_optimisation : bool
Flag to skip optimization of the ONNX model
"""

def __init__(self, model_path: Path, output_path: Path) -> None:
def __init__(
self,
model_path: Path,
output_path: Path,
skip_optimisation: bool = False,
) -> None:
self.model_path = model_path
self.output_path = output_path
self.skip_optimisation = skip_optimisation
self.load_onnx()
self.prev_onnx_model = self.onnx_model
self.prev_onnx_gs = self.onnx_gs
Expand All @@ -207,7 +215,8 @@ def load_onnx(self) -> None:
logger.info(f"Loading model: {self.model_path.stem}")

self.onnx_model, _ = simplify(
self.model_path.as_posix(), perform_optimization=True
self.model_path.as_posix(),
perform_optimization=True and not self.skip_optimisation,
)

self.dtype = onnx.mapping.TENSOR_TYPE_TO_NP_TYPE[
Expand All @@ -232,8 +241,10 @@ def optimize_onnx(self, passes: Optional[List[str]] = None) -> None:
@type passes: Optional[List[str]]
"""

optimised_onnx_model = onnxoptimizer.optimize(
self.onnx_model, passes=passes
optimised_onnx_model = (
self.onnx_model
if self.skip_optimisation
else onnxoptimizer.optimize(self.onnx_model, passes=passes)
)

optimised_onnx_model, _ = simplify(
Expand Down
54 changes: 28 additions & 26 deletions tests/test_utils/test_modifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@
from pathlib import Path
from typing import Tuple

import requests
import wget
from luxonis_ml.nn_archive.config import Config as NNArchiveConfig
from luxonis_ml.nn_archive.config_building_blocks import InputType

from modelconverter.cli import Request
from modelconverter.utils import ONNXModifier
from modelconverter.utils.config import Config
from modelconverter.utils.onnx_tools import onnx_attach_normalization_to_inputs
Expand All @@ -25,18 +25,18 @@
"mult_512x288",
]

EXCEMPT_OPTIMISATION = [
"efficientvit-b1-224",
]


def download_onnx_models():
if not os.path.exists(DATA_DIR):
os.makedirs(DATA_DIR)

url = "https://easyml.cloud.luxonis.com/models/api/v1/models?is_public=true&limit=1000"
response = requests.get(url, headers=HEADERS)
if response.status_code != 200:
raise ValueError(
f"Failed to get models. Status code: {response.status_code}"
)
hub_ai_models = response.json()
hub_ai_models = Request.get(
"models/", params={"is_public": True, "limit": 1000}
)

for model in hub_ai_models:
if "ONNX" in model["exportable_types"]:
Expand All @@ -46,25 +46,22 @@ def download_onnx_models():
os.makedirs(model_dir)
model_id = model["id"]

url = f"https://easyml.cloud.luxonis.com/models/api/v1/modelVersions?model_id={model_id}"
response = requests.get(url, headers=HEADERS)
if response.status_code != 200:
raise ValueError(
f"Failed to get model versions. Status code: {response.status_code}"
)
model_versions = response.json()
model_variants = Request.get(
"modelVersions/",
params={
"model_id": model_id,
"is_public": True,
"limit": 1000,
},
)

for version in model_versions:
if "ONNX" in version["exportable_types"]:
model_version_id = version["id"]
for variant in model_variants:
if "ONNX" in variant["exportable_types"]:
model_version_id = variant["id"]
break
url = f"https://easyml.cloud.luxonis.com/models/api/v1/modelVersions/{model_version_id}/download"
response = requests.get(url, headers=HEADERS)
if response.status_code != 200:
raise ValueError(
f"Failed to download model. Status code: {response.status_code}"
)
download_info = response.json()
download_info = Request.get(
f"modelVersions/{model_version_id}/download"
)

model_download_link = download_info[0]["download_link"]

Expand Down Expand Up @@ -210,6 +207,9 @@ def pytest_generate_tests(metafunc):


def test_onnx_model(onnx_file):
skip_optimisation = (
True if onnx_file.stem in EXCEMPT_OPTIMISATION else False
)
nn_config = onnx_file.parent / f"{onnx_file.stem}_config.json"
cfg, main_stage_key = get_config(nn_config)

Expand All @@ -229,7 +229,9 @@ def test_onnx_model(onnx_file):
onnx_file.parent / f"{onnx_file.stem}_modified_optimised.onnx"
)
onnx_modifier = ONNXModifier(
model_path=modified_onnx, output_path=modified_optimised_onnx
model_path=modified_onnx,
output_path=modified_optimised_onnx,
skip_optimisation=skip_optimisation,
)

if onnx_modifier.has_dynamic_shape:
Expand Down

0 comments on commit a34b9ed

Please sign in to comment.