Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactoring: lightning API package and smoke tests #161

Merged
merged 15 commits into from
Jul 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ exclude_lines = [
"except ImportError",
"\\.\\.\\.",
"raise NotImplementedError()",
"except PackageNotFoundError:",
]

[tool.coverage.run]
Expand Down
15 changes: 1 addition & 14 deletions src/careamics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,7 @@
except PackageNotFoundError:
__version__ = "uninstalled"

__all__ = [
"CAREamist",
"CAREamicsModuleWrapper",
"CAREamicsPredictData",
"CAREamicsTrainData",
"Configuration",
"load_configuration",
"save_configuration",
"TrainingDataWrapper",
"PredictDataWrapper",
]
__all__ = ["CAREamist", "Configuration", "load_configuration", "save_configuration"]

from .careamist import CAREamist
from .config import Configuration, load_configuration, save_configuration
from .lightning_datamodule import CAREamicsTrainData, TrainingDataWrapper
from .lightning_module import CAREamicsModuleWrapper
from .lightning_prediction_datamodule import CAREamicsPredictData, PredictDataWrapper
112 changes: 76 additions & 36 deletions src/careamics/careamist.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,22 +13,29 @@
)
from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger

from careamics.callbacks import ProgressBarCallback
from careamics.config import (
Configuration,
load_configuration,
)
from careamics.config.support import SupportedAlgorithm, SupportedData, SupportedLogger
from careamics.config.support import (
SupportedAlgorithm,
SupportedArchitecture,
SupportedData,
SupportedLogger,
)
from careamics.dataset.dataset_utils import reshape_array
from careamics.lightning_datamodule import CAREamicsTrainData
from careamics.lightning_module import CAREamicsModule
from careamics.lightning import (
CAREamicsModule,
HyperParametersCallback,
PredictDataModule,
ProgressBarCallback,
TrainDataModule,
create_predict_datamodule,
)
from careamics.model_io import export_to_bmz, load_pretrained
from careamics.prediction_utils import convert_outputs, create_pred_datamodule
from careamics.prediction_utils import convert_outputs
from careamics.utils import check_path_exists, get_logger

from .callbacks import HyperParametersCallback
from .lightning_prediction_datamodule import CAREamicsPredictData

logger = get_logger(__name__)

LOGGER_TYPES = Optional[Union[TensorBoardLogger, WandbLogger]]
Expand Down Expand Up @@ -61,9 +68,9 @@ class CAREamist:
Experiment logger, "wandb" or "tensorboard".
work_dir : pathlib.Path
Working directory.
train_datamodule : CAREamicsTrainData
train_datamodule : TrainDataModule
Training datamodule.
pred_datamodule : CAREamicsPredictData
pred_datamodule : PredictDataModule
Prediction datamodule.
"""

Expand Down Expand Up @@ -193,8 +200,8 @@ def __init__(
)

# place holder for the datamodules
self.train_datamodule: Optional[CAREamicsTrainData] = None
self.pred_datamodule: Optional[CAREamicsPredictData] = None
self.train_datamodule: Optional[TrainDataModule] = None
self.pred_datamodule: Optional[PredictDataModule] = None

def _define_callbacks(self, callbacks: Optional[list[Callback]] = None) -> None:
"""
Expand Down Expand Up @@ -246,7 +253,7 @@ def _define_callbacks(self, callbacks: Optional[list[Callback]] = None) -> None:
def train(
self,
*,
datamodule: Optional[CAREamicsTrainData] = None,
datamodule: Optional[TrainDataModule] = None,
train_source: Optional[Union[Path, str, NDArray]] = None,
val_source: Optional[Union[Path, str, NDArray]] = None,
train_target: Optional[Union[Path, str, NDArray]] = None,
Expand All @@ -273,7 +280,7 @@ def train(

Parameters
----------
datamodule : CAREamicsTrainData, optional
datamodule : TrainDataModule, optional
Datamodule to train on, by default None.
train_source : pathlib.Path or str or NDArray, optional
Train source, if no datamodule is provided, by default None.
Expand Down Expand Up @@ -375,17 +382,17 @@ def train(

else:
raise ValueError(
f"Invalid input, expected a str, Path, array or CAREamicsTrainData "
f"Invalid input, expected a str, Path, array or TrainDataModule "
f"instance (got {type(train_source)})."
)

def _train_on_datamodule(self, datamodule: CAREamicsTrainData) -> None:
def _train_on_datamodule(self, datamodule: TrainDataModule) -> None:
"""
Train the model on the provided datamodule.

Parameters
----------
datamodule : CAREamicsTrainData
datamodule : TrainDataModule
Datamodule to train on.
"""
# record datamodule
Expand Down Expand Up @@ -421,7 +428,7 @@ def _train_on_array(
Minimum number of patches to use for validation, by default 5.
"""
# create datamodule
datamodule = CAREamicsTrainData(
datamodule = TrainDataModule(
data_config=self.cfg.data_config,
train_data=train_data,
val_data=val_data,
Expand Down Expand Up @@ -477,7 +484,7 @@ def _train_on_path(
path_to_val_target = check_path_exists(path_to_val_target)

# create datamodule
datamodule = CAREamicsTrainData(
datamodule = TrainDataModule(
data_config=self.cfg.data_config,
train_data=path_to_train_data,
val_data=path_to_val_data,
Expand All @@ -493,7 +500,7 @@ def _train_on_path(

@overload
def predict( # numpydoc ignore=GL08
self, source: CAREamicsPredictData
self, source: PredictDataModule
) -> Union[list[NDArray], NDArray]: ...

@overload
Expand Down Expand Up @@ -528,7 +535,7 @@ def predict( # numpydoc ignore=GL08

def predict(
self,
source: Union[CAREamicsPredictData, Path, str, NDArray],
source: Union[PredictDataModule, Path, str, NDArray],
*,
batch_size: Optional[int] = None,
tile_size: Optional[tuple[int, ...]] = None,
Expand Down Expand Up @@ -591,29 +598,62 @@ def predict(
-------
list of NDArray or NDArray
Predictions made by the model.
"""
# Reuse batch size if not provided explicitly
if batch_size is None:
batch_size = (
self.train_datamodule.batch_size
if self.train_datamodule
else self.cfg.data_config.batch_size
)

self.pred_datamodule = create_pred_datamodule(
source=source,
config=self.cfg,
batch_size=batch_size,
Raises
------
ValueError
If mean and std are not provided in the configuration.
ValueError
If tile size is not divisible by 2**depth for UNet models.
ValueError
If tile overlap is not specified.
"""
if (
self.cfg.data_config.image_means is None
or self.cfg.data_config.image_stds is None
):
raise ValueError("Mean and std must be provided in the configuration.")

# tile size for UNets
if tile_size is not None:
model = self.cfg.algorithm_config.model

if model.architecture == SupportedArchitecture.UNET.value:
# tile size must be equal to k*2^n, where n is the number of pooling
# layers (equal to the depth) and k is an integer
Comment on lines +621 to +623
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I realise now the InferenceConfig doesn't have access to the type of architecture so it can't be validated in there 😅, I guess it might be nice to isolate this check as a separate function somewhere but otherwise looks good!

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not only the type of architecture, but also the depth of the UNet...

But you had a good point with the double Inference instantiation, that was convoluted. Let's see later if we want to move that to some utils package or so.

depth = model.depth
tile_increment = 2**depth

for i, t in enumerate(tile_size):
if t % tile_increment != 0:
raise ValueError(
f"Tile size must be divisible by {tile_increment} along "
f"all axes (got {t} for axis {i}). If your image size is "
f"smaller along one axis (e.g. Z), consider padding the "
f"image."
)

# tile overlaps must be specified
if tile_overlap is None:
raise ValueError("Tile overlap must be specified.")

# create the prediction
self.pred_datamodule = create_predict_datamodule(
pred_data=source,
data_type=data_type or self.cfg.data_config.data_type,
axes=axes or self.cfg.data_config.axes,
image_means=self.cfg.data_config.image_means,
image_stds=self.cfg.data_config.image_stds,
tile_size=tile_size,
tile_overlap=tile_overlap,
axes=axes,
data_type=data_type,
batch_size=batch_size or self.cfg.data_config.batch_size,
tta_transforms=tta_transforms,
dataloader_params=dataloader_params,
read_source_func=read_source_func,
extension_filter=extension_filter,
dataloader_params=dataloader_params,
)

# predict
predictions = self.trainer.predict(
model=self.model, datamodule=self.pred_datamodule
)
Expand Down
3 changes: 0 additions & 3 deletions src/careamics/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,14 @@
"create_care_configuration",
"register_model",
"CustomModel",
"create_inference_configuration",
"clear_custom_models",
"ConfigurationInformation",
]

from .algorithm_model import AlgorithmConfig
from .architectures import CustomModel, clear_custom_models, register_model
from .callback_model import CheckpointModel
from .configuration_factory import (
create_care_configuration,
create_inference_configuration,
create_n2n_configuration,
create_n2v_configuration,
)
Expand Down
86 changes: 0 additions & 86 deletions src/careamics/config/configuration_example.py

This file was deleted.

Loading
Loading