Skip to content

Commit

Permalink
refac: reuse create_data_module in careamics
Browse files Browse the repository at this point in the history
  • Loading branch information
jdeschamps committed Jun 26, 2024
1 parent c81038b commit dd08e56
Show file tree
Hide file tree
Showing 11 changed files with 104 additions and 375 deletions.
42 changes: 22 additions & 20 deletions src/careamics/careamist.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,20 +15,23 @@

from careamics.config import (
Configuration,
create_inference_parameters,
load_configuration,
)
from careamics.config.support import SupportedAlgorithm, SupportedData, SupportedLogger
from careamics.dataset.dataset_utils import reshape_array
from careamics.lightning.callbacks import ProgressBarCallback
from careamics.lightning.lightning_module import CAREamicsModule
from careamics.lightning.train_data_module import TrainDataModule
from careamics.lightning import (
CAREamicsModule,
HyperParametersCallback,
PredictDataModule,
ProgressBarCallback,
TrainDataModule,
create_predict_datamodule,
)
from careamics.model_io import export_to_bmz, load_pretrained
from careamics.prediction_utils import convert_outputs, create_pred_datamodule
from careamics.prediction_utils import convert_outputs
from careamics.utils import check_path_exists, get_logger

from .lightning.callbacks import HyperParametersCallback
from .lightning.predict_data_module import PredictDataModule

logger = get_logger(__name__)

LOGGER_TYPES = Optional[Union[TensorBoardLogger, WandbLogger]]
Expand Down Expand Up @@ -598,28 +601,27 @@ def predict(
list of NDArray or NDArray
Predictions made by the model.
"""
# Reuse batch size if not provided explicitly
if batch_size is None:
batch_size = (
self.train_datamodule.batch_size
if self.train_datamodule
else self.cfg.data_config.batch_size
)

self.pred_datamodule = create_pred_datamodule(
source=source,
config=self.cfg,
batch_size=batch_size,
# create inference configuration using the main config
inference_dict: dict = create_inference_parameters(
configuration=self.cfg,
tile_size=tile_size,
tile_overlap=tile_overlap,
axes=axes,
data_type=data_type,
axes=axes,
tta_transforms=tta_transforms,
batch_size=batch_size,
)

# create the prediction
self.pred_datamodule = create_predict_datamodule(
pred_data=source,
dataloader_params=dataloader_params,
read_source_func=read_source_func,
extension_filter=extension_filter,
**inference_dict,
)

# predict
predictions = self.trainer.predict(
model=self.model, datamodule=self.pred_datamodule, ckpt_path=checkpoint
)
Expand Down
5 changes: 2 additions & 3 deletions src/careamics/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,18 @@
"create_n2v_configuration",
"create_n2n_configuration",
"create_care_configuration",
"create_inference_parameters",
"register_model",
"CustomModel",
"create_inference_configuration",
"clear_custom_models",
"ConfigurationInformation",
]

from .algorithm_model import AlgorithmConfig
from .architectures import CustomModel, clear_custom_models, register_model
from .callback_model import CheckpointModel
from .configuration_factory import (
create_care_configuration,
create_inference_configuration,
create_inference_parameters,
create_n2n_configuration,
create_n2v_configuration,
)
Expand Down
86 changes: 0 additions & 86 deletions src/careamics/config/configuration_example.py

This file was deleted.

49 changes: 25 additions & 24 deletions src/careamics/config/configuration_factory.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
"""Convenience functions to create configurations for training and inference."""

from typing import Any, Dict, List, Literal, Optional, Tuple
from typing import Any, Dict, List, Literal, Optional

from .algorithm_model import AlgorithmConfig
from .architectures import UNetModel
from .configuration_model import Configuration
from .data_model import DataConfig
from .inference_model import InferenceConfig
from .support import (
SupportedAlgorithm,
SupportedArchitecture,
Expand Down Expand Up @@ -576,28 +575,30 @@ def create_n2v_configuration(
return configuration


def create_inference_configuration(
def create_inference_parameters(
configuration: Configuration,
tile_size: Optional[Tuple[int, ...]] = None,
tile_overlap: Optional[Tuple[int, ...]] = None,
tile_size: Optional[tuple[int, ...]] = None,
tile_overlap: Optional[tuple[int, ...]] = None,
data_type: Optional[Literal["array", "tiff", "custom"]] = None,
axes: Optional[str] = None,
tta_transforms: bool = True,
batch_size: Optional[int] = 1,
) -> InferenceConfig:
"""
Create a configuration for inference with N2V.
batch_size: Optional[int] = None,
) -> dict[str, Any]:
"""Return inference parameters based on a full configuration.
If not provided, `data_type` and `axes` are taken from the training
If not provided, `data_type`, `axes` and `batch_size` are taken from the training
configuration.
The tile and overlap sizes are compared to their constraints in the case of UNet
models.
Parameters
----------
configuration : Configuration
Global configuration.
tile_size : Tuple[int, ...], optional
tile_size : tuple of int, optional
Size of the tiles.
tile_overlap : Tuple[int, ...], optional
tile_overlap : tuple of int, optional
Overlap of the tiles.
data_type : str, optional
Type of the data, by default "tiff".
Expand All @@ -610,8 +611,8 @@ def create_inference_configuration(
Returns
-------
InferenceConfiguration
Configuration used to configure CAREamicsPredictData.
dict
Dictionary of values used to configure a `TrainDataModule`.
"""
if (
configuration.data_config.image_means is None
Expand Down Expand Up @@ -641,13 +642,13 @@ def create_inference_configuration(
if tile_overlap is None:
raise ValueError("Tile overlap must be specified.")

return InferenceConfig(
data_type=data_type or configuration.data_config.data_type,
tile_size=tile_size,
tile_overlap=tile_overlap,
axes=axes or configuration.data_config.axes,
image_means=configuration.data_config.image_means,
image_stds=configuration.data_config.image_stds,
tta_transforms=tta_transforms,
batch_size=batch_size,
)
return {
"data_type": data_type or configuration.data_config.data_type,
"tile_size": tile_size,
"tile_overlap": tile_overlap,
"axes": axes or configuration.data_config.axes,
"image_means": configuration.data_config.image_means,
"image_stds": configuration.data_config.image_stds,
"tta_transforms": tta_transforms,
"batch_size": batch_size or configuration.data_config.batch_size,
}
14 changes: 12 additions & 2 deletions src/careamics/config/inference_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,25 +15,35 @@ class InferenceConfig(BaseModel):

model_config = ConfigDict(validate_assignment=True, arbitrary_types_allowed=True)

# Mandatory fields
data_type: Literal["array", "tiff", "custom"] # As defined in SupportedData
"""Type of input data: numpy.ndarray (array) or path (tiff or custom)."""

tile_size: Optional[Union[list[int]]] = Field(
default=None, min_length=2, max_length=3
)
"""Tile size of prediction, only effective if `tile_overlap` is specified."""

tile_overlap: Optional[Union[list[int]]] = Field(
default=None, min_length=2, max_length=3
)
"""Overlap between tiles, only effective if `tile_size` is specified."""

axes: str
"""Data axes (TSCZYX) in the order of the input data."""

image_means: list = Field(..., min_length=0, max_length=32)
"""Mean values for each input channel."""

image_stds: list = Field(..., min_length=0, max_length=32)
"""Standard deviation values for each input channel."""

# only default TTAs are supported for now
# TODO only default TTAs are supported for now
tta_transforms: bool = Field(default=True)
"""Whether to apply test-time augmentation (all 90 degrees rotations and flips)."""

# Dataloader parameters
batch_size: int = Field(default=1, ge=1)
"""Batch size for prediction."""

@field_validator("tile_overlap")
@classmethod
Expand Down
Loading

0 comments on commit dd08e56

Please sign in to comment.