Skip to content

Commit

Permalink
Merge branch 'main' into fc/fix/free_bits_bug
Browse files Browse the repository at this point in the history
  • Loading branch information
jdeschamps authored Nov 20, 2024
2 parents 52ba7ad + 240e4d3 commit 90c3175
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 11 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ dependencies = [
'numpy<2.0.0',
'torch>=2.0.0',
'torchvision',
'bioimageio.core>=0.6.9',
'bioimageio.core>=0.7.0',
'tifffile',
'psutil',
'pydantic>=2.5,<2.9',
Expand Down
1 change: 1 addition & 0 deletions src/careamics/lvae_training/dataset/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ class DataType(Enum):
OptiMEM100_014 = 10
SeparateTiffData = 11
BioSR_MRC = 12
PunctaRemoval = 13 # for the case when we have a set of differently sized crops for each channel.


class DataSplitType(Enum):
Expand Down
17 changes: 15 additions & 2 deletions src/careamics/model_io/bioimage/model_description.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import List, Optional, Tuple, Union

import numpy as np
from bioimageio.spec._internal.io import resolve_and_extract
from bioimageio.spec.model.v0_5 import (
ArchitectureFromLibraryDescr,
Author,
Expand Down Expand Up @@ -280,6 +281,16 @@ def create_model_description(
"https://careamics.github.io/latest/",
],
license="BSD-3-Clause",
config={
"bioimageio": {
"test_kwargs": {
"pytorch_state_dict": {
"absolute_tolerance": 1e-2,
"relative_tolerance": 1e-2,
}
}
}
},
version="0.1.0",
weights=weights_descr,
attachments=[FileDescr(source=config_path)],
Expand All @@ -304,15 +315,17 @@ def extract_model_path(model_desc: ModelDescr) -> tuple[Path, Path]:
"""
if model_desc.weights.pytorch_state_dict is None:
raise ValueError("No model weights found in model description.")
weights_path = model_desc.weights.pytorch_state_dict.download().path
weights_path = resolve_and_extract(
model_desc.weights.pytorch_state_dict.source
).path

for file in model_desc.attachments:
file_path = file.source if isinstance(file.source, Path) else file.source.path
if file_path is None:
continue
file_path = Path(file_path)
if file_path.name == "careamics.yaml":
config_path = file.download().path
config_path = resolve_and_extract(file.source).path
break
else:
raise ValueError("Configuration file not found.")
Expand Down
15 changes: 7 additions & 8 deletions src/careamics/model_io/bmz_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
create_env_text,
create_model_description,
extract_model_path,
get_unzip_path,
)


Expand Down Expand Up @@ -185,7 +184,12 @@ def export_to_bmz(
)

# test model description
summary: ValidationSummary = test_model(model_description)
test_kwargs = (
model_description.config.get("bioimageio", {})
.get("test_kwargs", {})
.get("pytorch_state_dict", {})
)
summary: ValidationSummary = test_model(model_description, **test_kwargs)
if summary.status == "failed":
raise ValueError(f"Model description test failed: {summary}")

Expand Down Expand Up @@ -219,14 +223,9 @@ def load_from_bmz(
# load description, this creates an unzipped folder next to the archive
model_desc = load_model_description(path)

# extract relative paths
# extract paths
weights_path, config_path = extract_model_path(model_desc)

# create folder path and absolute paths
unzip_path = get_unzip_path(path)
weights_path = unzip_path / weights_path
config_path = unzip_path / config_path

# load configuration
config = load_configuration(config_path)

Expand Down

0 comments on commit 90c3175

Please sign in to comment.