From e81ec85c19fb952f4da64aca69b5e84712a93bc1 Mon Sep 17 00:00:00 2001 From: klemen1999 Date: Wed, 6 Nov 2024 09:33:44 +0100 Subject: [PATCH] added tests, fix in docstring --- .../base_models/input.py | 4 +- tests/test_nn_archive/test_nn_archive.py | 45 +++++++++++++++++++ 2 files changed, 47 insertions(+), 2 deletions(-) diff --git a/luxonis_ml/nn_archive/config_building_blocks/base_models/input.py b/luxonis_ml/nn_archive/config_building_blocks/base_models/input.py index 9f97d39e..2ff2bf74 100644 --- a/luxonis_ml/nn_archive/config_building_blocks/base_models/input.py +++ b/luxonis_ml/nn_archive/config_building_blocks/base_models/input.py @@ -24,8 +24,8 @@ class PreprocessingBlock(BaseModelExtraForbid): @type interleaved_to_planar: bool | None @ivar interleaved_to_planar: If True input to the model is interleaved (NHWC) else planar (NCHW). - @type layout: str | None - @ivar layout: DepthAI input type which is read by DepthAI to + @type dai_type: str | None + @ivar dai_type: DepthAI input type which is read by DepthAI to automatically setup the pipeline. """ diff --git a/tests/test_nn_archive/test_nn_archive.py b/tests/test_nn_archive/test_nn_archive.py index d407d681..bcf838f3 100644 --- a/tests/test_nn_archive/test_nn_archive.py +++ b/tests/test_nn_archive/test_nn_archive.py @@ -132,6 +132,51 @@ def test_archive_generator( assert "config.json" in tar.getnames() +def test_config_version(): + from luxonis_ml.nn_archive import Config + + cfg_dict = { + "config_version": "1.0", + "model": { + "metadata": { + "name": "test_model", + "path": "test_model.onnx", + }, + "inputs": [ + { + "name": "input", + "shape": [1, 3, 224, 224], + "input_type": "image", + "layout": "nchw", + "dtype": "float32", + "preprocessing": { + "mean": [0.485, 0.456, 0.406], + "scale": [0.229, 0.224, 0.225], + "reverse_channels": False, + "interleaved_to_planar": False, + }, + } + ], + "outputs": [ + { + "name": "output", + "dtype": "float32", + } + ], + "heads": [], + }, + } + Config(**cfg_dict) + cfg_dict["config_version"] = "1.2" + Config(**cfg_dict) + cfg_dict["config_version"] = "1.2.2" + with pytest.raises(ValidationError): + Config(**cfg_dict) + cfg_dict["config_version"] = "1.a" + with pytest.raises(ValidationError): + Config(**cfg_dict) + + def test_optional_head_name(): from luxonis_ml.nn_archive.config_building_blocks.base_models.head_metadata import ( HeadMetadata,