Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactoring: lightning API package and smoke tests #161

Merged
merged 15 commits into from
Jul 4, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 52 additions & 14 deletions src/careamics/careamist.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,14 @@

from careamics.config import (
Configuration,
create_inference_parameters,
load_configuration,
)
from careamics.config.support import SupportedAlgorithm, SupportedData, SupportedLogger
from careamics.config.support import (
SupportedAlgorithm,
SupportedArchitecture,
SupportedData,
SupportedLogger,
)
from careamics.dataset.dataset_utils import reshape_array
from careamics.lightning import (
CAREamicsModule,
Expand Down Expand Up @@ -594,25 +598,59 @@ def predict(
-------
list of NDArray or NDArray
Predictions made by the model.

Raises
------
ValueError
If mean and std are not provided in the configuration.
ValueError
If tile size is not divisible by 2**depth for UNet models.
ValueError
If tile overlap is not specified.
"""
# create inference configuration using the main config
inference_dict: dict = create_inference_parameters(
configuration=self.cfg,
tile_size=tile_size,
tile_overlap=tile_overlap,
data_type=data_type,
axes=axes,
tta_transforms=tta_transforms,
batch_size=batch_size,
)
if (
self.cfg.data_config.image_means is None
or self.cfg.data_config.image_stds is None
):
raise ValueError("Mean and std must be provided in the configuration.")

# tile size for UNets
if tile_size is not None:
model = self.cfg.algorithm_config.model

if model.architecture == SupportedArchitecture.UNET.value:
# tile size must be equal to k*2^n, where n is the number of pooling
# layers (equal to the depth) and k is an integer
Comment on lines +621 to +623
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I realise now the InferenceConfig doesn't have access to the type of architecture so it can't be validated in there 😅, I guess it might be nice to isolate this check as a separate function somewhere but otherwise looks good!

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not only the type of architecture, but also the depth of the UNet...

But you had a good point with the double Inference instantiation, that was convoluted. Let's see later if we want to move that to some utils package or so.

depth = model.depth
tile_increment = 2**depth

for i, t in enumerate(tile_size):
if t % tile_increment != 0:
raise ValueError(
f"Tile size must be divisible by {tile_increment} along "
f"all axes (got {t} for axis {i}). If your image size is "
f"smaller along one axis (e.g. Z), consider padding the "
f"image."
)

# tile overlaps must be specified
if tile_overlap is None:
raise ValueError("Tile overlap must be specified.")

# create the prediction
self.pred_datamodule = create_predict_datamodule(
pred_data=source,
dataloader_params=dataloader_params,
data_type=data_type or self.cfg.data_config.data_type,
axes=axes or self.cfg.data_config.axes,
image_means=self.cfg.data_config.image_means,
image_stds=self.cfg.data_config.image_stds,
tile_size=tile_size,
tile_overlap=tile_overlap,
batch_size=batch_size or self.cfg.data_config.batch_size,
tta_transforms=tta_transforms,
read_source_func=read_source_func,
extension_filter=extension_filter,
**inference_dict,
dataloader_params=dataloader_params,
)

# predict
Expand Down
2 changes: 0 additions & 2 deletions src/careamics/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
"create_n2v_configuration",
"create_n2n_configuration",
"create_care_configuration",
"create_inference_parameters",
"register_model",
"CustomModel",
"clear_custom_models",
Expand All @@ -23,7 +22,6 @@
from .callback_model import CheckpointModel
from .configuration_factory import (
create_care_configuration,
create_inference_parameters,
create_n2n_configuration,
create_n2v_configuration,
)
Expand Down
79 changes: 0 additions & 79 deletions src/careamics/config/configuration_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -573,82 +573,3 @@ def create_n2v_configuration(
)

return configuration


def create_inference_parameters(
configuration: Configuration,
tile_size: Optional[tuple[int, ...]] = None,
tile_overlap: Optional[tuple[int, ...]] = None,
data_type: Optional[Literal["array", "tiff", "custom"]] = None,
axes: Optional[str] = None,
tta_transforms: bool = True,
batch_size: Optional[int] = None,
) -> dict[str, Any]:
"""Return inference parameters based on a full configuration.

If not provided, `data_type`, `axes` and `batch_size` are taken from the training
configuration.

The tile and overlap sizes are compared to their constraints in the case of UNet
models.

Parameters
----------
configuration : Configuration
Global configuration.
tile_size : tuple of int, optional
Size of the tiles.
tile_overlap : tuple of int, optional
Overlap of the tiles.
data_type : str, optional
Type of the data, by default "tiff".
axes : str, optional
Axes of the data, by default "YX".
tta_transforms : bool, optional
Whether to apply test-time augmentations, by default True.
batch_size : int, optional
Batch size, by default 1.

Returns
-------
dict
Dictionary of values used to configure a `TrainDataModule`.
"""
if (
configuration.data_config.image_means is None
or configuration.data_config.image_stds is None
):
raise ValueError("Mean and std must be provided in the configuration.")

# tile size for UNets
if tile_size is not None:
model = configuration.algorithm_config.model

if model.architecture == SupportedArchitecture.UNET.value:
# tile size must be equal to k*2^n, where n is the number of pooling layers
# (equal to the depth) and k is an integer
depth = model.depth
tile_increment = 2**depth

for i, t in enumerate(tile_size):
if t % tile_increment != 0:
raise ValueError(
f"Tile size must be divisible by {tile_increment} along all "
f"axes (got {t} for axis {i}). If your image size is smaller "
f"along one axis (e.g. Z), consider padding the image."
)

# tile overlaps must be specified
if tile_overlap is None:
raise ValueError("Tile overlap must be specified.")

return {
"data_type": data_type or configuration.data_config.data_type,
"tile_size": tile_size,
"tile_overlap": tile_overlap,
"axes": axes or configuration.data_config.axes,
"image_means": configuration.data_config.image_means,
"image_stds": configuration.data_config.image_stds,
"tta_transforms": tta_transforms,
"batch_size": batch_size or configuration.data_config.batch_size,
}
89 changes: 0 additions & 89 deletions tests/config/test_configuration_factory.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
import pytest

from careamics.config import (
InferenceConfig,
create_care_configuration,
create_inference_parameters,
create_n2n_configuration,
create_n2v_configuration,
)
Expand Down Expand Up @@ -523,90 +521,3 @@ def test_structn2v():
== SupportedStructAxis.HORIZONTAL.value
)
assert config.data_config.transforms[-1].struct_mask_span == 7


def test_inference_config_no_stats():
"""Test that an inference configuration fails if no statistics are present."""
config = create_n2v_configuration(
experiment_name="test",
data_type="tiff",
axes="YX",
patch_size=[64, 64],
batch_size=8,
num_epochs=100,
)

with pytest.raises(ValueError):
create_inference_parameters(
configuration=config,
)


def test_inference_config():
"""Test that an inference configuration can be created."""
config = create_n2v_configuration(
experiment_name="test",
data_type="tiff",
axes="YX",
patch_size=[64, 64],
batch_size=8,
num_epochs=100,
)
config.data_config.set_means_and_stds([0.5], [0.2])

inf_dict = create_inference_parameters(
configuration=config,
)
InferenceConfig(**inf_dict)


def test_inference_tile_size():
"""Test that an inference configuration can be created for a UNet model."""
config = create_care_configuration(
experiment_name="test",
data_type="tiff",
axes="YX",
patch_size=[64, 64],
batch_size=8,
num_epochs=100,
)
config.data_config.set_means_and_stds([0.5], [0.2])

# check UNet depth, tile increment must then be a factor of 4
assert config.algorithm_config.model.depth == 2

# error if not a factor of 4
with pytest.raises(ValueError):
create_inference_parameters(
configuration=config,
tile_size=[6, 6],
tile_overlap=[2, 2],
)

# no error if a factor of 4
inf_dict = create_inference_parameters(
configuration=config,
tile_size=[8, 8],
tile_overlap=[2, 2],
)
InferenceConfig(**inf_dict)


def test_inference_tile_no_overlap():
"""Test that an error is raised if the tile overlap is not specified, but the tile
size is."""
config = create_care_configuration(
experiment_name="test",
data_type="tiff",
axes="YX",
patch_size=[64, 64],
batch_size=8,
num_epochs=100,
)
config.data_config.set_means_and_stds([0.5], [0.2])

with pytest.raises(ValueError):
create_inference_parameters(
configuration=config,
tile_size=[8, 8],
)
Loading