Skip to content

Commit

Permalink
Simplify CAREamist BMZ export (#144)
Browse files Browse the repository at this point in the history
### Description

Following a chat with @CatEek, I've come to realize that my previous
attempt to pull patches from the dataloaders in order to export the
model to the BMZ format was difficult to read (patches needed to be
denormalized in order to be put through the BMZ pipeline).

I've decided to just force users to input an array. This can be the
training or prediction arrays, this is a simple solution to having a
clearer code to maintain,

- **What**: Removed `CAREamist._create_data_for_bmz`, `input_array` now
mandatory for `CAREamist.export_to_bmz`.
- **Why**: Pulling patches from the dataloaders in order to avoid the
input of an array led to complex code.
- **How**: See what.

### Changes Made

- **Modified**: `CAREamist.export_to_bmz`.
- **Removed**: `CAREamist._create_data_for_bmz` and corresponding tests.


### Breaking changes

Any code not inputing `input_array`, e.g. all notebooks examples.

Currently, if the array does not have the same dimensions/axes as what
the configuration states, users should get an error from the reshape
function.


### Additional Notes and Examples

Before, after training a model or loading a model, the following code
would run:

```python
careamist.export_to_bmz(
    path="sem_n2v_model.zip",
    name="SEM_N2V",
    authors=[{"name": "CAREamics authors", "affiliation": "Human Technopole"}],
)
```

It would create input data for the BMZ, using the following steps:

- If there is a prediction dataloader, pull a patch from it and
denormalize it
- If there is a training dataloader, pull a patch from it and
denormalize it
- If there is not dataloader, create a random array using the `mean` and
`std` in the configuration

Now, users have to provide an input array:

```python
careamist.export_to_bmz(
    path="sem_n2v_model.zip",
    name="SEM_N2V",
    input_array=some_array, 
    authors=[{"name": "CAREamics authors", "affiliation": "Human Technopole"}],
)
```

---

**Please ensure your PR meets the following requirements:**

- [x] Code builds and passes tests locally, including doctests
- [x] New tests have been added (for bug fixes/features)
- [x] Pre-commit passes
- [ ] PR to the documentation exists (for bug fixes / features)
  • Loading branch information
jdeschamps authored Jun 13, 2024
1 parent 955d1e9 commit aa66daf
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 273 deletions.
107 changes: 11 additions & 96 deletions src/careamics/careamist.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
from careamics.lightning_prediction_datamodule import CAREamicsPredictData
from careamics.lightning_prediction_loop import CAREamicsPredictionLoop
from careamics.model_io import export_to_bmz, load_pretrained
from careamics.transforms import Denormalize
from careamics.utils import check_path_exists, get_logger

from .callbacks import HyperParametersCallback
Expand Down Expand Up @@ -656,122 +655,39 @@ def predict(
f"NDArray (got {type(source)})."
)

def _create_data_for_bmz(
self,
input_array: Optional[NDArray] = None,
seed: Optional[int] = None,
) -> NDArray:
"""Create data for BMZ export.
If no `input_array` is provided, this method checks if there is a prediction
datamodule, or a training data module, to extract a patch. If none exists,
then a random array is created.
The method returns a denormalized array.
If there is a non-singleton batch dimension, this method returns only the first
element.
Parameters
----------
input_array : NDArray, optional
Input array, which should not be normalized, by default None.
seed : int, optional
Seed for the random number generator used when no input array is given nor
are there data in the dataloaders, by default None.
Returns
-------
NDArray
Input data for BMZ export.
Raises
------
ValueError
If mean and std are not provided in the configuration.
"""
if input_array is None:
if self.cfg.data_config.mean is None or self.cfg.data_config.std is None:
raise ValueError(
"Mean and std cannot be None in the configuration in order to"
"export to the BMZ format. Was the model trained?"
)

# generate images, priority is given to the prediction data module
if self.pred_datamodule is not None:
# unpack a batch, ignore masks or targets
input_patch, *_ = next(iter(self.pred_datamodule.predict_dataloader()))

# convert torch.Tensor to numpy
input_patch = input_patch.numpy()

# denormalize
denormalize = Denormalize(
mean=self.cfg.data_config.mean, std=self.cfg.data_config.std
)
input_patch_denorm = denormalize(input_patch)

elif self.train_datamodule is not None:
input_patch, *_ = next(iter(self.train_datamodule.train_dataloader()))
input_patch = input_patch.numpy()

# denormalize
denormalize = Denormalize(
mean=self.cfg.data_config.mean, std=self.cfg.data_config.std
)
input_patch_denorm = denormalize(input_patch)
else:
# create a random input array
rng = np.random.default_rng(seed)
input_patch_denorm = rng.normal(
loc=self.cfg.data_config.mean,
scale=self.cfg.data_config.std,
size=self.cfg.data_config.patch_size,
).astype(np.float32)[
np.newaxis, np.newaxis, ...
] # add S & C dimensions
else:
# potentially correct shape
input_patch_denorm = reshape_array(input_array, self.cfg.data_config.axes)

# if this a batch
if input_patch_denorm.shape[0] > 1:
input_patch_denorm = input_patch_denorm[[0], ...] # keep singleton dim

return input_patch_denorm

def export_to_bmz(
self,
path: Union[Path, str],
name: str,
input_array: NDArray,
authors: List[dict],
input_array: Optional[NDArray] = None,
general_description: str = "",
channel_names: Optional[List[str]] = None,
data_description: Optional[str] = None,
) -> None:
"""Export the model to the BioImage Model Zoo format.
Input array must be of shape SC(Z)YX, with S and C singleton dimensions.
Input array must be of the same dimensions as the axes recorded in the
configuration of the `CAREamist`.
Parameters
----------
path : pathlib.Path or str
Path to save the model.
name : str
Name of the model.
input_array : NDArray
Input array used to validate the model and as example.
authors : list of dict
List of authors of the model.
input_array : NDArray, optional
Input array for the model, must be of shape SC(Z)YX, by default None.
general_description : str
General description of the model, used in the metadata of the BMZ archive.
channel_names : list of str, optional
Channel names, by default None.
data_description : str, optional
Description of the data, by default None.
"""
input_patch = self._create_data_for_bmz(input_array)
input_patch = reshape_array(input_array, self.cfg.data_config.axes)

# axes need to be reformated for the export because reshaping was done in the
# datamodule
Expand All @@ -788,11 +704,10 @@ def export_to_bmz(
tta_transforms=False,
)

if not isinstance(output_patch, np.ndarray):
raise ValueError(
f"Numpy array required for export to BioImage Model Zoo, got "
f"{type(output_patch)}."
)
if isinstance(output_patch, list):
output = np.concatenate(output_patch, axis=0)
else:
output = output_patch

export_to_bmz(
model=self.model,
Expand All @@ -802,7 +717,7 @@ def export_to_bmz(
general_description=general_description,
authors=authors,
input_array=input_patch,
output_array=output_patch,
output_array=output,
channel_names=channel_names,
data_description=data_description,
)
Loading

0 comments on commit aa66daf

Please sign in to comment.