diff --git a/src/careamics/model_io/bioimage/model_description.py b/src/careamics/model_io/bioimage/model_description.py index 21ed50b8f..dbd1dfe08 100644 --- a/src/careamics/model_io/bioimage/model_description.py +++ b/src/careamics/model_io/bioimage/model_description.py @@ -4,6 +4,7 @@ from typing import List, Optional, Tuple, Union import numpy as np +from bioimageio.spec._internal.io import resolve_and_extract from bioimageio.spec.model.v0_5 import ( ArchitectureFromLibraryDescr, Author, @@ -280,6 +281,16 @@ def create_model_description( "https://careamics.github.io/latest/", ], license="BSD-3-Clause", + config={ + "bioimageio": { + "test_kwargs": { + "pytorch_state_dict": { + "absolute_tolerance": 1e-2, + "relative_tolerance": 1e-2, + } + } + } + }, version="0.1.0", weights=weights_descr, attachments=[FileDescr(source=config_path)], @@ -304,7 +315,9 @@ def extract_model_path(model_desc: ModelDescr) -> tuple[Path, Path]: """ if model_desc.weights.pytorch_state_dict is None: raise ValueError("No model weights found in model description.") - weights_path = model_desc.weights.pytorch_state_dict.download().path + weights_path = resolve_and_extract( + model_desc.weights.pytorch_state_dict.source + ).path for file in model_desc.attachments: file_path = file.source if isinstance(file.source, Path) else file.source.path @@ -312,7 +325,7 @@ def extract_model_path(model_desc: ModelDescr) -> tuple[Path, Path]: continue file_path = Path(file_path) if file_path.name == "careamics.yaml": - config_path = file.download().path + config_path = resolve_and_extract(file.source).path break else: raise ValueError("Configuration file not found.") diff --git a/src/careamics/model_io/bmz_io.py b/src/careamics/model_io/bmz_io.py index dc4564ecc..65a3ea99c 100644 --- a/src/careamics/model_io/bmz_io.py +++ b/src/careamics/model_io/bmz_io.py @@ -21,7 +21,6 @@ create_env_text, create_model_description, extract_model_path, - get_unzip_path, ) @@ -185,7 +184,12 @@ def export_to_bmz( ) # test model description - summary: ValidationSummary = test_model(model_description) + test_kwargs = ( + model_description.config.get("bioimageio", {}) + .get("test_kwargs", {}) + .get("pytorch_state_dict", {}) + ) + summary: ValidationSummary = test_model(model_description, **test_kwargs) if summary.status == "failed": raise ValueError(f"Model description test failed: {summary}") @@ -219,14 +223,9 @@ def load_from_bmz( # load description, this creates an unzipped folder next to the archive model_desc = load_model_description(path) - # extract relative paths + # extract paths weights_path, config_path = extract_model_path(model_desc) - # create folder path and absolute paths - unzip_path = get_unzip_path(path) - weights_path = unzip_path / weights_path - config_path = unzip_path / config_path - # load configuration config = load_configuration(config_path)