Skip to content

Commit

Permalink
add to the preprocessing block in modelconverter_config_to_nn and arc…
Browse files Browse the repository at this point in the history
…hive_from_model methods
  • Loading branch information
ptoupas committed Nov 8, 2024
1 parent 4033408 commit c8f2c87
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 20 deletions.
9 changes: 3 additions & 6 deletions modelconverter/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,13 +217,10 @@ def extract_preprocessing(

dai_type = encoding.from_.value
if dai_type != "NONE":
if (
inp.data_type == DataType.FLOAT32
or inp.data_type == DataType.INT8
):
type = "888"
elif inp.data_type == DataType.FLOAT16:
if inp.data_type == DataType.FLOAT16:
type = "F16F16F16"
else:
type = "888"
dai_type += type
dai_type += "i" if layout == "NHWC" else "p"

Expand Down
32 changes: 18 additions & 14 deletions modelconverter/utils/nn_archive.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from modelconverter.utils.constants import MISC_DIR
from modelconverter.utils.layout import guess_new_layout, make_default_layout
from modelconverter.utils.metadata import get_metadata
from modelconverter.utils.types import DataType

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -191,7 +192,7 @@ def modelconverter_config_to_nn(
cfg = config.stages[main_stage_key]

archive_cfg = {
"config_version": CONFIG_VERSION.__args__[-1], # type: ignore
"config_version": CONFIG_VERSION,
"model": {
"metadata": {
"name": model_name.stem,
Expand All @@ -205,21 +206,25 @@ def modelconverter_config_to_nn(

for inp in cfg.inputs:
new_shape = model_metadata.input_shapes[inp.name]
# new_dtype = model_metadata.input_dtypes[inp.name]
if inp.shape is not None and not any(s == 0 for s in inp.shape):
assert inp.layout is not None
layout = guess_new_layout(inp.layout, inp.shape, new_shape)
else:
layout = make_default_layout(new_shape)
dai_type = inp.encoding.to.value
if inp.data_type == DataType.FLOAT16:
type = "F16F16F16"
else:
type = "888"
dai_type += type
dai_type += "i" if layout == "NHWC" else "p"

archive_cfg["model"]["inputs"].append(
{
"name": inp.name,
"shape": new_shape,
"layout": layout,
# "dtype": new_dtype.value,
"dtype": inp.data_type.value,
# "dtype": "float32",
"input_type": "image",
"preprocessing": {
"mean": [0 for _ in inp.mean_values]
Expand All @@ -230,14 +235,14 @@ def modelconverter_config_to_nn(
if inp.scale_values
else None
),
"reverse_channels": False,
"interleaved_to_planar": False,
"reverse_channels": inp.encoding.from_ != inp.encoding.to,
"interleaved_to_planar": layout == "NHWC",
"dai_type": dai_type,
},
}
)
for out in cfg.outputs:
new_shape = model_metadata.output_shapes[out.name]
# new_dtype = model_metadata.output_dtypes[out.name]
if out.shape is not None and not any(s == 0 for s in out.shape):
assert out.layout is not None
layout = guess_new_layout(out.layout, out.shape, new_shape)
Expand All @@ -249,9 +254,7 @@ def modelconverter_config_to_nn(
"name": out.name,
"shape": new_shape,
"layout": layout,
# "dtype": new_dtype.value,
"dtype": out.data_type.value,
# "dtype": "float32",
}
)

Expand Down Expand Up @@ -284,7 +287,7 @@ def archive_from_model(model_path: Path) -> NNArchiveConfig:
metadata = get_metadata(model_path)

archive_cfg = {
"config_version": "1.0",
"config_version": CONFIG_VERSION,
"model": {
"metadata": {
"name": model_path.stem,
Expand All @@ -302,13 +305,14 @@ def archive_from_model(model_path: Path) -> NNArchiveConfig:
"name": name,
"shape": shape,
"layout": make_default_layout(shape),
"dtype": "float32",
"dtype": metadata.input_dtypes[name].value,
"input_type": "image",
"preprocessing": {
"mean": None,
"scale": None,
"reverse_channels": False,
"interleaved_to_planar": False,
"reverse_channels": None,
"interleaved_to_planar": None,
"dai_type": None,
},
}
)
Expand All @@ -319,7 +323,7 @@ def archive_from_model(model_path: Path) -> NNArchiveConfig:
"name": name,
"shape": shape,
"layout": make_default_layout(shape),
"dtype": "float32",
"dtype": metadata.output_dtypes[name].value,
}
)

Expand Down

0 comments on commit c8f2c87

Please sign in to comment.