Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify CAREamist BMZ export #144

Merged
merged 1 commit into from
Jun 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 11 additions & 96 deletions src/careamics/careamist.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
from careamics.lightning_prediction_datamodule import CAREamicsPredictData
from careamics.lightning_prediction_loop import CAREamicsPredictionLoop
from careamics.model_io import export_to_bmz, load_pretrained
from careamics.transforms import Denormalize
from careamics.utils import check_path_exists, get_logger

from .callbacks import HyperParametersCallback
Expand Down Expand Up @@ -656,122 +655,39 @@ def predict(
f"NDArray (got {type(source)})."
)

def _create_data_for_bmz(
self,
input_array: Optional[NDArray] = None,
seed: Optional[int] = None,
) -> NDArray:
"""Create data for BMZ export.

If no `input_array` is provided, this method checks if there is a prediction
datamodule, or a training data module, to extract a patch. If none exists,
then a random array is created.

The method returns a denormalized array.

If there is a non-singleton batch dimension, this method returns only the first
element.

Parameters
----------
input_array : NDArray, optional
Input array, which should not be normalized, by default None.
seed : int, optional
Seed for the random number generator used when no input array is given nor
are there data in the dataloaders, by default None.

Returns
-------
NDArray
Input data for BMZ export.

Raises
------
ValueError
If mean and std are not provided in the configuration.
"""
if input_array is None:
if self.cfg.data_config.mean is None or self.cfg.data_config.std is None:
raise ValueError(
"Mean and std cannot be None in the configuration in order to"
"export to the BMZ format. Was the model trained?"
)

# generate images, priority is given to the prediction data module
if self.pred_datamodule is not None:
# unpack a batch, ignore masks or targets
input_patch, *_ = next(iter(self.pred_datamodule.predict_dataloader()))

# convert torch.Tensor to numpy
input_patch = input_patch.numpy()

# denormalize
denormalize = Denormalize(
mean=self.cfg.data_config.mean, std=self.cfg.data_config.std
)
input_patch_denorm = denormalize(input_patch)

elif self.train_datamodule is not None:
input_patch, *_ = next(iter(self.train_datamodule.train_dataloader()))
input_patch = input_patch.numpy()

# denormalize
denormalize = Denormalize(
mean=self.cfg.data_config.mean, std=self.cfg.data_config.std
)
input_patch_denorm = denormalize(input_patch)
else:
# create a random input array
rng = np.random.default_rng(seed)
input_patch_denorm = rng.normal(
loc=self.cfg.data_config.mean,
scale=self.cfg.data_config.std,
size=self.cfg.data_config.patch_size,
).astype(np.float32)[
np.newaxis, np.newaxis, ...
] # add S & C dimensions
else:
# potentially correct shape
input_patch_denorm = reshape_array(input_array, self.cfg.data_config.axes)

# if this a batch
if input_patch_denorm.shape[0] > 1:
input_patch_denorm = input_patch_denorm[[0], ...] # keep singleton dim

return input_patch_denorm

def export_to_bmz(
self,
path: Union[Path, str],
name: str,
input_array: NDArray,
authors: List[dict],
input_array: Optional[NDArray] = None,
general_description: str = "",
channel_names: Optional[List[str]] = None,
data_description: Optional[str] = None,
) -> None:
"""Export the model to the BioImage Model Zoo format.

Input array must be of shape SC(Z)YX, with S and C singleton dimensions.
Input array must be of the same dimensions as the axes recorded in the
configuration of the `CAREamist`.

Parameters
----------
path : pathlib.Path or str
Path to save the model.
name : str
Name of the model.
input_array : NDArray
Input array used to validate the model and as example.
authors : list of dict
List of authors of the model.
input_array : NDArray, optional
Input array for the model, must be of shape SC(Z)YX, by default None.
general_description : str
General description of the model, used in the metadata of the BMZ archive.
channel_names : list of str, optional
Channel names, by default None.
data_description : str, optional
Description of the data, by default None.
"""
input_patch = self._create_data_for_bmz(input_array)
input_patch = reshape_array(input_array, self.cfg.data_config.axes)

# axes need to be reformated for the export because reshaping was done in the
# datamodule
Expand All @@ -788,11 +704,10 @@ def export_to_bmz(
tta_transforms=False,
)

if not isinstance(output_patch, np.ndarray):
raise ValueError(
f"Numpy array required for export to BioImage Model Zoo, got "
f"{type(output_patch)}."
)
if isinstance(output_patch, list):
output = np.concatenate(output_patch, axis=0)
else:
output = output_patch

export_to_bmz(
model=self.model,
Expand All @@ -802,7 +717,7 @@ def export_to_bmz(
general_description=general_description,
authors=authors,
input_array=input_patch,
output_array=output_patch,
output_array=output,
channel_names=channel_names,
data_description=data_description,
)
Loading
Loading