Skip to content

Commit

Permalink
(refac): passing array to BMZ export is now mandatory
Browse files Browse the repository at this point in the history
  • Loading branch information
jdeschamps committed Jun 13, 2024
1 parent 955d1e9 commit 140daab
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 273 deletions.
107 changes: 11 additions & 96 deletions src/careamics/careamist.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
from careamics.lightning_prediction_datamodule import CAREamicsPredictData
from careamics.lightning_prediction_loop import CAREamicsPredictionLoop
from careamics.model_io import export_to_bmz, load_pretrained
from careamics.transforms import Denormalize
from careamics.utils import check_path_exists, get_logger

from .callbacks import HyperParametersCallback
Expand Down Expand Up @@ -656,122 +655,39 @@ def predict(
f"NDArray (got {type(source)})."
)

def _create_data_for_bmz(
self,
input_array: Optional[NDArray] = None,
seed: Optional[int] = None,
) -> NDArray:
"""Create data for BMZ export.
If no `input_array` is provided, this method checks if there is a prediction
datamodule, or a training data module, to extract a patch. If none exists,
then a random array is created.
The method returns a denormalized array.
If there is a non-singleton batch dimension, this method returns only the first
element.
Parameters
----------
input_array : NDArray, optional
Input array, which should not be normalized, by default None.
seed : int, optional
Seed for the random number generator used when no input array is given nor
are there data in the dataloaders, by default None.
Returns
-------
NDArray
Input data for BMZ export.
Raises
------
ValueError
If mean and std are not provided in the configuration.
"""
if input_array is None:
if self.cfg.data_config.mean is None or self.cfg.data_config.std is None:
raise ValueError(
"Mean and std cannot be None in the configuration in order to"
"export to the BMZ format. Was the model trained?"
)

# generate images, priority is given to the prediction data module
if self.pred_datamodule is not None:
# unpack a batch, ignore masks or targets
input_patch, *_ = next(iter(self.pred_datamodule.predict_dataloader()))

# convert torch.Tensor to numpy
input_patch = input_patch.numpy()

# denormalize
denormalize = Denormalize(
mean=self.cfg.data_config.mean, std=self.cfg.data_config.std
)
input_patch_denorm = denormalize(input_patch)

elif self.train_datamodule is not None:
input_patch, *_ = next(iter(self.train_datamodule.train_dataloader()))
input_patch = input_patch.numpy()

# denormalize
denormalize = Denormalize(
mean=self.cfg.data_config.mean, std=self.cfg.data_config.std
)
input_patch_denorm = denormalize(input_patch)
else:
# create a random input array
rng = np.random.default_rng(seed)
input_patch_denorm = rng.normal(
loc=self.cfg.data_config.mean,
scale=self.cfg.data_config.std,
size=self.cfg.data_config.patch_size,
).astype(np.float32)[
np.newaxis, np.newaxis, ...
] # add S & C dimensions
else:
# potentially correct shape
input_patch_denorm = reshape_array(input_array, self.cfg.data_config.axes)

# if this a batch
if input_patch_denorm.shape[0] > 1:
input_patch_denorm = input_patch_denorm[[0], ...] # keep singleton dim

return input_patch_denorm

def export_to_bmz(
self,
path: Union[Path, str],
name: str,
input_array: NDArray,
authors: List[dict],
input_array: Optional[NDArray] = None,
general_description: str = "",
channel_names: Optional[List[str]] = None,
data_description: Optional[str] = None,
) -> None:
"""Export the model to the BioImage Model Zoo format.
Input array must be of shape SC(Z)YX, with S and C singleton dimensions.
Input array must be of the same dimensions as the axes recorded in the
configuration of the `CAREamist`.
Parameters
----------
path : pathlib.Path or str
Path to save the model.
name : str
Name of the model.
input_array : NDArray
Input array used to validate the model and as example.
authors : list of dict
List of authors of the model.
input_array : NDArray, optional
Input array for the model, must be of shape SC(Z)YX, by default None.
general_description : str
General description of the model, used in the metadata of the BMZ archive.
channel_names : list of str, optional
Channel names, by default None.
data_description : str, optional
Description of the data, by default None.
"""
input_patch = self._create_data_for_bmz(input_array)
input_patch = reshape_array(input_array, self.cfg.data_config.axes)

# axes need to be reformated for the export because reshaping was done in the
# datamodule
Expand All @@ -788,11 +704,10 @@ def export_to_bmz(
tta_transforms=False,
)

if not isinstance(output_patch, np.ndarray):
raise ValueError(
f"Numpy array required for export to BioImage Model Zoo, got "
f"{type(output_patch)}."
)
if isinstance(output_patch, list):
output = np.concatenate(output_patch, axis=0)
else:
output = output_patch

export_to_bmz(
model=self.model,
Expand All @@ -802,7 +717,7 @@ def export_to_bmz(
general_description=general_description,
authors=authors,
input_array=input_patch,
output_array=output_patch,
output_array=output,
channel_names=channel_names,
data_description=data_description,
)
Loading

0 comments on commit 140daab

Please sign in to comment.